In [None]:
import numpy as np
import os
# os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/foo'
# import jax.numpy as jnp
# from jax import grad, jacfwd, jacrev, jit
# from jax.config import config
# from jax.ops import index, index_add, index_update
# config.update("jax_enable_x64", True)
from functools import reduce
from scipy.optimize._numdiff import approx_derivative
import matplotlib.pyplot as plt
%load_ext line_profiler

TINY = 1e-30
q = np.linspace(0.01, 0.5, 1001)
w = np.array([[0, 2.07, 0, 0],
              [100, 3.47, 0.0001, 3],
              [500, -0.5, 0.00001, 3],
              [0, 6.36, 0.0, 3]])

In [None]:
def abeles(layers, q, bkg=0):
    qvals = np.asfarray(q)
    flatq = qvals.ravel()

    nlayers = layers.shape[0] - 2
    npnts = flatq.size

    kn = np.zeros((npnts, nlayers + 2), np.complex128)
    mi00 = np.ones((npnts, nlayers + 1), np.complex128)

    sld = np.zeros(nlayers + 2, np.complex128)

    # addition of TINY is to ensure the correct branch cut
    # in the complex sqrt calculation of kn.
    sld[1:] += (
        (layers[1:, 1] - layers[0, 1]) + 1j * (np.abs(layers[1:, 2]) + TINY)
    ) * 1.0e-6

    # kn is a 2D array. Rows are Q points, columns are kn in a layer.
    # calculate wavevector in each layer, for each Q point.
    kn[:] = np.sqrt(flatq[:, np.newaxis] ** 2.0 / 4.0 - 4.0 * np.pi * sld)

    # reflectances for each layer
    # rj.shape = (npnts, nlayers + 1)
    rj = kn[:, :-1] - kn[:, 1:]
    rj /= kn[:, :-1] + kn[:, 1:]
    rj *= np.exp(-2.0 * kn[:, :-1] * kn[:, 1:] * layers[1:, 3] ** 2)

    # characteristic matrices for each layer
    # miNN.shape = (npnts, nlayers + 1)
    if nlayers:
        mi00[:, 1:] = np.exp(kn[:, 1:-1] * 1j * np.fabs(layers[1:-1, 0]))
    mi11 = 1.0 / mi00
    mi10 = rj * mi00
    mi01 = rj * mi11

    # initialise matrix total
    mrtot00 = mi00[:, 0]
    mrtot01 = mi01[:, 0]
    mrtot10 = mi10[:, 0]
    mrtot11 = mi11[:, 0]
#     return mi00, mi01, mi10, mi11

    # propagate characteristic matrices
    for idx in range(1, nlayers + 1):
        # matrix multiply mrtot by characteristic matrix
        p0 = mrtot00 * mi00[:, idx] + mrtot10 * mi01[:, idx]
        p1 = mrtot00 * mi10[:, idx] + mrtot10 * mi11[:, idx]
        mrtot00 = p0
        mrtot10 = p1

        p0 = mrtot01 * mi00[:, idx] + mrtot11 * mi01[:, idx]
        p1 = mrtot01 * mi10[:, idx] + mrtot11 * mi11[:, idx]

        mrtot01 = p0
        mrtot11 = p1
    
#     return mrtot00, mrtot01, mrtot10, mrtot11

    r = mrtot01 / mrtot00
    reflectivity = r * np.conj(r)
    reflectivity += bkg
    return np.real(np.reshape(reflectivity, qvals.shape))


In [None]:
def abeles2(layers, q, bkg=0):
    qvals = np.asfarray(q)
    flatq = qvals.ravel()
    q2 = flatq**2 / 4.0

    nlayers = layers.shape[0] - 2
    npnts = flatq.size

    kn = np.zeros((npnts, nlayers + 2), np.complex128)
#     mi00 = np.ones((npnts, nlayers + 1), np.complex128)
    mi = np.zeros((npnts, nlayers + 1, 2, 2), np.complex128)
    mi[:, :, 0, 0] = 1.0

    sld = np.zeros(nlayers + 2, np.complex128)

    # addition of TINY is to ensure the correct branch cut
    # in the complex sqrt calculation of kn.
    sld[1:] += (
        (layers[1:, 1] - layers[0, 1]) + 1j * (np.abs(layers[1:, 2]) + TINY)
    ) * 1.0e-6

    # kn is a 2D array. Rows are Q points, columns are kn in a layer.
    # calculate wavevector in each layer, for each Q point.
    kn[:] = np.sqrt(q2[:, np.newaxis] - 4.0 * np.pi * sld)

    # reflectances for each layer
    # rj.shape = (npnts, nlayers + 1)
    rj = kn[:, :-1] - kn[:, 1:]
    rj /= kn[:, :-1] + kn[:, 1:]
    rj *= np.exp(-2.0 * kn[:, :-1] * kn[:, 1:] * layers[1:, 3] ** 2)

    # characteristic matrices for each layer
    # miNN.shape = (npnts, nlayers + 1)
    if nlayers:
        mi[:, 1:, 0, 0] = np.exp(kn[:, 1:-1] * 1j * np.fabs(layers[1:-1, 0]))
    mi[:, :, 1, 1] = 1.0 / mi[:, :, 0, 0]
    mi[:, :, 1, 0] = rj * mi[:, :, 0, 0]
    mi[:, :, 0, 1] = rj * mi[:, :, 1, 1]

#     stk = [np.squeeze(v) for v in np.hsplit(mi, nlayers + 1)]
#     mrtot = np.copy(stk[0])
#     mrtot = np.copy(mi[:, 0])
#     for idx in range(1, nlayers + 1):
#         mrtot[:] = np.matmul(mrtot[:], mi[:, idx])

#     for sub in stk[1:]:
#         mrtot = np.matmul(np.copy(mrtot), sub)
#     mrtot = reduce(np.matmul, stk[1:], stk[0])

    # initialise matrix total
    mrtot00 = mi[:, 0, 0, 0]
    mrtot01 = mi[:, 0, 0, 1]
    mrtot10 = mi[:, 0, 1, 0]
    mrtot11 = mi[:, 0, 1, 1]
    
#     # propagate characteristic matrices
    for idx in range(1, nlayers + 1):
        # matrix multiply mrtot by characteristic matrix
        p0 = mrtot00 * mi[:, idx, 0, 0] + mrtot10 * mi[:, idx, 0, 1]
        p1 = mrtot00 * mi[:, idx, 1, 0] + mrtot10 * mi[:, idx, 1, 1]
        mrtot00 = p0
        mrtot10 = p1

        p0 = mrtot01 * mi[:, idx, 0, 0] + mrtot11 * mi[:, idx, 0, 1]
        p1 = mrtot01 * mi[:, idx, 1, 0] + mrtot11 * mi[:, idx, 1, 1]

        mrtot01 = p0
        mrtot11 = p1

    r = mrtot01 / mrtot00
    reflectivity = r * np.conj(r)
    reflectivity += bkg
    return np.real(np.reshape(reflectivity, qvals.shape))

In [None]:
np.testing.assert_allclose(abeles2(w, q), abeles(w, q))

In [None]:
np.testing.assert_allclose(abeles2(w, q)[:, :, 0, 0], abeles(w, q)[0])
np.testing.assert_allclose(abeles2(w, q)[:, :, 0, 1], abeles(w, q)[1])
np.testing.assert_allclose(abeles2(w, q)[:, :, 1, 0], abeles(w, q)[2])
np.testing.assert_allclose(abeles2(w, q)[:, :, 1, 1], abeles(w, q)[3])

In [None]:
%timeit abeles2(w, q)
%timeit abeles(w, q)

In [None]:
plt.plot(q, abeles2(w, q), label='new')
plt.plot(q, abeles(w, q))
plt.yscale('log')
plt.legend();

In [None]:
%lprun -f abeles abeles(w, q)

In [None]:
a = np.random.uniform(size=100).reshape(25, 2, 2)

In [None]:
np.dot.reduce()

In [None]:
np.multiply.identity