from metatrain.utils.neighbor_lists import _compute_single_neighbor_list
from metatomic.torch import NeighborListOptions

from vesin import ase_neighbor_list as neighbor_list

from ase.build import bulk
from time import monotonic

repeats = 100
cutoff = 5.0

for multiplier in [5, 10, 15]:
    atoms = bulk("Ar") * [multiplier, multiplier, multiplier]
    print(f"{len(atoms)} atoms")

    start = monotonic()
    for _ in range(repeats):
        _compute_single_neighbor_list(
            atoms, NeighborListOptions(cutoff=cutoff, full_list=True, strict=True)
        )
    print(f"{(monotonic()-start)*1e6/repeats/len(atoms)}µs/atom")


    start = monotonic()
    for _ in range(repeats):
        neighbor_list("ijSD", atoms, cutoff)
    print(f"{(monotonic()-start)*1e6/repeats/len(atoms)}µs/atom")
