Skip to content

Commit

Permalink
Change name of extension so it does not collide with NNPOps one
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed May 22, 2023
1 parent 526d366 commit a75a521
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
8 changes: 4 additions & 4 deletions torchmdnet/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ def compile_extension():
else []
)
sources = [os.path.join(src_dir, name) for name in sources]
cpp_extension.load(name="neighbors", sources=sources, is_python_module=False)
cpp_extension.load(name="torchmdnet_neighbors", sources=sources, is_python_module=False, verbose=True)


def get_backends():
compile_extension()
get_neighbor_pairs_brute = pt.ops.neighbors.get_neighbor_pairs_brute
get_neighbor_pairs_shared = pt.ops.neighbors.get_neighbor_pairs_shared
get_neighbor_pairs_cell = pt.ops.neighbors.get_neighbor_pairs_cell
get_neighbor_pairs_brute = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_brute
get_neighbor_pairs_shared = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_shared
get_neighbor_pairs_cell = pt.ops.torchmdnet_neighbors.get_neighbor_pairs_cell
return {
"brute": get_neighbor_pairs_brute,
"cell": get_neighbor_pairs_cell,
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/neighbors/neighbors.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <torch/extension.h>

TORCH_LIBRARY(neighbors, m) {
TORCH_LIBRARY(torchmdnet_neighbors, m) {
m.def("get_neighbor_pairs_brute(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs_shared(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
m.def("get_neighbor_pairs_cell(Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, Scalar cutoff_lower, Scalar cutoff_upper, Scalar max_num_pairs, bool loop, bool include_transpose) -> (Tensor neighbors, Tensor distances, Tensor distance_vecs, Tensor num_pairs)");
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/neighbors/neighbors_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
return {neighbors, deltas, distances, num_pairs_found};
}

TORCH_LIBRARY_IMPL(neighbors, CPU, m) {
TORCH_LIBRARY_IMPL(torchmdnet_neighbors, CPU, m) {
m.impl("get_neighbor_pairs_brute", &forward);
m.impl("get_neighbor_pairs_shared", &forward);
m.impl("get_neighbor_pairs_cell", &forward);
Expand Down
3 changes: 2 additions & 1 deletion torchmdnet/neighbors/neighbors_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Connection between the neighbor CUDA implementations and the torch extension.
See neighbors.cpp for the definition of the torch extension functions.
*/
#include <torch/extension.h>
#include "neighbors_cuda_brute.cuh"
#include "neighbors_cuda_cell.cuh"
#include "neighbors_cuda_shared.cuh"

TORCH_LIBRARY_IMPL(neighbors, AutogradCUDA, m) {
TORCH_LIBRARY_IMPL(torchmdnet_neighbors, AutogradCUDA, m) {
m.impl("get_neighbor_pairs_brute",
[](const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
Expand Down

0 comments on commit a75a521

Please sign in to comment.