In [None]:
import numpy as np

import jax.numpy as jnp
from jax import grad, jacfwd, jacrev, jit
from jax.config import config
from jax.ops import index, index_add, index_update
config.update("jax_enable_x64", True)

from scipy.optimize._numdiff import approx_derivative


TINY = 1e-30
q = np.linspace(0.01, 0.5, 2)
w = np.array([[0, 2.07, 0, 0],
              [100, 3.47, 0.0001, 3],
#               [500, -0.5, 0.00001, 3],
              [0, 6.36, 0.0, 3]])

In [None]:
def abeles(layers, q):
    layers = np.reshape(layers, (layers.size//4, 4))
    qvals = np.asfarray(q)
    flatq = qvals.ravel()

    nlayers = layers.shape[0] - 2
    npnts = flatq.size

    kn = np.zeros((npnts, nlayers + 2), np.complex128)
    mi00 = np.ones((npnts, nlayers + 1), np.complex128)

    sld = np.zeros(nlayers + 2, np.complex128)

    # addition of TINY is to ensure the correct branch cut
    # in the complex sqrt calculation of kn.
    sld[1:] += (
        (layers[1:, 1] - layers[0, 1]) + 1j * (np.abs(layers[1:, 2]) + TINY)
    ) * 1.0e-6

    # kn is a 2D array. Rows are Q points, columns are kn in a layer.
    # calculate wavevector in each layer, for each Q point.
    kn[:] = np.sqrt(flatq[:, np.newaxis] ** 2.0 / 4.0 - 4.0 * np.pi * sld)

    # reflectances for each layer
    # rj.shape = (npnts, nlayers + 1)
    rj = kn[:, :-1] - kn[:, 1:]
    rj /= kn[:, :-1] + kn[:, 1:]
    rj *= np.exp(-2.0 * kn[:, :-1] * kn[:, 1:] * layers[1:, 3] ** 2)

    # characteristic matrices for each layer
    # miNN.shape = (npnts, nlayers + 1)
    if nlayers:
        mi00[:, 1:] = np.exp(kn[:, 1:-1] * 1j * np.fabs(layers[1:-1, 0]))
    mi11 = 1.0 / mi00
    mi10 = rj * mi00
    mi01 = rj * mi11

    # initialise matrix total
    mrtot00 = mi00[:, 0]
    mrtot01 = mi01[:, 0]
    mrtot10 = mi10[:, 0]
    mrtot11 = mi11[:, 0]
    
    # propagate characteristic matrices
    for idx in range(1, nlayers + 1):
        # matrix multiply mrtot by characteristic matrix
        p0 = mrtot00 * mi00[:, idx] + mrtot10 * mi01[:, idx]
        p1 = mrtot00 * mi10[:, idx] + mrtot10 * mi11[:, idx]
        mrtot00 = p0
        mrtot10 = p1

        p0 = mrtot01 * mi00[:, idx] + mrtot11 * mi01[:, idx]
        p1 = mrtot01 * mi10[:, idx] + mrtot11 * mi11[:, idx]

        mrtot01 = p0
        mrtot11 = p1

    r = mrtot01 / mrtot00
    reflectivity = r * np.conj(r)
    return np.real(np.reshape(reflectivity, qvals.shape))


In [None]:
def jabeles(layers, q):
    layers = jnp.reshape(layers, (layers.size//4, 4))

    qvals = q.astype(jnp.float64)
    flatq = qvals.ravel()

    nlayers = layers.shape[0] - 2
    npnts = flatq.size

    # kn = jnp.zeros((npnts, nlayers + 2), jnp.complex128)
    mi00 = jnp.ones((npnts, nlayers + 1), jnp.complex128)
    
    sld = jnp.zeros(nlayers + 2, jnp.complex128)

    # addition of TINY is to ensure the correct branch cut
    # in the complex sqrt calculation of kn.
    sld = index_add(
        sld,
        index[1:], (
            (layers[1:, 1] - layers[0, 1]) + 1j * (jnp.abs(layers[1:, 2]) + TINY)
        ) * 1.0e-6
    )
    kn = jnp.sqrt(flatq[:, jnp.newaxis] ** 2.0 / 4.0 - 4.0 * jnp.pi * sld)
    # reflectances for each layer
    # rj.shape = (npnts, nlayers + 1)
    rj = kn[:, :-1] - kn[:, 1:]
    rj /= kn[:, :-1] + kn[:, 1:]
    rj *= jnp.exp(-2.0 * kn[:, :-1] * kn[:, 1:] * layers[1:, 3] ** 2)

    # characteristic matrices for each layer
    # miNN.shape = (npnts, nlayers + 1)
    if nlayers:
        mi00 = index_update(
            mi00, 
            index[:, 1:],
            jnp.exp(kn[:, 1:-1] * 1j * jnp.fabs(layers[1:-1, 0]))
        )
    mi11 = 1.0 / mi00
    mi10 = rj * mi00
    mi01 = rj * mi11

    # initialise matrix total
    mrtot00 = mi00[:, 0]
    mrtot01 = mi01[:, 0]
    mrtot10 = mi10[:, 0]
    mrtot11 = mi11[:, 0]
    
    # propagate characteristic matrices
    for idx in range(1, nlayers + 1):
        # matrix multiply mrtot by characteristic matrix
        p0 = mrtot00 * mi00[:, idx] + mrtot10 * mi01[:, idx]
        p1 = mrtot00 * mi10[:, idx] + mrtot10 * mi11[:, idx]
        mrtot00 = p0
        mrtot10 = p1

        p0 = mrtot01 * mi00[:, idx] + mrtot11 * mi01[:, idx]
        p1 = mrtot01 * mi10[:, idx] + mrtot11 * mi11[:, idx]

        mrtot01 = p0
        mrtot11 = p1

    r = mrtot01 / mrtot00
    reflectivity = r * jnp.conj(r)
    return jnp.real(jnp.reshape(reflectivity, qvals.shape))

In [None]:
np.testing.assert_allclose(jabeles(w, q), abeles(w.flatten(), q))

In [None]:
j_jabeles = jit(jacfwd(jabeles))
jitabeles = jit(jabeles)
# %timeit abeles(w.flatten(), q)
# %timeit jitabeles(w.flatten(), q)
# %timeit num_diff = approx_derivative(abeles, w.flatten(), method='2-point', args=(q,))
# %timeit jax_diff = j_jabeles(w.flatten(), q)

def filter_p(jac):
    mask = np.ones(jac.shape[-1], dtype=bool)
    mask[[0, 2, 3, -4, -2]] = False
    # need to lose 0, 2, 3, -4, -2 because they have absolutely no impact on jac
    jac = jac[:, mask]
    return jac

num_diff = filter_p(approx_derivative(abeles, w.flatten(), method='3-point', args=(q,)))
jax_diff = filter_p(j_jabeles(w.flatten(), q))

np.testing.assert_allclose(filter_p(jax_diff), filter_p(num_diff), rtol=2e-7)

In [None]:
num_diff.shape, jax_diff.shape