In [1]:
import jax
import jax.numpy as jnp

In [2]:
N_pts = 200_000
N_pix = 1_000_000

In [3]:
key = jax.random.PRNGKey(123)
pix_idxs = jax.vmap(
    lambda key: jax.random.randint(key, (), 0, N_pts // 4)
)(jax.random.split(key, N_pts))
pix_idxs.shape

(200000,)

In [4]:
def _get_first_n(pix_idxs, n):
    """
    Returns an array of shape (N_pix, n) containing the first n indices
    into N_pts containing index i.
    """
    arr = -jnp.ones((pix_idxs.shape[0], n), dtype=int)
    for t in range(n):
        vals, idx, cts = jnp.unique(pix_idxs, size=N_pix, return_counts=True, return_index=True)
        arr = arr.at[idx, t].set(vals)
        pix_idxs = pix_idxs.at[idx].set(-1)
    return arr
get_first_n = jax.jit(_get_first_n, static_argnums=(1,))

In [5]:
import time

In [6]:
st = time.time()
get_first_n(pix_idxs, 1)
time.time() - st

1.5504238605499268

In [10]:
tiled = jnp.tile(pix_idxs, (400, 1))

In [11]:
get_first_n_vmapped = jax.jit(jax.vmap(get_first_n, in_axes=(0, None)), static_argnums=(1,))

In [14]:
st = time.time()
get_first_n_vmapped(tiled, 1)
time.time() - st

0.6735115051269531

In [None]:
# 400 max vs 4000 max; .67s vs .027s (25x faster)

In [15]:
def _get_first_n_v2(key ,pix_idxs, n):
    arr = -jnp.ones((pix_idxs.shape[0], n), dtype=int)
    pi = jax.random.shuffle(key, pix_idxs)
    random_indices = jax.random.randint(key, pix_idxs.shape, 0, 10)
    return arr.at[pix_idxs, random_indices].set(jnp.arange(pix_idxs.shape[0]))
get_first_n_v2 = jax.jit(_get_first_n_v2, static_argnums=(2,))

In [16]:
st = time.time()
a = get_first_n_v2(key, pix_idxs, 1)
time.time() - st
# ~~ 0.0012 seconds

0.7456269264221191

In [26]:
tiled = jnp.tile(pix_idxs, (500, 1))

In [27]:
get_first_n_v2_vmapped = jax.jit(jax.vmap(get_first_n_v2, in_axes=(None, 0, None)), static_argnums=(2,))

In [28]:
st = time.time()
a = get_first_n_v2_vmapped(key, tiled, 10)
time.time() - st

0.7378904819488525

In [18]:
vals, idx, cts = jnp.unique(pix_idxs, size=100_000, return_counts=True, return_index=True)
vals

Array([    0,     1,     2, ..., 99997, 99998, 99999], dtype=int32)