Skip to content

Commit

Permalink
Replace withCUDA decorator: withDevice (#9082)
Browse files Browse the repository at this point in the history
Replace `withCUDA` for a `withDevice` decorator.

Change variable name from devices to processors to reduce confusion
against pytorch api (backends/devices) and reflect the hardware choices.

Note that at this time:

## Hardware
3 repertoires of hardware can be used to run pyTorch code:

* CPU only
* CPU and GPU
* Unified Memory Single Chip

## Backend Software
The backend is the software framework used to process tensors by
pytorch. There are several. For example, the following are the backends
available today:


```python
torch.backends.cpu
torch.backends.cuda
torch.backends.cudnn
torch.backends.mha
torch.backends.mps
torch.backends.mkl
torch.backends.mkldnn
torch.backends.nnpack
torch.backends.openmp
torch.backends.opt_einsum
torch.backends.xeon
```
`mps` is the only backend which works with the unified memory single
chip - currently Apple's M series.

Hardware is determined by `torch.cuda.is_available()` and
`torch.backends.mps.is_available()` which confuses when looking up
`backends` and hardaware types/architectures/vendors. Introducing
`processor/processors` variable name reduces cognitive load, while
allowing future developments from vendors/frameworks.



~~Modify `withCUDA` decorator to be single purpose. Downstream
comprehensive distributed device and backend testing can be developed
later.~~

~~Redundant code was removed, this doesn't affect other components.~~

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people committed Mar 24, 2024
1 parent 27367b4 commit f0e4c82
Show file tree
Hide file tree
Showing 28 changed files with 116 additions and 74 deletions.
6 changes: 3 additions & 3 deletions test/loader/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from torch_geometric.data import Data
from torch_geometric.loader import CachedLoader, NeighborLoader
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.testing import withDevice, withPackage


@withCUDA
@withDevice
@withPackage('pyg_lib')
def test_cached_loader(device):
x = torch.randn(14, 16)
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_cached_loader(device):
assert len(cached_loader._cache) == 0


@withCUDA
@withDevice
@withPackage('pyg_lib')
def test_cached_loader_transform(device):
x = torch.randn(14, 16)
Expand Down
4 changes: 2 additions & 2 deletions test/loader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_random_edge_index,
get_random_tensor_frame,
onlyLinux,
withCUDA,
withDevice,
withPackage,
)

Expand All @@ -23,7 +23,7 @@
multiprocessing.set_start_method('spawn')


@withCUDA
@withDevice
@pytest.mark.parametrize('num_workers', num_workers_list)
def test_dataloader(num_workers, device):
if num_workers > 0 and device != torch.device('cpu'):
Expand Down
4 changes: 2 additions & 2 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch_geometric.typing
from torch_geometric.nn import GATConv
from torch_geometric.testing import is_full_test, withCUDA
from torch_geometric.testing import is_full_test, withDevice
from torch_geometric.typing import Adj, Size, SparseTensor
from torch_geometric.utils import to_torch_csc_tensor

Expand Down Expand Up @@ -192,7 +192,7 @@ def test_gat_conv_with_edge_attr():
assert torch.allclose(conv(x, adj2.t()), out)


@withCUDA
@withDevice
def test_gat_conv_empty_edge_index(device):
x = torch.randn(0, 8, device=device)
edge_index = torch.empty(2, 0, dtype=torch.long, device=device)
Expand Down
4 changes: 2 additions & 2 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch_geometric.testing import (
get_random_edge_index,
onlyLinux,
withCUDA,
withDevice,
withPackage,
)

Expand Down Expand Up @@ -178,7 +178,7 @@ def test_hetero_conv_with_dot_syntax_node_types():
assert out_dict['author'].size() == (30, 64)


@withCUDA
@withDevice
@onlyLinux
@withPackage('torch>=2.1.0')
def test_compile_hetero_conv_graph_breaks(device):
Expand Down
9 changes: 7 additions & 2 deletions test/nn/conv/test_rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@

import torch_geometric.typing
from torch_geometric.nn import FastRGCNConv, RGCNConv
from torch_geometric.testing import is_full_test, withCUDA, withPackage
from torch_geometric.testing import (
is_full_test,
withCUDA,
withDevice,
withPackage,
)
from torch_geometric.typing import SparseTensor

classes = [RGCNConv, FastRGCNConv]
confs = [(None, None), (2, None), (None, 2)]


@withCUDA
@withDevice
@withPackage('torch>=1.12.0') # TODO Investigate error
@pytest.mark.parametrize('conf', confs)
def test_rgcn_conv_equality(conf, device):
Expand Down
4 changes: 2 additions & 2 deletions test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
assert_module,
is_full_test,
onlyLinux,
withCUDA,
withDevice,
withPackage,
)
from torch_geometric.typing import SparseTensor
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_multi_aggr_sage_conv(aggr_kwargs):
assert_module(conv, x, edge_index, expected_size=(4, 32))


@withCUDA
@withDevice
@onlyLinux
@withPackage('torch>=2.1.0')
def test_compile_multi_aggr_sage_conv(device):
Expand Down
14 changes: 7 additions & 7 deletions test/nn/dense/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import torch_geometric.backend
from torch_geometric.nn import HeteroDictLinear, HeteroLinear, Linear
from torch_geometric.profile import benchmark
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.testing import withCUDA, withDevice, withPackage
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import cumsum

weight_inits = ['glorot', 'kaiming_uniform', None]
bias_inits = ['zeros', None]


@withCUDA
@withDevice
@pytest.mark.parametrize('weight', weight_inits)
@pytest.mark.parametrize('bias', bias_inits)
def test_linear(weight, bias, device):
Expand All @@ -30,7 +30,7 @@ def test_linear(weight, bias, device):
assert lin(x).size() == (3, 4, 32)


@withCUDA
@withDevice
@pytest.mark.parametrize('weight', weight_inits)
@pytest.mark.parametrize('bias', bias_inits)
def test_lazy_linear(weight, bias, device):
Expand All @@ -50,7 +50,7 @@ def test_lazy_linear(weight, bias, device):
assert copied_lin(x).size() == (3, 4, 32)


@withCUDA
@withDevice
@pytest.mark.parametrize('dim1', [-1, 16])
@pytest.mark.parametrize('dim2', [-1, 16])
@pytest.mark.parametrize('bias', [True, False])
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_copy_unintialized_parameter():
copy.deepcopy(weight)


@withCUDA
@withDevice
@pytest.mark.parametrize('lazy', [True, False])
def test_copy_linear(lazy, device):
lin = Linear(-1 if lazy else 16, 32).to(device)
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_lazy_hetero_linear(device):
assert out.size() == (3, 32)


@withCUDA
@withDevice
@pytest.mark.parametrize('bias', [True, False])
def test_hetero_dict_linear(bias, device):
x_dict = {
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_hetero_dict_linear_jit():
assert len(jit(x_dict)) == 2


@withCUDA
@withDevice
def test_lazy_hetero_dict_linear(device):
x_dict = {
'v': torch.randn(3, 16, device=device),
Expand Down
6 changes: 3 additions & 3 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
onlyLinux,
onlyNeighborSampler,
onlyOnline,
withCUDA,
withDevice,
withPackage,
)

Expand Down Expand Up @@ -203,7 +203,7 @@ def test_basic_gnn_inference(get_dataset, jk):
assert 'n_id' not in data


@withCUDA
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.0.0')
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_trim_to_layer():
assert torch.allclose(out1, out2)


@withCUDA
@withDevice
@onlyLinux
@withPackage('torch>=2.1.0')
@pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA])
Expand Down
6 changes: 3 additions & 3 deletions test/nn/models/test_deep_graph_infomax.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from torch_geometric.nn import GCN, DeepGraphInfomax
from torch_geometric.testing import is_full_test, withCUDA
from torch_geometric.testing import is_full_test, withDevice


@withCUDA
@withDevice
def test_infomax(device):
def corruption(z):
return z + 1
Expand Down Expand Up @@ -42,7 +42,7 @@ def corruption(z):
assert 0 <= acc <= 1


@withCUDA
@withDevice
def test_infomax_predefined_model(device):
def corruption(x, edge_index, edge_weight):
return (
Expand Down
4 changes: 2 additions & 2 deletions test/nn/models/test_metapath2vec.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from torch_geometric.nn import MetaPath2Vec
from torch_geometric.testing import withCUDA
from torch_geometric.testing import withDevice


@withCUDA
@withDevice
def test_metapath2vec(device):
edge_index_dict = {
('author', 'writes', 'paper'):
Expand Down
4 changes: 2 additions & 2 deletions test/nn/models/test_node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import torch_geometric.typing
from torch_geometric.nn import Node2Vec
from torch_geometric.testing import is_full_test, withCUDA, withPackage
from torch_geometric.testing import is_full_test, withDevice, withPackage


@withCUDA
@withDevice
@withPackage('pyg_lib|torch_cluster')
@pytest.mark.parametrize('p', [1.0])
@pytest.mark.parametrize('q', [1.0, 0.5])
Expand Down
6 changes: 3 additions & 3 deletions test/nn/norm/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch

from torch_geometric.nn import BatchNorm, HeteroBatchNorm
from torch_geometric.testing import is_full_test, withCUDA
from torch_geometric.testing import is_full_test, withDevice


@withCUDA
@withDevice
@pytest.mark.parametrize('conf', [True, False])
def test_batch_norm(device, conf):
x = torch.randn(100, 16, device=device)
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_batch_norm_single_element():
assert torch.allclose(out, x)


@withCUDA
@withDevice
@pytest.mark.parametrize('conf', [True, False])
def test_hetero_batch_norm(device, conf):
x = torch.randn((100, 16), device=device)
Expand Down
6 changes: 3 additions & 3 deletions test/nn/norm/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch

from torch_geometric.nn import HeteroLayerNorm, LayerNorm
from torch_geometric.testing import is_full_test, withCUDA
from torch_geometric.testing import is_full_test, withDevice


@withCUDA
@withDevice
@pytest.mark.parametrize('affine', [True, False])
@pytest.mark.parametrize('mode', ['graph', 'node'])
def test_layer_norm(device, affine, mode):
Expand All @@ -27,7 +27,7 @@ def test_layer_norm(device, affine, mode):
assert torch.allclose(out1, out2[100:], atol=1e-6)


@withCUDA
@withDevice
@pytest.mark.parametrize('affine', [False, True])
def test_hetero_layer_norm(device, affine):
x = torch.randn((100, 16), device=device)
Expand Down
12 changes: 6 additions & 6 deletions test/nn/pool/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
L2KNNIndex,
MIPSKNNIndex,
)
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.testing import withDevice, withPackage


@withCUDA
@withDevice
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_l2(device, k):
Expand All @@ -34,7 +34,7 @@ def test_l2(device, k):
assert torch.equal(out.index, index[:, :k])


@withCUDA
@withDevice
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_mips(device, k):
Expand All @@ -58,7 +58,7 @@ def test_mips(device, k):
assert torch.equal(out.index, index[:, :k])


@withCUDA
@withDevice
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
@pytest.mark.parametrize('reserve', [None, 100])
Expand All @@ -82,7 +82,7 @@ def test_approx_l2(device, k, reserve):
assert out.index.min() >= 0 and out.index.max() < 10_000


@withCUDA
@withDevice
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
@pytest.mark.parametrize('reserve', [None, 100])
Expand All @@ -106,7 +106,7 @@ def test_approx_mips(device, k, reserve):
assert out.index.min() >= 0 and out.index.max() < 10_000


@withCUDA
@withDevice
@withPackage('faiss')
@pytest.mark.parametrize('k', [50])
def test_mips_exclude(device, k):
Expand Down
4 changes: 2 additions & 2 deletions test/nn/test_compile_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch_geometric.testing import (
onlyFullTest,
onlyLinux,
withCUDA,
withDevice,
withPackage,
)
from torch_geometric.utils import scatter
Expand Down Expand Up @@ -42,7 +42,7 @@ def fused_gather_scatter(x, edge_index, reduce=['sum', 'mean', 'max']):
return torch.cat(outs, dim=-1)


@withCUDA
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.0.0')
Expand Down
6 changes: 3 additions & 3 deletions test/nn/test_compile_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch_geometric.testing import (
onlyFullTest,
onlyLinux,
withCUDA,
withDevice,
withPackage,
)
from torch_geometric.utils import scatter
Expand All @@ -26,7 +26,7 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
return self.lin_src(out) + self.lin_dst(x)


@withCUDA
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.1.0')
Expand All @@ -49,7 +49,7 @@ def test_compile_conv(device, Conv):
assert torch.allclose(conv(x, edge_index), out, atol=1e-6)


@withCUDA
@withDevice
@onlyLinux
@onlyFullTest
@withPackage('torch>=2.2.0')
Expand Down

0 comments on commit f0e4c82

Please sign in to comment.