Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed test/__init__.py
Empty file.
3 changes: 1 addition & 2 deletions test/test_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import pytest
import torch
from torch_scatter import scatter

from .utils import reductions, devices
from torch_scatter.testing import devices, reductions


@pytest.mark.parametrize('reduce,device', product(reductions, devices))
Expand Down
5 changes: 2 additions & 3 deletions test/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import pytest
import torch
from torch.autograd import gradcheck
from torch_scatter import gather_csr, gather_coo

from .utils import tensor, dtypes, devices
from torch_scatter import gather_coo, gather_csr
from torch_scatter.testing import devices, dtypes, tensor

tests = [
{
Expand Down
3 changes: 1 addition & 2 deletions test/test_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import pytest
import torch
import torch_scatter

from .utils import reductions, tensor, dtypes
from torch_scatter.testing import dtypes, reductions, tensor

tests = [
{
Expand Down
3 changes: 1 addition & 2 deletions test/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torch
import torch_scatter
from torch.autograd import gradcheck

from .utils import devices, dtypes, reductions, tensor
from torch_scatter.testing import devices, dtypes, reductions, tensor

reductions = reductions + ['mul']

Expand Down
5 changes: 2 additions & 3 deletions test/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import pytest
import torch
from torch.autograd import gradcheck
import torch_scatter

from .utils import reductions, tensor, dtypes, devices
from torch.autograd import gradcheck
from torch_scatter.testing import devices, dtypes, reductions, tensor

tests = [
{
Expand Down
7 changes: 3 additions & 4 deletions test/test_zero_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import pytest
import torch
from torch_scatter import scatter, segment_coo, gather_coo
from torch_scatter import segment_csr, gather_csr

from .utils import reductions, tensor, grad_dtypes, devices
from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo,
segment_csr)
from torch_scatter.testing import devices, grad_dtypes, reductions, tensor


@pytest.mark.parametrize('reduce,dtype,device',
Expand Down
12 changes: 8 additions & 4 deletions test/utils.py → torch_scatter/testing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import Any

import torch

reductions = ['sum', 'add', 'mean', 'min', 'max']

dtypes = [torch.half, torch.bfloat16, torch.float, torch.double,
torch.int, torch.long]
dtypes = [
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
grad_dtypes = [torch.float, torch.double]

devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
devices += [torch.device('cuda:0')]


def tensor(x, dtype, device):
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, device=device).to(dtype)