From 45a4d9854a6af081c2e8b8855cc23a20900adc1c Mon Sep 17 00:00:00 2001 From: james77777778 <20734616+james77777778@users.noreply.github.com> Date: Mon, 22 Jun 2020 22:53:24 +0800 Subject: [PATCH 1/2] fix the bug from metis if num_parts == 1 --- test/test_metis.py | 5 +++++ torch_sparse/metis.py | 9 +++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_metis.py b/test/test_metis.py index 89e35f3f..8c7e47b3 100644 --- a/test/test_metis.py +++ b/test/test_metis.py @@ -30,3 +30,8 @@ def test_metis(device): weighted=False) assert partptr.numel() == 3 assert perm.numel() == 6 + + _, partptr, perm = mat.partition(num_parts=1, recursive=False, + weighted=True) + assert partptr.numel() == 2 + assert perm.numel() == 6 diff --git a/torch_sparse/metis.py b/torch_sparse/metis.py index b00c44a6..d34e6c4f 100644 --- a/torch_sparse/metis.py +++ b/torch_sparse/metis.py @@ -33,8 +33,13 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False, else: value = None - cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, - recursive) + if num_parts > 1: + cluster = torch.ops.torch_sparse.partition(rowptr, col, value, + num_parts, recursive) + elif num_parts == 1: + cluster = torch.zeros((src.size(0)), dtype=torch.long) + else: + raise ValueError cluster = cluster.to(src.device()) cluster, perm = cluster.sort() From eb8c2ec02be499597a613ee8d2cc8f7f8d98d3f8 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 23 Jun 2020 09:54:52 +0200 Subject: [PATCH 2/2] cleanup --- torch_sparse/metis.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torch_sparse/metis.py b/torch_sparse/metis.py index d34e6c4f..8e9617ff 100644 --- a/torch_sparse/metis.py +++ b/torch_sparse/metis.py @@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]: def partition(src: SparseTensor, num_parts: int, recursive: bool = False, weighted=False ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: + + assert num_parts >= 1 + if num_parts == 1: + partptr = torch.tensor([0, src.size(0)], device=src.device()) + perm = torch.arange(src.size(0), device=src.device()) + return src, partptr, perm + rowptr, col, value = src.csr() rowptr, col = rowptr.cpu(), col.cpu() @@ -33,13 +40,8 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False, else: value = None - if num_parts > 1: - cluster = torch.ops.torch_sparse.partition(rowptr, col, value, - num_parts, recursive) - elif num_parts == 1: - cluster = torch.zeros((src.size(0)), dtype=torch.long) - else: - raise ValueError + cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, + recursive) cluster = cluster.to(src.device()) cluster, perm = cluster.sort()