In [33]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import numpy as np

sys.path.append("..")
from dm21cm.interpolators import BatchInterpolator as BatchInterpolatorSciPy
from dm21cm.interpolators_jax import BatchInterpolator as BatchInterpolatorJax

In [26]:
rs = 20.
in_spec = np.linspace(3., 4., 500)
n_in = np.linspace(0.5, 1.5, 64**3)
x_in = np.linspace(0.01, 0.99, 64**3)
w = np.linspace(0.5, 1.5, 64**3)

In [None]:
bi_scipy = BatchInterpolatorSciPy(f"{os.environ['DM21CM_DATA_DIR']}/tf/230629/phot/phot_prop.h5")

In [34]:
bi_jax = BatchInterpolatorJax(f"{os.environ['DM21CM_DATA_DIR']}/tf/230629/phot/phot_prop.h5")

| Exec time (s) on | 48CPU 32GB mem        | 48CPU A100            |
|------------------|-----------------------|-----------------------|
| scipy (base)     | 3.997290 +/- 0.004115 | 4.304705 +/- 0.068191 |
| jax (base)       | 0.506022 +/- 0.006346 | 0.010241 +/- 0.000188 |

In [15]:
result_scipy = bi_scipy(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w)

In [28]:
result_jax = bi_jax(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w)

In [21]:
%timeit result_scipy = bi_scipy(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w)

3.95 s ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
%timeit result_jax = bi_jax(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w).block_until_ready()

157 ms ± 2.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
result_scipy / result_jax

In [31]:
type(result_jax)

jaxlib.xla_extension.ArrayImpl

## A. Interpolators from scratch

In [2]:
z = np.random.uniform(0, 1, 1000*500*500).reshape(10, 10, 10, 500, 500)
a0 = np.sort(list(np.random.uniform(0, 1, 8)) + [0., 1.])
a1 = np.sort(list(np.random.uniform(0, 1, 8)) + [0., 1.])
a2 = np.sort(list(np.random.uniform(0, 1, 8)) + [0., 1.])
in_spec = np.random.uniform(0, 1, 500)

In [6]:
in0 = np.random.uniform(0, 1)
in1 = np.random.uniform(0, 1, 64**3)
in2 = np.random.uniform(0, 1, 64**3)
in12 = np.stack([in1, in2], axis=-1)
w = np.random.uniform(0, 1, 64**3)

## 1. Scipy

In [10]:
from scipy import interpolate

In [15]:
def fscipy():
    d = np.einsum('i,abcio->abco', in_spec, z) # (0, 1, 2, out)
    d2 = interpolate.interp1d(a0, d, axis=0, copy=False)(in0) # (1, 2, out)
    interpolator = interpolate.RegularGridInterpolator((a1, a2), d2)
    return np.dot(w, interpolator(in12))

In [18]:
%timeit s = fscipy()

3.95 s ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
in012 = np.stack([np.full_like(in1, in0), in1, in2], axis=-1)

In [21]:
def fscipy2():
    d = np.einsum('i,abcio->abco', in_spec, z) # (0, 1, 2, out)
    interpolator = interpolate.RegularGridInterpolator((a0, a1, a2), d)
    return np.dot(w, interpolator(in012))

In [22]:
%timeit s2 = fscipy2()

7.72 s ± 4.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
np.all(np.isclose(fscipy(), fscipy2()))

True

## 2 Jax

In [24]:
import jax.numpy as jnp
from jax import jit, vmap

In [25]:
def interp1d(f, x, xv):
    """Interpolates f(x) at values in xvs. Does not do bound checks.
    f : (n>=1 D) array of function value.
    x : 1D array of input value, corresponding to first dimension of f.
    xv : x values to interpolate.
    """
    li = jnp.searchsorted(x, xv) - 1
    lx = x[li]
    rx = x[li+1]
    p = (xv-lx) / (rx-lx)
    fl = f[li]
    return fl + (f[li+1]-fl) * p

interp1d_vmap = jit(vmap(interp1d, in_axes=(None, None, 0)))


def interp2d(f, x0, x1, xv):
    """Interpolates f(x) at values in xvs. Does not do bound checks.
    f : (n>=2 D) array of function value.
    x0 : 1D array of input value, corresponding to first dimension of f.
    x1 : 1D array of input value, corresponding to second dimension of f.
    xv : [x0, x1] values to interpolate.
    """
    xv0, xv1 = xv
    
    li0 = jnp.searchsorted(x0, xv0, side='right') - 1
    lx0 = x0[li0]
    rx0 = x0[li0+1]
    wl0 = (rx0-xv0) / (rx0-lx0)
    wr0 = 1 - wl0
    
    li1 = jnp.searchsorted(x1, xv1, side='right') - 1
    lx1 = x1[li1]
    rx1 = x1[li1+1]
    wl1 = (rx1-xv1) / (rx1-lx1)
    wr1 = 1 - wl1
    
    return f[li0,li1]*wl0*wl1 + f[li0+1,li1]*wr0*wl1 + f[li0,li1+1]*wl0*wr1 + f[li0+1,li1+1]*wr0*wr1

interp2d_vmap = jit(vmap(interp2d, in_axes=(None, None, None, 0)))

In [26]:
def fjax():
    d = jnp.einsum('i,abcio->abco', in_spec, z) # (0, 1, 2, out)
    d2 = interp1d(d, a0, in0) # (1, 2, out)
    return jnp.dot(w, interp2d_vmap(d2, a1, a2, in12))

In [None]:
np.isclose?

In [32]:
np.all(np.isclose(fscipy(), fjax(), atol=1e100))

True

In [33]:
%timeit s3 = fjax()

514 ms ± 9.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
