In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as np
import jraph
import flax.linen as nn

from functools import partial

In [3]:
import sys
sys.path.append("../")

from models.graph_utils import nearest_neighbors
import jaxkdtree

In [42]:
n_nodes = 5000

x = np.load("data/halos_small.npy")[:4, :n_nodes, :3]

x.shape

(4, 5000, 3)

In [43]:
%%timeit
# kD-tree
res = jax.vmap(jaxkdtree.kNN, in_axes=(0,None,None))(x[:4, :, :3], 16, 100.0)
res.shape

10.9 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [44]:
%%timeit
# Pairwise distances
sources, targets, dist = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4, :, :3], 16)
np.array([sources, targets])

23.2 ms ± 9.01 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [45]:
sources, targets, dist = jax.vmap(nearest_neighbors, in_axes=(0, None))(x[:4, :, :3], 8)

In [108]:
@jax.custom_vjp
def nearest_neighbors_kd(x, k, max_radius=2000.0):
    # Implementation of nearest neighbors op
    res = jaxkdtree.kNN(x, 8, max_radius)
    sources = np.repeat(np.arange(x.shape[0]), 8)
    targets = res.reshape(-1)
    dr = x[sources] - x[targets]
    distances = np.sum(dr**2, axis=-1)
    return sources, targets, distances

def nearest_neighbors_kd_fwd(x, k, max_radius):
    sources, targets, distances = nearest_neighbors_kd(x, k, max_radius)
    return (sources, targets, distances), (0.,)

def nearest_neighbors_kd_bwd(res, g):
    return lambda g: g, lambda g: None

nearest_neighbors_kd.defvjp(nearest_neighbors_kd_fwd, nearest_neighbors_kd_bwd)

In [109]:
sources_knn, targets_knn, dist_knn = jax.vmap(nearest_neighbors_kd, in_axes=(0,None,None))(x[:4, :, :3], 8, 1000.)

In [110]:
np.allclose(sources, sources_knn), np.allclose(targets, targets_knn), np.allclose(dist, dist_knn)

(Array(True, dtype=bool), Array(True, dtype=bool), Array(True, dtype=bool))

In [118]:
import jax._src.test_util as jtu

def dummy_fn(x):
    sources_knn, targets_knn, dist_knn = jax.vmap(nearest_neighbors_kd, in_axes=(0,None,None))(x, 8, 1000.)
    return np.mean(dist_knn)

dummy_fn(x[:4])

jtu.check_grads(dummy_fn, (x[:4],), modes=["rev"], order=1)

In [121]:
@jax.custom_vjp
def f(x, y):
  return np.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (np.cos(x), np.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)

###

def dummy_fn(x, y):
    return f(x, y)


dummy_fn(1., 3.)

jtu.check_grads(dummy_fn, (1.,3.,), modes=["rev"], order=1)

In [156]:
@jax.custom_vjp
def nearest_neighbors_kd(x, k, max_radius=2000.0):
    # Implementation of nearest neighbors op
    res = jaxkdtree.kNN(x, 8, max_radius)
    sources = np.repeat(np.arange(x.shape[0]), 8)
    targets = res.reshape(-1)
    dr = x[sources] - x[targets]
    distances = np.sum(dr**2, axis=-1)
    return sources, targets, distances

def nearest_neighbors_kd_fwd(x, k, max_radius):

  sources, targets, distances = nearest_neighbors_kd(x, k, max_radius)

  # Return dummy outputs
  return (0., 0., 0.), (sources, targets, distances)

def nearest_neighbors_kd_bwd(res, g):
  return lambda g: g, lambda g: None

nearest_neighbors_kd.defvjp(nearest_neighbors_kd_fwd, nearest_neighbors_kd_bwd)

###

def dummy_fn(x):
    sources_knn, targets_knn, dist_knn = jax.vmap(nearest_neighbors_kd, in_axes=(0,None,None))(x, 8, 1000.)
    return np.mean(dist_knn)

dummy_fn(x[:4])

jax.value_and_grad(dummy_fn)(x[:1])

# jtu.check_grads(dummy_fn, (x[:4],), modes=["rev"], order=1)

ValueError: too many values to unpack (expected 3)