diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 5824a6863..3d3fa3f7e 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -215,6 +215,7 @@ def forward( neighbors = neighbors[:, mask] distances = distances[mask] distance_vecs = distance_vecs[mask, :] + neighbors = neighbors.to(torch.long) if self.return_vecs: return neighbors, distances, distance_vecs else: