From 92c020c80d7add595a3d3b37ec2f0fe99b0b7db5 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 29 Nov 2022 08:53:46 +0000 Subject: [PATCH] fix test --- test/__init__.py | 0 test/test_broadcasting.py | 3 +-- test/test_gather.py | 5 ++--- test/test_multi_gpu.py | 3 +-- test/test_scatter.py | 3 +-- test/test_segment.py | 5 ++--- test/test_zero_tensors.py | 7 +++---- test/utils.py => torch_scatter/testing.py | 12 ++++++++---- 8 files changed, 18 insertions(+), 20 deletions(-) delete mode 100644 test/__init__.py rename test/utils.py => torch_scatter/testing.py (52%) diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/test_broadcasting.py b/test/test_broadcasting.py index cfb3593c..0b332e49 100644 --- a/test/test_broadcasting.py +++ b/test/test_broadcasting.py @@ -3,8 +3,7 @@ import pytest import torch from torch_scatter import scatter - -from .utils import reductions, devices +from torch_scatter.testing import devices, reductions @pytest.mark.parametrize('reduce,device', product(reductions, devices)) diff --git a/test/test_gather.py b/test/test_gather.py index 8d0d100f..0b40e5d4 100644 --- a/test/test_gather.py +++ b/test/test_gather.py @@ -3,9 +3,8 @@ import pytest import torch from torch.autograd import gradcheck -from torch_scatter import gather_csr, gather_coo - -from .utils import tensor, dtypes, devices +from torch_scatter import gather_coo, gather_csr +from torch_scatter.testing import devices, dtypes, tensor tests = [ { diff --git a/test/test_multi_gpu.py b/test/test_multi_gpu.py index cdaf893e..98ed38c4 100644 --- a/test/test_multi_gpu.py +++ b/test/test_multi_gpu.py @@ -3,8 +3,7 @@ import pytest import torch import torch_scatter - -from .utils import reductions, tensor, dtypes +from torch_scatter.testing import dtypes, reductions, tensor tests = [ { diff --git a/test/test_scatter.py b/test/test_scatter.py index 8ba83813..93257619 100644 --- a/test/test_scatter.py +++ b/test/test_scatter.py @@ -4,8 +4,7 @@ import torch import torch_scatter from torch.autograd import gradcheck - -from .utils import devices, dtypes, reductions, tensor +from torch_scatter.testing import devices, dtypes, reductions, tensor reductions = reductions + ['mul'] diff --git a/test/test_segment.py b/test/test_segment.py index 3e8996be..9adc49da 100644 --- a/test/test_segment.py +++ b/test/test_segment.py @@ -2,10 +2,9 @@ import pytest import torch -from torch.autograd import gradcheck import torch_scatter - -from .utils import reductions, tensor, dtypes, devices +from torch.autograd import gradcheck +from torch_scatter.testing import devices, dtypes, reductions, tensor tests = [ { diff --git a/test/test_zero_tensors.py b/test/test_zero_tensors.py index 60855427..f744eb56 100644 --- a/test/test_zero_tensors.py +++ b/test/test_zero_tensors.py @@ -2,10 +2,9 @@ import pytest import torch -from torch_scatter import scatter, segment_coo, gather_coo -from torch_scatter import segment_csr, gather_csr - -from .utils import reductions, tensor, grad_dtypes, devices +from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo, + segment_csr) +from torch_scatter.testing import devices, grad_dtypes, reductions, tensor @pytest.mark.parametrize('reduce,dtype,device', diff --git a/test/utils.py b/torch_scatter/testing.py similarity index 52% rename from test/utils.py rename to torch_scatter/testing.py index 4f23e248..2407b8a0 100644 --- a/test/utils.py +++ b/torch_scatter/testing.py @@ -1,15 +1,19 @@ +from typing import Any + import torch reductions = ['sum', 'add', 'mean', 'min', 'max'] -dtypes = [torch.half, torch.bfloat16, torch.float, torch.double, - torch.int, torch.long] +dtypes = [ + torch.half, torch.bfloat16, torch.float, torch.double, torch.int, + torch.long +] grad_dtypes = [torch.float, torch.double] devices = [torch.device('cpu')] if torch.cuda.is_available(): - devices += [torch.device(f'cuda:{torch.cuda.current_device()}')] + devices += [torch.device('cuda:0')] -def tensor(x, dtype, device): +def tensor(x: Any, dtype: torch.dtype, device: torch.device): return None if x is None else torch.tensor(x, device=device).to(dtype)