diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/test_add.py b/test/test_add.py index e4839220..0fc3971b 100644 --- a/test/test_add.py +++ b/test/test_add.py @@ -2,9 +2,9 @@ import pytest import torch -from torch_sparse import SparseTensor, add -from .utils import dtypes, devices, tensor +from torch_sparse import SparseTensor, add +from torch_sparse.testing import devices, dtypes, tensor @pytest.mark.parametrize('dtype,device', product(dtypes, devices)) diff --git a/test/test_cat.py b/test/test_cat.py index be1e1e53..1b07799f 100644 --- a/test/test_cat.py +++ b/test/test_cat.py @@ -1,9 +1,9 @@ import pytest import torch -from torch_sparse.tensor import SparseTensor -from torch_sparse.cat import cat -from .utils import devices, tensor +from torch_sparse.cat import cat +from torch_sparse.tensor import SparseTensor +from torch_sparse.testing import devices, tensor @pytest.mark.parametrize('device', devices) diff --git a/test/test_diag.py b/test/test_diag.py index 0ed7564a..a3142fce 100644 --- a/test/test_diag.py +++ b/test/test_diag.py @@ -2,9 +2,9 @@ import pytest import torch -from torch_sparse.tensor import SparseTensor -from .utils import dtypes, devices, tensor +from torch_sparse.tensor import SparseTensor +from torch_sparse.testing import devices, dtypes, tensor @pytest.mark.parametrize('dtype,device', product(dtypes, devices)) diff --git a/test/test_eye.py b/test/test_eye.py index 293d676f..f3f2275b 100644 --- a/test/test_eye.py +++ b/test/test_eye.py @@ -1,9 +1,9 @@ from itertools import product import pytest -from torch_sparse.tensor import SparseTensor -from .utils import dtypes, devices +from torch_sparse.tensor import SparseTensor +from torch_sparse.testing import devices, dtypes @pytest.mark.parametrize('dtype,device', product(dtypes, devices)) diff --git a/test/test_matmul.py b/test/test_matmul.py index bf7ad129..0de64dfe 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -6,8 +6,7 @@ from torch_sparse.matmul import matmul from torch_sparse.tensor import SparseTensor - -from .utils import devices, grad_dtypes, reductions +from torch_sparse.testing import devices, grad_dtypes, reductions @pytest.mark.parametrize('dtype,device,reduce', diff --git a/test/test_metis.py b/test/test_metis.py index 897526af..a0eac32c 100644 --- a/test/test_metis.py +++ b/test/test_metis.py @@ -2,9 +2,9 @@ import pytest import torch -from torch_sparse.tensor import SparseTensor -from .utils import devices +from torch_sparse.tensor import SparseTensor +from torch_sparse.testing import devices try: rowptr = torch.tensor([0, 1]) diff --git a/test/test_permute.py b/test/test_permute.py index 206338f1..d8fd041f 100644 --- a/test/test_permute.py +++ b/test/test_permute.py @@ -1,8 +1,8 @@ import pytest import torch -from torch_sparse.tensor import SparseTensor -from .utils import devices, tensor +from torch_sparse.tensor import SparseTensor +from torch_sparse.testing import devices, tensor @pytest.mark.parametrize('device', devices) diff --git a/test/test_spmm.py b/test/test_spmm.py index 4d0e05c7..f0bb0a80 100644 --- a/test/test_spmm.py +++ b/test/test_spmm.py @@ -2,9 +2,9 @@ import pytest import torch -from torch_sparse import spmm -from .utils import dtypes, devices, tensor +from torch_sparse import spmm +from torch_sparse.testing import devices, dtypes, tensor @pytest.mark.parametrize('dtype,device', product(dtypes, devices)) diff --git a/test/test_spspmm.py b/test/test_spspmm.py index 95647b84..d6311d85 100644 --- a/test/test_spspmm.py +++ b/test/test_spspmm.py @@ -4,8 +4,7 @@ import torch from torch_sparse import SparseTensor, spspmm - -from .utils import devices, grad_dtypes, tensor +from torch_sparse.testing import devices, grad_dtypes, tensor @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) diff --git a/test/test_storage.py b/test/test_storage.py index 81bfbb49..04f62bb6 100644 --- a/test/test_storage.py +++ b/test/test_storage.py @@ -2,9 +2,9 @@ import pytest import torch -from torch_sparse.storage import SparseStorage -from .utils import dtypes, devices, tensor +from torch_sparse.storage import SparseStorage +from torch_sparse.testing import devices, dtypes, tensor @pytest.mark.parametrize('device', devices) diff --git a/test/test_tensor.py b/test/test_tensor.py index c83abcbc..e94f6926 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -2,9 +2,9 @@ import pytest import torch -from torch_sparse import SparseTensor -from .utils import grad_dtypes, devices +from torch_sparse import SparseTensor +from torch_sparse.testing import devices, grad_dtypes @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @@ -15,8 +15,8 @@ def test_getitem(dtype, device): mat = torch.randn(m, n, dtype=dtype, device=device) mat = SparseTensor.from_dense(mat) - idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device) - idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device) + idx1 = torch.randint(0, m, (k, ), dtype=torch.long, device=device) + idx2 = torch.randint(0, n, (k, ), dtype=torch.long, device=device) bool1 = torch.zeros(m, dtype=torch.bool, device=device) bool2 = torch.zeros(n, dtype=torch.bool, device=device) bool1.scatter_(0, idx1, 1) diff --git a/test/test_transpose.py b/test/test_transpose.py index 8cf6946a..18bdcdd5 100644 --- a/test/test_transpose.py +++ b/test/test_transpose.py @@ -2,9 +2,9 @@ import pytest import torch -from torch_sparse import transpose -from .utils import dtypes, devices, tensor +from torch_sparse import transpose +from torch_sparse.testing import devices, dtypes, tensor @pytest.mark.parametrize('dtype,device', product(dtypes, devices)) diff --git a/test/utils.py b/torch_sparse/testing.py similarity index 80% rename from test/utils.py rename to torch_sparse/testing.py index 7cc58c1d..9383ee07 100644 --- a/test/utils.py +++ b/torch_sparse/testing.py @@ -1,3 +1,5 @@ +from typing import Any + import torch import torch_scatter from packaging import version @@ -13,8 +15,8 @@ devices = [torch.device('cpu')] if torch.cuda.is_available(): - devices += [torch.device(f'cuda:{torch.cuda.current_device()}')] + devices += [torch.device('cuda:0')] -def tensor(x, dtype, device): +def tensor(x: Any, dtype: torch.dtype, device: torch.device): return None if x is None else torch.tensor(x, dtype=dtype, device=device)