Skip to content

Commit

Permalink
Fix macOS ARM tests (#9020)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rusty1s and pre-commit-ci[bot] committed Mar 5, 2024
1 parent db9d8f5 commit 0d30e89
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 6 deletions.
2 changes: 2 additions & 0 deletions test/loader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch_geometric.testing import (
get_random_edge_index,
get_random_tensor_frame,
onlyLinux,
withCUDA,
withPackage,
)
Expand Down Expand Up @@ -72,6 +73,7 @@ def test_dataloader(num_workers, device):
assert batch.edge_index_batch.tolist() == [0, 0, 0, 0, 1, 1, 1, 1]


@onlyLinux
@pytest.mark.parametrize('num_workers', num_workers_list)
def test_dataloader_on_disk_dataset(tmp_path, num_workers):
dataset = OnDiskDataset(tmp_path)
Expand Down
2 changes: 1 addition & 1 deletion test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ def test_spspmm(device, reduce, transpose, is_undirected):
assert isinstance(out, EdgeIndex)
assert out.is_sorted_by_row
assert out._sparse_size == (3, 3)
if not torch_geometric.typing.WITH_WINDOWS:
if torch_geometric.typing.WITH_MKL:
assert out._indptr is not None
assert torch.allclose(out.to_dense(value), adj1_dense @ adj2_dense)
else:
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,7 +1823,7 @@ def matmul(

transpose &= not input.is_undirected or input_value is not None

if torch_geometric.typing.WITH_WINDOWS: # pragma: no cover
if not torch_geometric.typing.WITH_MKL: # pragma: no cover
sparse_input = input.to_sparse_coo(input_value)
elif input.is_sorted_by_col:
sparse_input = input.to_sparse_csc(input_value)
Expand All @@ -1833,7 +1833,7 @@ def matmul(
if transpose:
sparse_input = sparse_input.t()

if torch_geometric.typing.WITH_WINDOWS: # pragma: no cover
if not torch_geometric.typing.WITH_MKL: # pragma: no cover
other = other.to_sparse_coo(other_value)
elif other.is_sorted_by_col:
other = other.to_sparse_csc(other_value)
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def forward(self, data: Data) -> Data:
adj = torch.zeros((N, N), device=row.device)
adj[row, col] = value
loop_index = torch.arange(N, device=row.device)
elif torch_geometric.typing.WITH_WINDOWS:
adj = to_torch_coo_tensor(data.edge_index, value, size=data.size())
else:
elif torch_geometric.typing.WITH_MKL:
adj = to_torch_csr_tensor(data.edge_index, value, size=data.size())
else:
adj = to_torch_coo_tensor(data.edge_index, value, size=data.size())

def get_pe(out: Tensor) -> Tensor:
if is_torch_sparse_tensor(out):
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13

WITH_WINDOWS = os.name == 'nt'
WITH_MKL = 'USE_MKL=OFF' not in torch.__config__.show()

MAX_INT64 = torch.iinfo(torch.int64).max

Expand Down

0 comments on commit 0d30e89

Please sign in to comment.