In [1]:
%reload_ext autoreload
%autoreload 2

import os
import sys
import numpy as np

import cProfile
import pstats

sys.path.append("..")
from dm21cm.interpolators import BatchInterpolator as BatchInterpolatorSciPy
from dm21cm.interpolators_jax import BatchInterpolator as BatchInterpolatorJax

In [2]:
import h5py
import numpy as np
from functools import partial

import jax.numpy as jnp
from jax import jit, vmap, device_put

EPSILON = 1e-6

In [16]:
@jit
def interp1d(fp, xp, x):
    """Interpolates f(x), described by points fp and xp, at values in x.

    Args:
        fp (array): n(>=1)D array of function values. First dimension will be interpolated.
        xp (array): 1D array of x values.
        x (array): x values to interpolate.

    Notes:
        xp must be sorted. Does not do bound checks.
    """
    il = jnp.searchsorted(xp, x, side='right') - 1
    wl = (xp[il+1] - x) / (xp[il+1] - xp[il])
    return fp[il] * wl + fp[il+1] * (1 - wl)

@jit
@partial(vmap, in_axes=(None, None, None, 0))
def interp2d(fp, x0p, x1p, x01):
    """Interpolates f(x0, x1), described by points fp, x0p, and x1p, at values in x01.

    Args:
        fp (array): n(>=2)D array of function values. First two dimensions will be interpolated.
        x0p (array): 1D array of x0 values (first dimension of fp).
        x1p (array): 1D array of x1 values (second dimension of fp).
        x01 (array): [x0, x1] values to interpolate.

    Notes:
        x0p and x1p must be sorted. Does not do bound checks.
    """
    x0, x1 = x01
    
    i0l = jnp.searchsorted(x0p, x0, side='right') - 1
    wl0 = (x0p[i0l+1] - x0) / (x0p[i0l+1] - x0p[i0l])
    wr0 = 1 - wl0
    
    i1l = jnp.searchsorted(x1p, x1, side='right') - 1
    wl1 = (x1p[i1l+1] - x1) / (x1p[i1l+1] - x1p[i1l])
    wr1 = 1 - wl1
    
    return fp[i0l,i1l]*wl0*wl1 + fp[i0l+1,i1l]*wr0*wl1 + fp[i0l,i1l+1]*wl0*wr1 + fp[i0l+1,i1l+1]*wr0*wr1

In [17]:
def bound_action(v, absc, out_of_bounds_action):
    if out_of_bounds_action == 'clip':
        return jnp.clip(v, jnp.min(absc)*(1+EPSILON), jnp.max(absc)/(1+EPSILON))
    else:
        if not (jnp.all(v >= jnp.min(absc)) and jnp.all(v <= jnp.max(absc))):
            raise ValueError('value out of bounds.')
        return v

In [24]:
class BatchInterpolator:
    """Interpolator for multidimensional data. Currently supports axes=('rs', 'Ein', 'nBs', 'x', 'out').

    Args:
        filename (str): HDF5 data file name.
        on_device (bool, optional): Whether to save data on device (GPU). Default: True.

    Attributes:
        axes (list): List of axis names.
        abscs (dict): Abscissas of axes.
        data (array): Grid data consistent with axes and abscs.
    """
    
    def __init__(self, filename, on_device=True):
        
        with h5py.File(filename, 'r') as hf:
            self.axes = hf['axes'][:]
            self.abscs = {}
            for k, item in hf['abscs'].items():
                self.abscs[k] = item[:]
            self.data = jnp.array(hf['data'][:]) # load into memory

        self.on_device = on_device
        if self.on_device:
            self.data = device_put(self.data)
            
        self.fixed_in_spec = None
        self.fixed_in_spec_data = None
    
    
    def set_fixed_in_spec(self, in_spec):
        
        self.fixed_in_spec = in_spec
        self.fixed_in_spec_data = jnp.einsum('e,renxo->rnxo', in_spec, self.data)
        if self.on_device:
            self.fixed_in_spec_data = device_put(self.fixed_in_spec_data)
        
        
    def __call__(self, rs=None, in_spec=None, nBs_s=None, x_s=None,
                 sum_result=False, sum_weight=None, sum_batch_size=256**3,
                 out_of_bounds_action='error'):
        """Batch interpolate in (nBs and) x directions.
        
        First sum with in_spec (with caching), then interpolate to a rs point,
        then perform the interpolation on [(nBs_s), x_s]. If sum_result is True,
        sum over all interpolated value.
        
        Parameters:
            rs : [1]
            in_spec : [N * ...]
            nBs_s : [1]
            x_s : [1]
            sum_result : if True, return average in the batch dimension.
            sum_weight : if None, just sum. otherwise dot.
            sum_batch_size : perform batch interpolation (and averaging) in batches of this size.
            out_of_bounds_action : {'error', 'clip'}
        
        Return:
            interpolated box or average of interpolated box.
        """

        rs = bound_action(rs, self.abscs['rs'], out_of_bounds_action)
        x_s = bound_action(x_s, self.abscs['x'], out_of_bounds_action)
        nBs_s = bound_action(nBs_s, self.abscs['nBs'], out_of_bounds_action)
        
        # 1. rs interpolation and in_spec sum
        if jnp.all(in_spec == self.fixed_in_spec):
            in_spec_data = self.fixed_in_spec_data
            data_at_rs_at_spec = interp1d(in_spec_data, jnp.array(self.abscs['rs']), rs)
        else:
            data_at_rs = interp1d(self.data, jnp.array(self.abscs['rs']), rs) # enxo
            data_at_rs_at_spec = jnp.tensordot(in_spec, data_at_rs, axes=(0, 0)) # nxo
    
        # 2. (nBs) x interpolation (and sum)
        if not sum_result:
            nBs_x_in = jnp.stack([nBs_s, x_s], axis=-1)
            return interp2d(
                data_at_rs_at_spec,
                jnp.array(self.abscs['nBs']),
                jnp.array(self.abscs['x']),
                nBs_x_in
            )
        
        else:
            nBs_x_in = jnp.stack([nBs_s, x_s], axis=-1)
            interp_result = interp2d(
                data_at_rs_at_spec,
                jnp.array(self.abscs['nBs']),
                jnp.array(self.abscs['x']),
                nBs_x_in
            )
            if sum_weight is None:
                return jnp.sum(interp_result, axis=0)
            else:
                return jnp.dot(sum_weight, interp_result)
        
    
    def point_interp(self, rs=None, nBs=None, x=None, out_of_bounds_action='error'):
        """Returns the transfer function at a (rs, nBs, x) point."""

        rs = bound_action(rs, self.abscs['rs'], out_of_bounds_action)
        nBs = bound_action(nBs, self.abscs['nBs'], out_of_bounds_action)
        x = bound_action(x, self.abscs['x'], out_of_bounds_action)
        
        data = interp1d(self.data, self.abscs['rs'], rs) # enxo
        data = np.einsum('enxo -> nxeo', data) # nxeo
        data = interp1d(data, self.abscs['nBs'], nBs) # xeo
        data = interp1d(data, self.abscs['x'], x) # eo
        return data

## fake data

In [3]:
rs = 10.
in_spec = jnp.array(np.random.uniform(size=(500,)))

box_dim = 128

n_in = jnp.array(np.random.uniform(size=(box_dim**3,)))
x_in = jnp.array(np.random.uniform(size=(box_dim**3,)))
w = jnp.array(np.random.uniform(size=(box_dim**3,)))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
bi = BatchInterpolatorJax(f"{os.environ['DM21CM_DATA_DIR']}/tf/zf01/data/phot_scat.h5")

In [38]:
%%timeit

result = bi(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w, out_of_bounds_action='clip').block_until_ready()

545 ms ± 3.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
pr = cProfile.Profile()
pr.enable()
for i in range(20):
    result = bi(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w, out_of_bounds_action='clip').block_until_ready()
pr.disable()

In [6]:
ps = pstats.Stats(pr).sort_stats('cumulative')
# Print the top 30 functions
ps.print_stats(30)

         589413 function calls (584284 primitive calls) in 11.368 seconds

   Ordered by: cumulative time
   List reduced from 1369 to 30 due to restriction <30>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000   11.379    5.689 /n/home07/yitians/.conda/envs/dm21cm/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3490(run_code)
        2    0.000    0.000   11.379    5.689 {built-in method builtins.exec}
        1    0.166    0.166   11.379   11.379 /tmp/ipykernel_62324/2194605283.py:1(<module>)
       20    8.902    0.445   11.213    0.561 /n/home07/yitians/dm21cm/DM21cm/benchmarking/../dm21cm/interpolators_jax.py:111(__call__)
  469/260    0.003    0.000    2.221    0.009 /n/home07/yitians/.conda/envs/dm21cm/lib/python3.11/site-packages/jax/_src/profiler.py:311(wrapper)
 1058/479    0.002    0.000    2.154    0.004 /n/home07/yitians/.conda/envs/dm21cm/lib/python3.11/site-packages/jax/_src/core.py:388(bind_with_trace)
  

<pstats.Stats at 0x7f163ef43790>

In [None]:
result

In [14]:
3*80*2000/60/60

133.33333333333334

## old

In [26]:
rs = 20.
in_spec = np.linspace(3., 4., 500)
n_in = np.linspace(0.5, 1.5, 64**3)
x_in = np.linspace(0.01, 0.99, 64**3)
w = np.linspace(0.5, 1.5, 64**3)

In [None]:
bi_scipy = BatchInterpolatorSciPy(f"{os.environ['DM21CM_DATA_DIR']}/tf/230629/phot/phot_prop.h5")

In [34]:
bi_jax = BatchInterpolatorJax(f"{os.environ['DM21CM_DATA_DIR']}/tf/230629/phot/phot_prop.h5")

| Exec time (s) on | 48CPU 32GB mem        | 48CPU A100            |
|------------------|-----------------------|-----------------------|
| scipy (base)     | 3.997290 +/- 0.004115 | 4.304705 +/- 0.068191 |
| jax (base)       | 0.506022 +/- 0.006346 | 0.010241 +/- 0.000188 |

In [15]:
result_scipy = bi_scipy(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w)

In [28]:
result_jax = bi_jax(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w)

In [21]:
%timeit result_scipy = bi_scipy(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w)

3.95 s ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
%timeit result_jax = bi_jax(rs=rs, in_spec=in_spec, nBs_s=n_in, x_s=x_in, sum_result=True, sum_weight=w).block_until_ready()

157 ms ± 2.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
result_scipy / result_jax

In [31]:
type(result_jax)

jaxlib.xla_extension.ArrayImpl

## A. Interpolators from scratch

In [2]:
z = np.random.uniform(0, 1, 1000*500*500).reshape(10, 10, 10, 500, 500)
a0 = np.sort(list(np.random.uniform(0, 1, 8)) + [0., 1.])
a1 = np.sort(list(np.random.uniform(0, 1, 8)) + [0., 1.])
a2 = np.sort(list(np.random.uniform(0, 1, 8)) + [0., 1.])
in_spec = np.random.uniform(0, 1, 500)

In [6]:
in0 = np.random.uniform(0, 1)
in1 = np.random.uniform(0, 1, 64**3)
in2 = np.random.uniform(0, 1, 64**3)
in12 = np.stack([in1, in2], axis=-1)
w = np.random.uniform(0, 1, 64**3)

## 1. Scipy

In [10]:
from scipy import interpolate

In [15]:
def fscipy():
    d = np.einsum('i,abcio->abco', in_spec, z) # (0, 1, 2, out)
    d2 = interpolate.interp1d(a0, d, axis=0, copy=False)(in0) # (1, 2, out)
    interpolator = interpolate.RegularGridInterpolator((a1, a2), d2)
    return np.dot(w, interpolator(in12))

In [18]:
%timeit s = fscipy()

3.95 s ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
in012 = np.stack([np.full_like(in1, in0), in1, in2], axis=-1)

In [21]:
def fscipy2():
    d = np.einsum('i,abcio->abco', in_spec, z) # (0, 1, 2, out)
    interpolator = interpolate.RegularGridInterpolator((a0, a1, a2), d)
    return np.dot(w, interpolator(in012))

In [22]:
%timeit s2 = fscipy2()

7.72 s ± 4.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
np.all(np.isclose(fscipy(), fscipy2()))

True

## 2 Jax

In [24]:
import jax.numpy as jnp
from jax import jit, vmap

In [25]:
def interp1d(f, x, xv):
    """Interpolates f(x) at values in xvs. Does not do bound checks.
    f : (n>=1 D) array of function value.
    x : 1D array of input value, corresponding to first dimension of f.
    xv : x values to interpolate.
    """
    li = jnp.searchsorted(x, xv) - 1
    lx = x[li]
    rx = x[li+1]
    p = (xv-lx) / (rx-lx)
    fl = f[li]
    return fl + (f[li+1]-fl) * p

interp1d_vmap = jit(vmap(interp1d, in_axes=(None, None, 0)))


def interp2d(f, x0, x1, xv):
    """Interpolates f(x) at values in xvs. Does not do bound checks.
    f : (n>=2 D) array of function value.
    x0 : 1D array of input value, corresponding to first dimension of f.
    x1 : 1D array of input value, corresponding to second dimension of f.
    xv : [x0, x1] values to interpolate.
    """
    xv0, xv1 = xv
    
    li0 = jnp.searchsorted(x0, xv0, side='right') - 1
    lx0 = x0[li0]
    rx0 = x0[li0+1]
    wl0 = (rx0-xv0) / (rx0-lx0)
    wr0 = 1 - wl0
    
    li1 = jnp.searchsorted(x1, xv1, side='right') - 1
    lx1 = x1[li1]
    rx1 = x1[li1+1]
    wl1 = (rx1-xv1) / (rx1-lx1)
    wr1 = 1 - wl1
    
    return f[li0,li1]*wl0*wl1 + f[li0+1,li1]*wr0*wl1 + f[li0,li1+1]*wl0*wr1 + f[li0+1,li1+1]*wr0*wr1

interp2d_vmap = jit(vmap(interp2d, in_axes=(None, None, None, 0)))

In [26]:
def fjax():
    d = jnp.einsum('i,abcio->abco', in_spec, z) # (0, 1, 2, out)
    d2 = interp1d(d, a0, in0) # (1, 2, out)
    return jnp.dot(w, interp2d_vmap(d2, a1, a2, in12))

In [None]:
np.isclose?

In [32]:
np.all(np.isclose(fscipy(), fjax(), atol=1e100))

True

In [33]:
%timeit s3 = fjax()

514 ms ± 9.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
