Skip to content

Commit

Permalink
Pass strategy string as const ref
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jun 8, 2023
1 parent 5970fc0 commit dddb3b1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion torchmdnet/neighbors/neighbors_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
}

TORCH_LIBRARY_IMPL(torchmdnet_neighbors, CPU, m) {
m.impl("get_neighbor_pairs", [](std::string strategy, const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
m.impl("get_neighbor_pairs", [](const std::string &strategy, const Tensor& positions, const Tensor& batch, const Tensor& box_vectors,
bool use_periodic, const Scalar& cutoff_lower, const Scalar& cutoff_upper,
const Scalar& max_num_pairs, bool loop, bool include_transpose) {
return forward(positions, batch, box_vectors, use_periodic, cutoff_lower, cutoff_upper, max_num_pairs, loop, include_transpose);
Expand Down
7 changes: 4 additions & 3 deletions torchmdnet/neighbors/neighbors_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ public:

TORCH_LIBRARY_IMPL(torchmdnet_neighbors, AutogradCUDA, m) {
m.impl("get_neighbor_pairs",
[](std::string strategy, const Tensor& positions, const Tensor& batch,
[](const std::string& strategy, const Tensor& positions, const Tensor& batch,
const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower,
const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
bool include_transpose) {
auto final_strategy = strategy;
if (positions.size(0) >= 32768 && strategy == "brute") {
strategy = "shared";
final_strategy = "shared";
}
auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors,
auto result = NeighborAutograd::apply(final_strategy, positions, batch, box_vectors,
use_periodic, cutoff_lower, cutoff_upper,
max_num_pairs, loop, include_transpose);
return std::make_tuple(result[0], result[1], result[2], result[3]);
Expand Down

0 comments on commit dddb3b1

Please sign in to comment.