In [1]:
from jax import config
config.update("jax_enable_x64", True)

In [62]:
import numpy as np
from functools import partial
from scipy.optimize._numdiff import approx_derivative
import jax.numpy as jnp
from jax import jit, grad
from refnx.reflect._jax_reflect import abeles_jax, jabeles
from refnx.reflect import abeles
from refnx.reflect.reflect_model import gauss_legendre, _smeared_kernel_pointwise, available_backends, get_reflect_backend

_FWHM = 2 * np.sqrt(2 * np.log(2.0))
_INTLIMIT = 3.5

q = np.linspace(0.01, 0.5, 1000)
w = np.array([[0, 2.07, 0, 0],
              [100, 3.47, 0, 3],
              [500, -0.5, 0.00001, 3],
              [0, 6.36, 0, 3]])

In [11]:
np.testing.assert_allclose(jabeles(q, w), abeles(q, w))

In [51]:
def jax_smeared_kernel_pointwise(qvals, w, dqvals, quad_order=17, threads=-1):
    # get the gauss-legendre weights and abscissae
    abscissa, weights = gauss_legendre(quad_order)

    # get the normal distribution at that point
    prefactor = 1.0 / np.sqrt(2 * np.pi)

    def gauss(x):
        return np.exp(-0.5 * x * x)

    gaussvals = prefactor * gauss(abscissa * _INTLIMIT)

    # integration between -3.5 and 3.5 sigma
    va = qvals - _INTLIMIT * dqvals / _FWHM
    vb = qvals + _INTLIMIT * dqvals / _FWHM

    va = va[:, np.newaxis]
    vb = vb[:, np.newaxis]

    qvals_for_res = (np.atleast_2d(abscissa) * (vb - va) + vb + va) / 2.0
    smeared_rvals = jabeles(qvals_for_res, w)

    smeared_rvals = np.reshape(smeared_rvals, (qvals.size, abscissa.size))

    smeared_rvals *= np.atleast_2d(gaussvals * weights)
    return np.sum(smeared_rvals, 1) * _INTLIMIT

smeared = jit(jax_smeared_kernel_pointwise)

In [52]:
np.testing.assert_allclose(smeared(q, w, 0.05 * q), _smeared_kernel_pointwise(q, w, 0.05 * q))

In [53]:
data = abeles(q, w)

In [76]:
def chi2(q, w):
    return np.sum((smeared(q, w, 0.05 * q) - data)**2)

def chi2_2(q, w):
    w = np.reshape(w, (-1, 4))
    return np.sum((smeared(q, w, 0.05 * q) - data)**2)  

def chi2_3(q, w):
    w = np.reshape(w, (-1, 4))
    return np.sum((_smeared_kernel_pointwise(q, w, 0.05 * q) - data)**2)  

In [71]:
gsmeared = grad(chi2, argnums=1)
%timeit gsmeared(q, w)

12.7 ms ± 40.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [70]:
part_chi2 = partial(chi2_2, q)
%timeit approx_derivative(part_chi2, w.ravel())

110 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [77]:
part_chi3 = partial(chi2_3, q)
%timeit approx_derivative(part_chi3, w.ravel())

15.4 ms ± 54.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [72]:
%timeit abeles_jax(q, w)

264 µs ± 1.13 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [73]:
%timeit abeles(q, w)

70.5 µs ± 257 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [74]:
%timeit _smeared_kernel_pointwise(q, w, 0.05 * q)

463 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [75]:
%timeit smeared(q, w, 0.05 * q)

3.27 ms ± 3.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
