In [4]:
import time
import torch
import sys
import os

# Add project root to PYTHONPATH
sys.path.append('/home/vkapil/scratch/nl/torch-nlpp')

from ase import Atoms
from neighbourlist import NeighbourList

from vesin import ase_neighbor_list
from vesin import NeighborList as NeighbourList_vesin


from torch_nl import compute_neighborlist, compute_neighborlist_n2, ase2data

In [9]:
cell = [
    [ 2.460394,  0.0,       0.0     ],
    [-1.26336,   2.044166,  0.0     ],
    [-0.139209, -0.407369,  6.809714]
]

positions = [
    [-0.03480225, -0.10184225, 1.70242850],
    [-0.10440675, -0.30552675, 5.10728550],
    [-0.05691216,  1.26093576, 1.70242850],
    [ 1.11473716,  0.37586124, 5.10728550],
]

carbon = Atoms(
    symbols="CCCC",
    positions=positions,
    cell=cell,
    pbc=True,
) * (4,4,2)

n_tries = 10
radius = 4.0
box = carbon.cell
num_atoms = len(carbon)

print (f"radius = {radius}")
print (f"num_atoms = {num_atoms}")
print (f"box sizes = {torch.linalg.norm(torch.tensor(box), dim=0)}")

print ('')
print ('VK torch-nl++ O^2 algo [ms]')

nl = NeighbourList(
    list_of_configurations=[carbon],
    radius=radius,
    batch_size=1,
    device='cuda'
)

nl.load_data()

out = nl.calculate_neighbourlist(use_torch_compile=False)

start_time = time.perf_counter()
for i in range(n_tries):
    out = nl.calculate_neighbourlist(use_torch_compile=False)
torch.cuda.synchronize()
print ((time.perf_counter() - start_time) / n_tries * 1000)

print ('')

print ('ASE O(N) algo [ms]')

start_time = time.perf_counter()
for i in range(n_tries):
    i, j, S, d = ase_neighbor_list("ijSd", carbon, cutoff=radius)
print ((time.perf_counter() - start_time) / n_tries * 1000)

print ('')

print ('vesin O(N) algo [ms]')

start_time = time.perf_counter()
for i in range(n_tries):
    calculator = NeighbourList_vesin(cutoff=4.0, full_list=True)
    i, j, S, d = calculator.compute(
        points=carbon.positions,
        box=carbon.cell,
        periodic=True,
        quantities="ijSd")
print ((time.perf_counter() - start_time) / n_tries * 1000)

print (' ')
print ('torch-nl O(N^2) algo [ms]')

pos, cell_t, pbc, batch, _ = ase2data([carbon])

compute_neighborlist_n2(4.0, pos, cell_t, pbc, batch, False )

start = time.perf_counter()
for i in range(n_tries):
    compute_neighborlist_n2(4.0, pos, cell_t, pbc, batch, False )

print((time.perf_counter() - start) / n_tries * 1000)



radius = 4.0
num_atoms = 128
box sizes = tensor([11.0667,  8.2172, 13.6194], dtype=torch.float64)

VK torch-nl++ O^2 algo [ms]
0.7110321894288063

ASE O(N) algo [ms]
0.5415303632616997

vesin O(N) algo [ms]
0.27181394398212433
 
torch-nl O(N^2) algo [ms]
8.372658118605614


In [4]:
print ('Number of batches = ', len(out))

Number of batches =  3


In [5]:
print ('Batch_size = ', len(out[0]))

Batch_size =  2


In [7]:
print ('Number of entries in the list', len(out[0][0]))

Number of entries in the list 4


In [8]:
from vesin import ase_neighbor_list

i, j, S, d = ase_neighbor_list("ijSd", carbon, cutoff=3.0)

In [9]:
i

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3], dtype=uint64)

In [10]:
d

array([2.36886155, 2.460394  , 2.40305912, 2.40305912, 2.36886155,
       2.460394  , 2.72591262, 1.39663935, 2.83195861, 1.41597718,
       1.36295736, 2.79327643, 2.36886155, 2.460394  , 2.40305912,
       2.40305912, 2.460394  , 2.36886155, 2.72591262, 2.79327643,
       1.36295736, 1.41597718, 2.83195861, 1.39663935, 1.39663935,
       2.72591262, 2.83195861, 1.41597718, 1.36295736, 2.79327643,
       2.460394  , 2.40305912, 2.36886155, 2.40305912, 2.460394  ,
       2.36886155, 1.39663935, 2.83195861, 1.41597718, 1.36295736,
       2.79327643, 2.72591262, 2.36886155, 2.460394  , 2.40305912,
       2.40305912, 2.460394  , 2.36886155])

In [11]:
d

array([2.36886155, 2.460394  , 2.40305912, 2.40305912, 2.36886155,
       2.460394  , 2.72591262, 1.39663935, 2.83195861, 1.41597718,
       1.36295736, 2.79327643, 2.36886155, 2.460394  , 2.40305912,
       2.40305912, 2.460394  , 2.36886155, 2.72591262, 2.79327643,
       1.36295736, 1.41597718, 2.83195861, 1.39663935, 1.39663935,
       2.72591262, 2.83195861, 1.41597718, 1.36295736, 2.79327643,
       2.460394  , 2.40305912, 2.36886155, 2.40305912, 2.460394  ,
       2.36886155, 1.39663935, 2.83195861, 1.41597718, 1.36295736,
       2.79327643, 2.72591262, 2.36886155, 2.460394  , 2.40305912,
       2.40305912, 2.460394  , 2.36886155])