In [1]:
from jax import config
config.update("jax_enable_x64", True)

In [2]:
from refnx._lib import flatten
import numpy as np
from refnx.analysis import Parameter, Model, Parameters, Objective
from refnx.dataset import Data1D
from refnx._lib import flatten, unique
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [3]:
np.random.seed(123)

# Choose the "true" parameters.
m_true = -0.9594
b_true = 4.294
f_true = 0.534

# Generate some synthetic data from the model.
N = 50000
x = np.sort(10 * np.random.rand(N))
yerr = 0.1 + 0.5 * np.random.rand(N)
y = m_true * x + b_true
y += np.abs(f_true * y) * np.random.randn(N)
y += yerr * np.random.randn(N)

data = Data1D(data=(x, y))

In [4]:
def make_evaluator(objective):
    # pars = list(flatten(objective.parameters))
    vpars = objective.varying_parameters()
    
    def func(pvs):
        for vpar, pv in zip(vpars, pvs):
            vpar._value = pv
        return objective.logl()

    return func, grad(func)

In [5]:
m = Parameter(1)
c = Parameter(0)

In [6]:
class Line(Model):
    def __init__(self, pars):
        self._parameters = Parameters(pars)
        self.fitfunc = None
        self.fcn_args = None
        self.fcn_kwds = None
        self.pars = pars

    def model(self, x, p=None, x_err=None):
        if p is not None:
            self.parameters.pvals = np.array(p)

        return self.parameters[0].value * x + self.parameters[1].value

In [7]:
l = Line([m, c])
m.vary = True

In [8]:
objective = Objective(l, data)

In [9]:
f, g = make_evaluator(objective)

In [10]:
f(jnp.array([2.0]))

Array(-4659324.74119686, dtype=float64)

In [11]:
g(jnp.array([-1.0]))

Array([1144222.39350771], dtype=float64)

In [12]:
objective.logl([1.0, 0.0])

-1622769.8377402509

In [13]:
type(l.parameters[0]._value), type(l.parameters[1]._value)

(numpy.float64, numpy.float64)

In [14]:
g([-1.])

[Array(1144222.39350771, dtype=float64, weak_type=True)]

In [15]:
# %timeit g(jnp.array([1.0]))

In [16]:
from scipy.optimize._numdiff import approx_derivative

In [17]:
# %timeit approx_derivative(objective.logl, [-1.])

In [18]:
from refnx.reflect import reflect_model, Slab, SLD, Structure, ReflectModel, abeles, use_reflect_backend
from refnx.reflect._jax_reflect import abeles_jax
reflect_model.kernel = abeles_jax

In [19]:
air = SLD(0.0)
si = SLD(2.07)
sio2 = SLD(3.47)
s = air | sio2(15, 3) | si(0, 3)
s[-2].thick.setp(vary=True, bounds=(10, 20))
s[-2].rough.setp(vary=True, bounds=(1, 6))
model = ReflectModel(s)
model.scale.setp(vary=True)
model.bkg.setp(vary=True)
sio2.real.setp(vary=True)
si.real.setp(vary=True)
s[-1].rough.setp(vary=True)
s[-2].rough.setp(vary=True)
data = Data1D('c_PLP0000708.dat')
objective = Objective(model, data)
arr = np.array(objective.varying_parameters())
sarr = np.array(objective.parameters)

In [20]:
with use_reflect_backend('c'):
    model.threads=1
    objective.setp(np.copy(arr))
    %timeit approx_derivative(objective.logl, arr, method='2-point')

3.88 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
reflect_model.kernel = abeles_jax
f, g = make_evaluator(objective)

In [21]:
%timeit g(arr)

11.8 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
for p in flatten(objective.parameters):
    print(type(p.value))

The basic timings for calculating d(Objective.logl) with finite differences vs `jax.grad` are clear. It's better to calculate the gradient using finite differences than use autograd.


| method | Time | 
|--------|------|
| finite differences | 3.88 ms |
| jax.grad | 11.8 ms |


I think the reason the difference is so stark is that the finite differences is due to the speed of the underlying reflectivity kernel. finite differences uses a C based kernel that is very fast (even when single threaded). In comparison `jax.grad` has to use a (jitted) JAX kernel which is way slower than the C-kernel. The speed comparison is a factor of 3! This means it's not worth using JAX for gradient estimation when trying to do NUTS sampling with `pymc`.