Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile neighbors without graph breaks #305

Merged
merged 9 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from torchmdnet.models.utils import OptimizedDistance


def sort_neighbors(neighbors, deltas, distances):
i_sorted = np.lexsort(neighbors)
return neighbors[:, i_sorted], deltas[i_sorted], distances[i_sorted]
Expand Down Expand Up @@ -69,7 +70,10 @@ def compute_ref_neighbors(pos, batch, loop, include_transpose, cutoff, box_vecto
return ref_neighbors, ref_distance_vecs, ref_distances


@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])
@pytest.mark.parametrize(
("device", "strategy"),
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
)
@pytest.mark.parametrize("n_batches", [1, 2, 3, 4, 128])
@pytest.mark.parametrize("cutoff", [0.1, 1.0, 3.0, 4.9])
@pytest.mark.parametrize("loop", [True, False])
Expand All @@ -92,7 +96,7 @@ def test_neighbors(
).to(device)
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
lbox = 10.0
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0*lbox
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0 * lbox
# Ensure there is at least one pair
pos[0, :] = torch.zeros(3)
pos[1, :] = torch.zeros(3)
Expand Down Expand Up @@ -141,7 +145,11 @@ def test_neighbors(
assert np.allclose(distances, ref_distances)
assert np.allclose(distance_vecs, ref_distance_vecs)

@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])

@pytest.mark.parametrize(
("device", "strategy"),
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
)
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("include_transpose", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
Expand Down Expand Up @@ -249,10 +257,14 @@ def test_neighbor_grads(
else:
assert np.allclose(ref_pos_grad_sorted, pos_grad_sorted, atol=1e-8, rtol=1e-5)

@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])

@pytest.mark.parametrize(
("device", "strategy"),
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
)
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("include_transpose", [True, False])
@pytest.mark.parametrize("num_atoms", [1,2,10])
@pytest.mark.parametrize("num_atoms", [1, 2, 10])
@pytest.mark.parametrize("box_type", [None, "triclinic", "rectangular"])
def test_neighbor_autograds(
device, strategy, loop, include_transpose, num_atoms, box_type
Expand Down Expand Up @@ -293,8 +305,12 @@ def test_neighbor_autograds(
neighbors, distances, deltas = nl(positions, batch)
# Lambda that returns only the distances and deltas
lambda_dist = lambda x, y: nl(x, y)[1:]
torch.autograd.gradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4)
torch.autograd.gradgradcheck(lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4)
torch.autograd.gradcheck(
lambda_dist, (positions, batch), eps=1e-4, atol=1e-4, rtol=1e-4, nondet_tol=1e-4
)
torch.autograd.gradgradcheck(
lambda_dist, (positions, batch), eps=1e-5, atol=1e-4, rtol=1e-4, nondet_tol=1e-3
)


@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"])
Expand Down Expand Up @@ -353,7 +369,11 @@ def test_large_size(strategy, n_batches):
assert np.allclose(distances, ref_distances)
assert np.allclose(distance_vecs, ref_distance_vecs)

@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")])

@pytest.mark.parametrize(
("device", "strategy"),
[("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared"), ("cuda", "cell")],
)
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("cutoff", [1.0])
@pytest.mark.parametrize("loop", [True, False])
Expand Down Expand Up @@ -504,6 +524,7 @@ def test_cuda_graph_compatible_forward(
assert np.allclose(distances, ref_distances)
assert np.allclose(distance_vecs, ref_distance_vecs)


@pytest.mark.parametrize("device", ["cuda"])
@pytest.mark.parametrize("strategy", ["brute", "shared", "cell"])
@pytest.mark.parametrize("n_batches", [1, 128])
Expand Down Expand Up @@ -578,12 +599,12 @@ def test_cuda_graph_compatible_backward(
torch.cuda.synchronize()


@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared")])
@pytest.mark.parametrize(
("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared")]
)
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("use_forward", [True, False])
def test_per_batch_box(
device, strategy, n_batches, use_forward
):
def test_per_batch_box(device, strategy, n_batches, use_forward):
dtype = torch.float32
cutoff = 1.0
include_transpose = True
Expand All @@ -599,7 +620,7 @@ def test_per_batch_box(
).to(device)
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
lbox = 10.0
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0*lbox
pos = torch.rand(cumsum[-1], 3, device=device, dtype=dtype) * lbox - 10.0 * lbox
# Ensure there is at least one pair
pos[0, :] = torch.zeros(3)
pos[1, :] = torch.zeros(3)
Expand All @@ -625,7 +646,9 @@ def test_per_batch_box(
include_transpose=include_transpose,
)
batch.to(device)
neighbors, distances, distance_vecs = nl(pos, batch, box=box if use_forward else None)
neighbors, distances, distance_vecs = nl(
pos, batch, box=box if use_forward else None
)
neighbors = neighbors.cpu().detach().numpy()
distance_vecs = distance_vecs.cpu().detach().numpy()
distances = distances.cpu().detach().numpy()
Expand All @@ -639,3 +662,45 @@ def test_per_batch_box(
assert np.allclose(neighbors, ref_neighbors)
assert np.allclose(distances, ref_distances)
assert np.allclose(distance_vecs, ref_distance_vecs)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float64])
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("include_transpose", [True, False])
def test_torch_compile(device, dtype, loop, include_transpose):
if torch.__version__ < "2.0.0":
pytest.skip("Not available in this version")
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
np.random.seed(123456)
example_pos = 10 * torch.rand(50, 3, requires_grad=True, dtype=dtype, device=device)
model = OptimizedDistance(
cutoff_lower=0.1, # I do this to avoid non-finite-differentiable points
cutoff_upper=10,
return_vecs=True,
loop=loop,
max_num_pairs=-example_pos.shape[0],
include_transpose=include_transpose,
resize_to_fit=False,
check_errors=False,
).to(device)
for _ in range(50):
model(example_pos)
example_pos = example_pos.detach().requires_grad_(True)
edge_index, edge_vec, edge_distance = model(example_pos)
edge_vec.sum().backward()
example_pos.grad.zero_()
fullgraph = torch.__version__ >= "2.2.0"
model = torch.compile(
model,
fullgraph=fullgraph,
backend="inductor",
mode="reduce-overhead",
)
edge_index, edge_vec, edge_distance = model(example_pos)
edge_vec.sum().backward()
lambda_dist = lambda x: model(x)[1:]
torch.autograd.gradcheck(
lambda_dist, example_pos, eps=1e-5, atol=1e-4, rtol=1e-4, nondet_tol=1e-3
)
78 changes: 60 additions & 18 deletions torchmdnet/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
# Place here any short extensions to torch that you want to use in your code.
# The extensions present in extensions.cpp will be automatically compiled in setup.py and loaded here.
# The extensions will be available under torch.ops.torchmdnet_extensions, but you can add wrappers here to make them more convenient to use.
# Place here too any meta registrations for your extensions if required.

import os.path as osp
import torch
import importlib.machinery
from torch import Tensor
from typing import Tuple


Expand All @@ -29,6 +32,8 @@ def _load_library(library):

_load_library("torchmdnet_extensions")

__all__ = ["is_current_stream_capturing", "get_neighbor_pairs_kernel"]


def is_current_stream_capturing():
"""Returns True if the current CUDA stream is capturing.
Expand All @@ -45,30 +50,29 @@ def is_current_stream_capturing():

def get_neighbor_pairs_kernel(
strategy: str,
positions: torch.Tensor,
batch: torch.Tensor,
box_vectors: torch.Tensor,
positions: Tensor,
batch: Tensor,
box_vectors: Tensor,
use_periodic: bool,
cutoff_lower: float,
cutoff_upper: float,
max_num_pairs: int,
loop: bool,
include_transpose: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Computes the neighbor pairs for a given set of atomic positions.

The list is generated as a list of pairs (i,j) without any enforced ordering.
The list is padded with -1 to the maximum number of pairs.

Parameters
----------
strategy : str
Strategy to use for computing the neighbor list. Can be one of :code:`["shared", "brute", "cell"]`.
positions : torch.Tensor
positions : Tensor
A tensor with shape (N, 3) representing the atomic positions.
batch : torch.Tensor
batch : Tensor
A tensor with shape (N,). Specifies the batch for each atom.
box_vectors : torch.Tensor
box_vectors : Tensor
The vectors defining the periodic box with shape `(3, 3)` or `(max(batch)+1, 3, 3)` if a different box is used for each sample.
use_periodic : bool
Whether to apply periodic boundary conditions.
Expand All @@ -85,18 +89,14 @@ def get_neighbor_pairs_kernel(

Returns
-------
neighbors : torch.Tensor
neighbors : Tensor
List of neighbors for each atom. Shape (2, max_num_pairs).
distances : torch.Tensor
distances : Tensor
List of distances for each atom. Shape (max_num_pairs,).
distance_vecs : torch.Tensor
distance_vecs : Tensor
List of distance vectors for each atom. Shape (max_num_pairs, 3).
num_pairs : torch.Tensor
num_pairs : Tensor
The number of pairs found.

Notes
-----
This function is a torch extension loaded from `torch.ops.torchmdnet_extensions.get_neighbor_pairs`.
"""
return torch.ops.torchmdnet_extensions.get_neighbor_pairs(
strategy,
Expand All @@ -112,7 +112,49 @@ def get_neighbor_pairs_kernel(
)


# For some unknown reason torch.compile is not able to compile this function
if int(torch.__version__.split(".")[0]) >= 2:
def get_neighbor_pairs_bkwd_meta(
grad_edge_vec: Tensor,
grad_edge_weight: Tensor,
edge_index: Tensor,
edge_vec: Tensor,
edge_weight: Tensor,
num_atoms: int,
):
return torch.zeros((num_atoms, 3), dtype=edge_vec.dtype, device=edge_vec.device)


def get_neighbor_pairs_fwd_meta(
strategy: str,
positions: Tensor,
batch: Tensor,
box_vectors: Tensor,
use_periodic: bool,
cutoff_lower: float,
cutoff_upper: float,
max_num_pairs: int,
loop: bool,
include_transpose: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Returns empty vectors with the correct shape for the output of get_neighbor_pairs_kernel."""
size = max_num_pairs
edge_index = torch.empty((2, size), dtype=torch.long, device=positions.device)
edge_distance = torch.empty((size,), dtype=positions.dtype, device=positions.device)
edge_vec = torch.empty((size, 3), dtype=positions.dtype, device=positions.device)
num_pairs = torch.empty((1,), dtype=torch.long, device=positions.device)
return edge_index, edge_vec, edge_distance, num_pairs


if torch.__version__ >= "2.2.0":
from torch.library import impl_abstract

impl_abstract(
"torchmdnet_extensions::get_neighbor_pairs_bkwd", get_neighbor_pairs_bkwd_meta
)
impl_abstract(
"torchmdnet_extensions::get_neighbor_pairs_fwd", get_neighbor_pairs_fwd_meta
)
elif torch.__version__ < "2.2.0" and torch.__version__ >= "2.0.0":
# torch.compile is not able to compile this function in old versions
import torch._dynamo as dynamo

dynamo.disallow_in_graph(torch.ops.torchmdnet_extensions.get_neighbor_pairs)
21 changes: 20 additions & 1 deletion torchmdnet/extensions/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,27 @@ bool is_current_stream_capturing() {
#endif
}

#define TORCH_VERSION_CODE(MAJOR, MINOR, PATCH) ((MAJOR)*10000 + (MINOR)*100 + (PATCH))
#define TORCH_VERSION_COMPARE_LE(MAJOR, MINOR, PATCH) \
(TORCH_VERSION_CODE(TORCH_VERSION_MAJOR, TORCH_VERSION_MINOR, TORCH_VERSION_PATCH) >= \
TORCH_VERSION_CODE(MAJOR, MINOR, PATCH))

TORCH_LIBRARY(torchmdnet_extensions, m) {
m.def("is_current_stream_capturing", is_current_stream_capturing);
m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
#if TORCH_VERSION_COMPARE_LE(2, 2, 0)
//This line is required to signal to torch that the meta registration is implemented in python.
// Specifically, it will look for them in the torchmdnet.extensions module.
m.impl_abstract_pystub("torchmdnet.extensions");
#endif
m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, "
"bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool "
"loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor "
"distance_vecs, Tensor num_pairs)");
//The individual fwd and bkwd functions must be exposed in order to register their meta implementations python side.
m.def("get_neighbor_pairs_fwd(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, "
"bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool "
"loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor "
"distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs_bkwd(Tensor grad_edge_vec, Tensor grad_edge_weight, Tensor edge_index, "
"Tensor edge_vec, Tensor edge_weight, int num_atoms) -> Tensor");
}
2 changes: 1 addition & 1 deletion torchmdnet/extensions/neighbors/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ inline Accessor<scalar_t, num_dims> get_accessor(const Tensor& tensor) {
return tensor.packed_accessor32<scalar_t, num_dims, torch::RestrictPtrTraits>();
};

template <typename scalar_t> __device__ __forceinline__ scalar_t sqrt_(scalar_t x){};
template <typename scalar_t> __device__ __forceinline__ scalar_t sqrt_(scalar_t x){return ::sqrt(x);};
template <> __device__ __forceinline__ float sqrt_(float x) {
return ::sqrtf(x);
};
Expand Down
Loading
Loading