In [22]:
import jax
import jax.numpy as np
from functools import partial

## Brute force kNN

In [23]:
rng = jax.random.PRNGKey(42)

n_nodes = 10000
k = 100

# Generate random inputs
x = jax.random.normal(rng, (n_nodes, 3))

In [24]:
# Usual brute force kNN algo

@partial(jax.jit, static_argnums=(1,))
def nearest_neighbors(x, k, mask=None):
    """The shittiest implementation of nearest neighbours with masking in the world"""

    if mask is None:
        mask = np.ones((x.shape[0],), dtype=np.int32)

    n_nodes = x.shape[0]

    distance_matrix = np.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1)

    distance_matrix = np.where(mask[:, None], distance_matrix, np.inf)
    distance_matrix = np.where(mask[None, :], distance_matrix, np.inf)

    indices = np.argsort(distance_matrix, axis=-1)[:, :k]

    sources = indices[:, 0].repeat(k)
    targets = indices.reshape(n_nodes * (k))

    return (sources, targets)

In [25]:
# Supposedly more efficient one from https://arxiv.org/abs/2206.14286

@partial(jax.jit, static_argnums=(1,))
def nearest_neighbors_ann(x, k):
    """Algorithm from https://arxiv.org/abs/2206.14286"""

    dots = np.einsum('ik,jk->ij', x, x)
    db_half_norm = np.linalg.norm(x, axis=1) ** 2 / 2.0
    dists = db_half_norm - dots
    dist, neighbours = jax.lax.approx_min_k(dists, k=k, recall_target=0.95)
    sources = np.arange(x.shape[0]).repeat(k)
    targets = neighbours.reshape(x.shape[0] * (k))
    return (sources, targets)

In [26]:
sources, targets = nearest_neighbors(x, k=k)
sources_ann, targets_ann = nearest_neighbors_ann(x, k=k)

In [27]:
np.array_equal(sources_ann, sources), np.array_equal(targets_ann, targets)

(Array(True, dtype=bool), Array(False, dtype=bool))

In [28]:
(targets_ann == targets).sum()

Array(999888, dtype=int32)

In [29]:
%%timeit
sources_ann, targets_ann = nearest_neighbors_ann(x, k=k)

21.4 ms ± 5.65 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [30]:
%%timeit
sources, targets = nearest_neighbors(x, k=k)

28.7 ms ± 4.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Using jax-md

Scale up without having to recalculate pairwise distance matrix.

In [51]:
from jax_md import space, partition

In [87]:
class RadiusSearch:
    """ Update-able radius graph
    """
    def __init__(self, box_size, cutoff, boundary_cond="free", capacity_multiplier=1.5):
        
        self.box_size = np.array(box_size)
        
        if boundary_cond == "free":    
            self.displacement_fn, _ = space.free()  
        elif boundary_cond == "periodic":
            self.displacement_fn, _ = space.periodic(self.box_size)
        else:
            raise NotImplementedError
            
        self.disp = jax.vmap(self.displacement_fn)
        self.dist = jax.vmap(space.metric(self.displacement_fn))
        self.cutoff = cutoff
        self.neighbor_list_fn = partition.neighbor_list(self.displacement_fn,
                                                       self.box_size,
                                                       cutoff,
                                                       format=partition.Sparse,
                                                       dr_threshold= cutoff / 6.,
                                                       mask_self=False,
                                                       capacity_multiplier=capacity_multiplier)
        
        self.neighbor_list_fn_jit = jax.jit(self.neighbor_list_fn)
        self.neighbor_dist_jit = self.displacement_fn
        
        # Each time number of neighbours buffer overflows, reallocate
        self.n_times_reallocated = 0

    def init_neighbor_lst(self, pos):
        """ Allocate initial neighbour list
        """
        pos = np.mod(pos, self.box_size)
        nbr = self.neighbor_list_fn.allocate(pos)
        return nbr

    def update_neighbor_lst(self, pos, nbr):
        pos = np.mod(pos, self.box_size)
        nbr_update = jax.vmap(self.neighbor_list_fn_jit.update, in_axes=(0,None))(pos, nbr)
        
        # If buffer overflows, update capacity of neighbours.
        # NOTE: This reallocation strategy might be more efficient: https://github.com/jax-md/jax-md/issues/192#issuecomment-1114002995
        if np.any(nbr_update.did_buffer_overflow):
            nbr = self.neighbor_list_fn.allocate(pos[0], extra_capacity=2 ** self.n_times_reallocated)
            self.n_times_reallocated += 1
        
        return nbr_update, nbr


Let's see if it can scale to 50000 nodes.

In [88]:
n_batch = 4
n_nodes = 50000
n_pos = 3

# x = np.load("/n/holyscratch01/iaifi_lab/ccuesta/data_for_sid/halos.npy")[:n_batch, :n_nodes,:n_pos]
x = jax.random.uniform(rng, (n_batch, n_nodes, n_pos)) * 1000.

In [89]:
ns = RadiusSearch(1000., 100.)
nbr = ns.init_neighbor_lst(x[0])

In [90]:
# %%timeit
nbr_update, nbr = ns.update_neighbor_lst(x, nbr)

In [91]:
nbr_update.idx

Array([[[29664, 30783, 31251, ..., 50000, 50000, 50000],
        [    0,     0,     0, ..., 50000, 50000, 50000]],

       [[    0,   471,   552, ..., 50000, 50000, 50000],
        [    0,     0,     0, ..., 50000, 50000, 50000]],

       [[15035, 17264, 17428, ..., 50000, 50000, 50000],
        [    0,     0,     0, ..., 50000, 50000, 50000]],

       [[45270, 45927, 47561, ..., 50000, 50000, 50000],
        [    0,     0,     0, ..., 50000, 50000, 50000]]], dtype=int32)

In [92]:
nbr_update.idx.shape

(4, 2, 21889268)