Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix neighbor backward #179

Merged
merged 18 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 213 additions & 125 deletions tests/test_neighbors.py

Large diffs are not rendered by default.

38 changes: 18 additions & 20 deletions torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_cluster import radius_graph
import torchmdnet.neighbors as neighbors
import warnings


Expand Down Expand Up @@ -78,7 +77,6 @@ def forward(self, z, x, edge_index, edge_weight, edge_attr):
def message(self, x_j, W):
return x_j * W


class OptimizedDistance(torch.nn.Module):

def __init__(
Expand Down Expand Up @@ -168,11 +166,9 @@ def __init__(
lbox = cutoff_upper * 3.0
self.box = torch.tensor([[lbox, 0, 0], [0, lbox, 0], [0, 0, lbox]])
self.box = self.box.cpu() # All strategies expect the box to be in CPU memory
self._backends = neighbors.get_backends()
self.kernel = self._backends[self.strategy]
if self.kernel is None:
raise ValueError("Unknown strategy: {}".format(self.strategy))
self.check_errors = check_errors
from torchmdnet.neighbors import get_neighbor_pairs_kernel
self.kernel = get_neighbor_pairs_kernel;

def forward(
self, pos: Tensor, batch: Optional[Tensor] = None
Expand All @@ -186,13 +182,13 @@ def forward(
shape (N,)
Returns
-------
neighbors : torch.Tensor
edge_index : torch.Tensor
List of neighbors for each atom in the batch.
shape (2, num_found_pairs or max_num_pairs)
distances : torch.Tensor
edge_weight : torch.Tensor
List of distances for each atom in the batch.
shape (num_found_pairs or max_num_pairs,)
distance_vecs : torch.Tensor
edge_vec : torch.Tensor
List of distance vectors for each atom in the batch.
shape (num_found_pairs or max_num_pairs, 3)

Expand All @@ -206,13 +202,14 @@ def forward(
max_pairs = -self.max_num_pairs * pos.shape[0]
if batch is None:
batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device)
neighbors, distance_vecs, distances, num_pairs = self.kernel(
pos,
edge_index, edge_vec, edge_weight, num_pairs = self.kernel(
strategy=self.strategy,
positions=pos,
batch=batch,
max_num_pairs=max_pairs,
cutoff_lower=self.cutoff_lower,
cutoff_upper=self.cutoff_upper,
loop=self.loop,
batch=batch,
max_num_pairs=max_pairs,
include_transpose=self.include_transpose,
box_vectors=self.box,
use_periodic=self.use_periodic,
Expand All @@ -224,17 +221,18 @@ def forward(
num_pairs[0], max_pairs
)
)
edge_index = edge_index.to(torch.long)
# Remove (-1,-1) pairs
if self.resize_to_fit:
mask = neighbors[0] != -1
neighbors = neighbors[:, mask]
distances = distances[mask]
distance_vecs = distance_vecs[mask, :]
neighbors = neighbors.to(torch.long)
mask = edge_index[0] != -1
edge_index = edge_index[:, mask]
edge_weight = edge_weight[mask]
edge_vec = edge_vec[mask, :]

if self.return_vecs:
return neighbors, distances, distance_vecs
return edge_index, edge_weight, edge_vec
else:
return neighbors, distances, None
return edge_index, edge_weight, None


class GaussianSmearing(nn.Module):
Expand Down
24 changes: 7 additions & 17 deletions torchmdnet/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
import os
import torch as pt
import torch
from torch.utils import cpp_extension


def compile_extension():
src_dir = os.path.dirname(__file__)
sources = ["neighbors.cpp", "neighbors_cpu.cpp"] + (
["neighbors_cuda.cu", "backwards.cu"]
if pt.cuda.is_available()
else []
["neighbors_cuda.cu", "backwards.cu"] if torch.cuda.is_available() else []
)
sources = [os.path.join(src_dir, name) for name in sources]
cpp_extension.load(name="torchmdnet_neighbors", sources=sources, is_python_module=False)

cpp_extension.load(
name="torchmdnet_neighbors", sources=sources, is_python_module=False
)

def get_backends():
compile_extension()
get_neighbor_pairs_brute = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_brute
get_neighbor_pairs_shared = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_shared
get_neighbor_pairs_cell = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_cell
return {
"brute": get_neighbor_pairs_brute,
"cell": get_neighbor_pairs_cell,
"shared": get_neighbor_pairs_shared,
}
compile_extension()
get_neighbor_pairs_kernel = torch.ops.torchmdnet_neighbors.get_neighbor_pairs
59 changes: 0 additions & 59 deletions torchmdnet/neighbors/backwards.cu

This file was deleted.

5 changes: 0 additions & 5 deletions torchmdnet/neighbors/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,4 @@ __device__ auto compute_distance(scalar3<scalar_t> pos_i, scalar3<scalar_t> pos_

} // namespace triclinic

/*
* Backward pass for the CUDA neighbor list operation.
* Computes the gradient of the positions with respect to the distances and deltas.
*/
tensor_list common_backward(AutogradContext* ctx, const tensor_list& grad_inputs);
#endif
4 changes: 1 addition & 3 deletions torchmdnet/neighbors/neighbors.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include <torch/extension.h>

TORCH_LIBRARY(torchmdnet_neighbors, m) {
m.def("get_neighbor_pairs_brute(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs_shared(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
}
8 changes: 5 additions & 3 deletions torchmdnet/neighbors/neighbors_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
}

TORCH_LIBRARY_IMPL(torchmdnet_neighbors, CPU, m) {
m.impl("get_neighbor_pairs_brute", &forward);
m.impl("get_neighbor_pairs_shared", &forward);
m.impl("get_neighbor_pairs_cell", &forward);
m.impl("get_neighbor_pairs", [](const std::string &strategy, const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
const Scalar& max_num_pairs, bool loop, bool include_transpose) {
return forward(positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, max_num_pairs, loop, include_transpose);
});
}
103 changes: 69 additions & 34 deletions torchmdnet/neighbors/neighbors_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,80 @@
Connection between the neighbor CUDA implementations and the torch extension.
See neighbors.cpp for the definition of the torch extension functions.
*/
#include <torch/extension.h>
#include "neighbors_cuda_brute.cuh"
#include "neighbors_cuda_cell.cuh"
#include "neighbors_cuda_shared.cuh"
#include <torch/extension.h>
template <class... T> auto call_forward_kernel(const std::string& kernel_name, const T&... args) {
if (kernel_name == "brute") {
return forward_brute(args...);
} else if (kernel_name == "cell") {
return forward_cell(args...);
} else if (kernel_name == "shared") {
return forward_shared(args...);
} else {
throw std::runtime_error("Unknown kernel name");
}
}

// This is the autograd function that is called when the user calls get_neighbor_pairs.
// It dispatches the required strategy for the forward function and implements the backward
// function. The backward function is written in full pytorch so that it can be differentiated a
// second time automatically via Autograd.
class NeighborAutograd : public torch::autograd::Function<NeighborAutograd> {
public:
static tensor_list forward(AutogradContext* ctx, const std::string& strategy,
const Tensor& positions, const Tensor& batch,
const Tensor& box_vectors, bool use_periodic,
const Scalar& cutoff_lower, const Scalar& cutoff_upper,
const Scalar& max_num_pairs, bool loop, bool include_transpose) {
Tensor neighbors, deltas, distances, i_curr_pair;
std::tie(neighbors, deltas, distances, i_curr_pair) =
call_forward_kernel(strategy, positions, batch, box_vectors, use_periodic, cutoff_lower,
cutoff_upper, max_num_pairs, loop, include_transpose);
ctx->save_for_backward({neighbors, deltas, distances});
ctx->saved_data["num_atoms"] = positions.size(0);
return {neighbors, deltas, distances, i_curr_pair};
}

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
RaulPPelaez marked this conversation as resolved.
Show resolved Hide resolved
auto saved = ctx->get_saved_variables();
auto edge_index = saved[0];
auto edge_vec = saved[1];
auto edge_weight = saved[2];
auto num_atoms = ctx->saved_data["num_atoms"].toInt();
auto grad_edge_vec = grad_outputs[1];
auto grad_edge_weight = grad_outputs[2];
auto r0 = edge_weight.nonzero().squeeze(-1);
auto grad_positions = torch::zeros({num_atoms, 3}, edge_vec.options());
// We need to avoid dividing by 0. Otherwise Autograd fills the gradient with NaNs in the
// case of a double backwards. This is why we index_select like this.
auto grad_distances_ =
(edge_vec.index_select(0, r0) / edge_weight.index_select(0, r0).unsqueeze(-1)) *
grad_edge_weight.index_select(0, r0).unsqueeze(-1);
auto edge_index_no_r0 = edge_index.index_select(1, r0);
auto result = grad_edge_vec.index_select(0, r0) + grad_distances_;
grad_positions.index_add_(0, edge_index_no_r0[0], result);
grad_positions.index_add_(0, edge_index_no_r0[1], -result);
Tensor ignore;
return {ignore, grad_positions, ignore, ignore, ignore, ignore,
ignore, ignore, ignore, ignore, ignore};
}
};

TORCH_LIBRARY_IMPL(torchmdnet_neighbors, AutogradCUDA, m) {
m.impl("get_neighbor_pairs_brute",
[](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
const Scalar& max_num_pairs, bool loop, bool include_transpose) {
tensor_list results;
if (positions.size(0) >= 32768) {
// Revert to shared if there are too many particles, which brute can't handle
results = AutogradSharedCUDA::apply(positions, batch, cutoff_lower, cutoff_upper,
box_vectors, use_periodic, max_num_pairs,
loop, include_transpose);
} else {
results = AutogradBruteCUDA::apply(positions, batch, cutoff_lower, cutoff_upper,
box_vectors, use_periodic, max_num_pairs,
loop, include_transpose);
m.impl("get_neighbor_pairs",
[](const std::string& strategy, const Tensor& positions, const Tensor& batch,
const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower,
const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
bool include_transpose) {
auto final_strategy = strategy;
if (positions.size(0) >= 32768 && strategy == "brute") {
final_strategy = "shared";
}
return std::make_tuple(results[0], results[1], results[2], results[3]);
});
m.impl("get_neighbor_pairs_shared",
[](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
const Scalar& max_num_pairs, bool loop, bool include_transpose) {
const tensor_list results = AutogradSharedCUDA::apply(
positions, batch, cutoff_lower, cutoff_upper, box_vectors, use_periodic,
max_num_pairs, loop, include_transpose);
return std::make_tuple(results[0], results[1], results[2], results[3]);
});
m.impl("get_neighbor_pairs_cell",
[](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
const Scalar& max_num_pairs, bool loop, bool include_transpose) {
const tensor_list results = AutogradCellCUDA::apply(
positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper,
max_num_pairs, loop, include_transpose);
return std::make_tuple(results[0], results[1], results[2], results[3]);
auto result = NeighborAutograd::apply(final_strategy, positions, batch, box_vectors,
use_periodic, cutoff_lower, cutoff_upper,
max_num_pairs, loop, include_transpose);
return std::make_tuple(result[0], result[1], result[2], result[3]);
});
}
Loading
Loading