# Demonstation of the neigbor search operation

## Compile and import

In [1]:
import torch as pt
from torchmdnet.neighbors import get_neighbor_list

## Run

In [2]:
pos = pt.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])

### Forward

In [3]:
get_neighbor_list(pos.to('cpu'))

(tensor([1, 2, 2], dtype=torch.int32),
 tensor([0, 0, 1], dtype=torch.int32),
 tensor([ 5.1962, 10.3923,  5.1962]))

In [4]:
get_neighbor_list(pos.to('cuda'))

(tensor([1, 2, 2], device='cuda:0', dtype=torch.int32),
 tensor([0, 0, 1], device='cuda:0', dtype=torch.int32),
 tensor([ 5.1962, 10.3923,  5.1962], device='cuda:0'))

### Forward and backward

In [5]:
pos_ = pos.to('cpu').detach()
pos_.requires_grad = True
res = get_neighbor_list(pos_)
res[2].sum().backward()
pos_.grad

tensor([[-1.1547, -1.1547, -1.1547],
        [ 0.0000,  0.0000,  0.0000],
        [ 1.1547,  1.1547,  1.1547]])

In [6]:
pos_ = pos.to('cuda').detach()
pos_.requires_grad = True
res = get_neighbor_list(pos_)
res[2].sum().backward()
pos_.grad

tensor([[-1.1547, -1.1547, -1.1547],
        [ 0.0000,  0.0000,  0.0000],
        [ 1.1547,  1.1547,  1.1547]], device='cuda:0')