In [1]:
from jax import config
config.update("jax_enable_x64", True)

In [62]:
import numpy as np
from functools import partial
from scipy.optimize._numdiff import approx_derivative
import jax.numpy as jnp
from jax import jit, grad
from refnx.reflect._jax_reflect import abeles_jax, jabeles
from refnx.reflect import abeles
from refnx.reflect.reflect_model import gauss_legendre, _smeared_kernel_pointwise, available_backends, get_reflect_backend

_FWHM = 2 * np.sqrt(2 * np.log(2.0))
_INTLIMIT = 3.5

q = np.linspace(0.01, 0.5, 1000)
w = np.array([[0, 2.07, 0, 0],
              [100, 3.47, 0, 3],
              [500, -0.5, 0.00001, 3],
              [0, 6.36, 0, 3]])

In [11]:
np.testing.assert_allclose(jabeles(q, w), abeles(q, w))

In [51]:
def jax_smeared_kernel_pointwise(qvals, w, dqvals, quad_order=17, threads=-1):
    # get the gauss-legendre weights and abscissae
    abscissa, weights = gauss_legendre(quad_order)

    # get the normal distribution at that point
    prefactor = 1.0 / np.sqrt(2 * np.pi)

    def gauss(x):
        return np.exp(-0.5 * x * x)

    gaussvals = prefactor * gauss(abscissa * _INTLIMIT)

    # integration between -3.5 and 3.5 sigma
    va = qvals - _INTLIMIT * dqvals / _FWHM
    vb = qvals + _INTLIMIT * dqvals / _FWHM

    va = va[:, np.newaxis]
    vb = vb[:, np.newaxis]

    qvals_for_res = (np.atleast_2d(abscissa) * (vb - va) + vb + va) / 2.0
    smeared_rvals = jabeles(qvals_for_res, w)

    smeared_rvals = np.reshape(smeared_rvals, (qvals.size, abscissa.size))

    smeared_rvals *= np.atleast_2d(gaussvals * weights)
    return np.sum(smeared_rvals, 1) * _INTLIMIT

smeared = jit(jax_smeared_kernel_pointwise)

In [52]:
np.testing.assert_allclose(smeared(q, w, 0.05 * q), _smeared_kernel_pointwise(q, w, 0.05 * q))

In [53]:
data = abeles(q, w)

In [65]:
def chi2(q, w):
    return np.sum((smeared(q, w, 0.05 * q) - data)**2)

def chi2_2(q, w):
    w = np.reshape(w, (-1, 4))
    return np.sum((smeared(q, w, 0.05 * q) - data)**2)       

In [69]:
gsmeared = grad(chi2, argnums=1)
gsmeared(q, w)

Array([[ 0.00000000e+00, -1.23946190e-01,  0.00000000e+00,
         0.00000000e+00],
       [ 8.33112085e-05,  4.69593792e-03, -2.19750346e-03,
        -4.68062868e-07],
       [ 1.31681133e-05, -6.29836800e-03,  3.57718214e-02,
        -1.04544939e-05],
       [ 0.00000000e+00,  1.25548620e-01,  1.04202793e-01,
        -2.22486536e-05]], dtype=float64)

In [68]:
part_chi2 = partial(chi2_2, q)
approx_derivative(part_chi2, w.ravel())

array([ 0.00000000e+00, -1.23946207e-01,  0.00000000e+00,  0.00000000e+00,
        8.33112085e-05,  4.69593792e-03,  0.00000000e+00, -4.68063359e-07,
        1.31681142e-05, -6.29836801e-03,  3.57718149e-02, -1.04544951e-05,
        0.00000000e+00,  1.25548788e-01,  0.00000000e+00, -2.22486538e-05])