diff --git a/test/test_metis.py b/test/test_metis.py index 89e35f3f..8c7e47b3 100644 --- a/test/test_metis.py +++ b/test/test_metis.py @@ -30,3 +30,8 @@ def test_metis(device): weighted=False) assert partptr.numel() == 3 assert perm.numel() == 6 + + _, partptr, perm = mat.partition(num_parts=1, recursive=False, + weighted=True) + assert partptr.numel() == 2 + assert perm.numel() == 6 diff --git a/torch_sparse/metis.py b/torch_sparse/metis.py index b00c44a6..8e9617ff 100644 --- a/torch_sparse/metis.py +++ b/torch_sparse/metis.py @@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]: def partition(src: SparseTensor, num_parts: int, recursive: bool = False, weighted=False ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: + + assert num_parts >= 1 + if num_parts == 1: + partptr = torch.tensor([0, src.size(0)], device=src.device()) + perm = torch.arange(src.size(0), device=src.device()) + return src, partptr, perm + rowptr, col, value = src.csr() rowptr, col = rowptr.cpu(), col.cpu()