Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/building.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
- name: Upgrade pip
run: |
pip install --upgrade setuptools
pip list

- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
Expand All @@ -67,7 +66,11 @@ jobs:
if: ${{ runner.os != 'macOS' }}
run: |
VERSION=`sed -n "s/^__version__ = '\(.*\)'/\1/p" torch_sparse/__init__.py`
sed -i "s/$VERSION/$VERSION+${{ matrix.cuda-version }}/" torch_sparse/__init__.py
TORCH_VERSION=`echo "pt${{ matrix.torch-version }}" | sed "s/..$//" | sed "s/\.//g"`
CUDA_VERSION=`echo ${{ matrix.cuda-version }}`
echo "New version name: $VERSION+$TORCH_VERSION$CUDA_VERSION"
sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" setup.py
sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" torch_sparse/__init__.py
shell:
bash

Expand Down
18 changes: 15 additions & 3 deletions test/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch
import torch_scatter

from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor

Expand All @@ -12,6 +13,9 @@
@pytest.mark.parametrize('dtype,device,reduce',
product(grad_dtypes, devices, reductions))
def test_spmm(dtype, device, reduce):
if device == torch.device('cuda:0') and dtype == torch.bfloat16:
return # Not yet implemented.

src = torch.randn((10, 8), dtype=dtype, device=device)
src[2:4, :] = 0 # Remove multiple rows.
src[:, 2:4] = 0 # Remove multiple columns.
Expand Down Expand Up @@ -39,13 +43,21 @@ def test_spmm(dtype, device, reduce):
out = matmul(src, other, reduce)
out.backward(grad_out)

assert torch.allclose(expected, out, atol=1e-2)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-2)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-2)
if dtype == torch.float16 or dtype == torch.bfloat16:
assert torch.allclose(expected, out, atol=1e-1)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-1)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-1)
else:
assert torch.allclose(expected, out)
assert torch.allclose(expected_grad_value, value.grad)
assert torch.allclose(expected_grad_other, other.grad)


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
if device == torch.device('cuda:0') and dtype == torch.bfloat16:
return # Not yet implemented.

src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device)

Expand Down
11 changes: 9 additions & 2 deletions test/test_spspmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import pytest
import torch
from torch_sparse import spspmm, SparseTensor

from .utils import grad_dtypes, devices, tensor
from torch_sparse import SparseTensor, spspmm

from .utils import devices, grad_dtypes, tensor


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
if device == torch.device('cuda:0') and dtype == torch.bfloat16:
return # Not yet implemented.

indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
Expand All @@ -21,6 +25,9 @@ def test_spspmm(dtype, device):

@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device):
if device == torch.device('cuda:0') and dtype == torch.bfloat16:
return # Not yet implemented.

x = SparseTensor(
row=torch.tensor(
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
Expand Down
2 changes: 1 addition & 1 deletion torch_sparse/cat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple # noqa

import torch
from torch_sparse.storage import SparseStorage
Expand Down