In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
print(f"GPU: {jax.devices()[0].device_kind}")

GPU: NVIDIA GeForce RTX 2080 Ti


In [3]:
import numpy as np

from jaxknn.scipy_kdtree import knn_scipy
from jaxknn.cuda_kdtree import knn_cuda

from jaxknn.utils import generate_uniform_random_points

In [4]:
points, box_size = generate_uniform_random_points(n_point=256**2, n_dim=2)
k = 16

In [8]:
points, box_size = generate_uniform_random_points(n_point=64**3, n_dim=3)
k = 16

In [5]:
points_cpu = np.asarray(points)

In [6]:
%%timeit
scipy_idx = knn_scipy(
    points=points_cpu, queries=points_cpu, k=k, box_size=box_size
)
scipy_idx.block_until_ready()

102 ms ± 2.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
%%timeit
cuda_idx = knn_cuda(
    points=points, queries=points, k=k, box_size=box_size,
)
cuda_idx.block_until_ready()

4.34 ms ± 12.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Fewer queries. Does not scale, so there are probably some overheads that need to be understood.

In [11]:
%%timeit
cuda_idx = knn_cuda(
    points=points, queries=points[:100], k=k, box_size=box_size,
)
cuda_idx.block_until_ready()

3.64 ms ± 31.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Get points into tree order first before doing knn query. Does not appear to speed things up at the moment.

In [8]:
from jaxknn.jax_kdtree import build_tree

nodes, order = build_tree(points, l_max=int(np.log2(points.shape[0])))

In [9]:
%%timeit
cuda_idx = knn_cuda(
    points=nodes, queries=points, k=k, box_size=box_size,
)
cuda_idx.block_until_ready()

4.41 ms ± 21.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
