diff --git a/torchmdnet/extensions/__init__.py b/torchmdnet/extensions/__init__.py index 6c4768a8..b8488cd9 100644 --- a/torchmdnet/extensions/__init__.py +++ b/torchmdnet/extensions/__init__.py @@ -137,10 +137,10 @@ def get_neighbor_pairs_fwd_meta( ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Returns empty vectors with the correct shape for the output of get_neighbor_pairs_kernel.""" size = max_num_pairs - edge_index = torch.empty((2, size), dtype=torch.long, device=positions.device) + edge_index = torch.empty((2, size), dtype=torch.int, device=positions.device) edge_distance = torch.empty((size,), dtype=positions.dtype, device=positions.device) edge_vec = torch.empty((size, 3), dtype=positions.dtype, device=positions.device) - num_pairs = torch.empty((1,), dtype=torch.long, device=positions.device) + num_pairs = torch.empty((1,), dtype=torch.int, device=positions.device) return edge_index, edge_vec, edge_distance, num_pairs