From 500a8e48c35a9c2c7f476b72ed379877cac3a90d Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Wed, 20 Mar 2024 11:31:56 +0100 Subject: [PATCH] Fix incorrect type --- torchmdnet/extensions/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/extensions/__init__.py b/torchmdnet/extensions/__init__.py index 6c4768a8..b8488cd9 100644 --- a/torchmdnet/extensions/__init__.py +++ b/torchmdnet/extensions/__init__.py @@ -137,10 +137,10 @@ def get_neighbor_pairs_fwd_meta( ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Returns empty vectors with the correct shape for the output of get_neighbor_pairs_kernel.""" size = max_num_pairs - edge_index = torch.empty((2, size), dtype=torch.long, device=positions.device) + edge_index = torch.empty((2, size), dtype=torch.int, device=positions.device) edge_distance = torch.empty((size,), dtype=positions.dtype, device=positions.device) edge_vec = torch.empty((size, 3), dtype=positions.dtype, device=positions.device) - num_pairs = torch.empty((1,), dtype=torch.long, device=positions.device) + num_pairs = torch.empty((1,), dtype=torch.int, device=positions.device) return edge_index, edge_vec, edge_distance, num_pairs