diff --git a/torch_scatter/utils/gen.py b/torch_scatter/utils/gen.py index 45dee97b..aa6a6221 100644 --- a/torch_scatter/utils/gen.py +++ b/torch_scatter/utils/gen.py @@ -8,7 +8,8 @@ def maybe_dim_size(index, dim_size=None): if dim_size is not None: return dim_size - return index.max().item() + 1 if index.numel() > 0 else 0 + dim = index.max().item() + 1 if index.numel() > 0 else 0 + return int(dim) def broadcast(src, index, dim):