# Implementing and fitting a simple syrinx model

Model based on [Mindlin et al., 2003](https://journals.aps.org/pre/abstract/10.1103/PhysRevE.68.041908). 

We want to fit the system of first-order nonlinear equations:

$$
\dot{x} = y \\
\dot{y} = -\epsilon x - C x^2 y + By - D_0
$$
where
$$
\epsilon = \epsilon_1 + \epsilon_2 K(t) \\
B = \beta_1 + \beta_2 P(t) \\
D_0 = \delta D(t)
$$
and $K(t)$, $D(t)$, and $P(t)$ are the (linear envelopes of) tension in the ventral syringeal muscle (vS), the tracheobronchialis dorsalis (dTB), and sub-syringeal air pressure, respectively.

From the original paper, we take parameter values
$$
\epsilon_1 = 1.25 \times 10^8 \, \mathrm{s}^{-2} \\
\epsilon_2 = 7.5 \times 10^9 \, \mathrm{V}^{-1}\cdot \mathrm{s}^{-2} \\
C = 2 \times 10^8 \, \mathrm{cm}^{-2} \cdot \mathrm{s}^{-1} \\
\beta_1 = -2 \times 10^3 \, \mathrm{s}^{-1} \\
\beta_2 = 5.3 \times 10^4 \, \mathrm{V}^{-1}\cdot \mathrm{s}^{-1} \\
\delta = 15 \times 10^6 \, \mathrm{cm}\cdot\mathrm{V}^{-1} \cdot \mathrm{s}^{-2}
$$

In [None]:
eps1 = 1.25e8
eps2 = 7.5e9
beta1 = -2e3
beta2 = 5.3e5  # NOTE: 10x higher than in paper!
C = 2e8
delta = 15e6

But whereas the original paper used measured values for $K$, $P$, and $D$, we use simpler time series constructed to have the same shape:

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import lax
import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica"
})

In [None]:
sr = 4e4  # data sampling rate (Hz)
T = 1.2
t_axis = jnp.arange(0, T, 1/sr)
params_true = jnp.array([eps1/1e8, eps2/1e8, beta1/1e3, beta2/1e3, C/1e8, delta/1e7])

In [None]:
def make_pulse_sequence(pulse_fun, arg_list):
    def applicator(carry, args):
       t = carry
       return t, pulse_fun(t, *args)
    
    def scan_and_sum(t):
        _, out = lax.scan(applicator, t, arg_list)
        return jnp.sum(out)

    return scan_and_sum

In [None]:
pfreq = 2 
plocs = jnp.array([0.3, 0.9])
pA = 0.025  # pressure amplitude (Volts)
p0 = -0.005 # pressure DC offset (Volts)
pwid = 0.08

ppulse = lambda t, loc: pA * jnp.exp(-0.5 * (t - loc)**2/pwid**2) + p0

P = jax.vmap(make_pulse_sequence(ppulse, (plocs,)))

plt.plot(t_axis, P(t_axis))

In [None]:
def make_tension_pulse_fn(shape=1, scale=1, peak=1):
    norm = peak * jnp.exp(shape - shape * jnp.log(shape) - shape * jnp.log(scale)) 

    fn = lambda t, loc: norm * jnp.exp((t - loc)/scale) * jnp.maximum((loc - t), 0)**shape 

    return fn
    
        
kshape = 5  # shape parameter of gamma function
kscale = 0.025  # rate parameter of gamma function (s)
kpeak = 0.06  # peak value (Volts)
klocs = jnp.array([0, 0.5, 1, 1.5])

pulse = make_tension_pulse_fn(shape=kshape, scale=kscale, peak=kpeak)
K = jax.vmap(make_pulse_sequence(pulse, (klocs,)))

plt.plot(t_axis, K(t_axis))

In [None]:
dfreq = 2 
dlocs1 = jnp.arange(0.1, t_axis[-1], 1/dfreq)  # pressure pulse frequency (Hz)
dlocs2 = jnp.arange(0.45, t_axis[-1], 1/dfreq)  # pressure pulse frequency (Hz)
dlocs = jnp.sort(jnp.concatenate([dlocs1, dlocs2]))
dA = jnp.array([0.05, 0.02, 0.01, 0.05, 0.03,])
dwid = 0.01

gpulse = lambda t, loc, amp: amp * jnp.exp(-0.5 * (t - loc)**2/dwid**2)

D = jax.vmap(make_pulse_sequence(gpulse, (dlocs, dA)))

plt.plot(t_axis, D(t_axis))

In [None]:
plt.plot(t_axis, K(t_axis), label='vS tension')
plt.plot(t_axis, D(t_axis), label='dTB tension')
plt.plot(t_axis, P(t_axis), label='Pressure')
plt.xlabel("time (s)")
plt.ylabel("envelope (V)")
plt.legend();

Now let's define and integrate the ODE:

In [None]:
import diffrax

In [None]:
def gradfun(t, y, args):
    # params: eps1, eps2, beta1, beta2, C, delta
    params, extra_args = args
    K, D, P = extra_args
    t_arr = jnp.array([t])
    eps = (params[0] + params[1] * K(t_arr)) * 1e8
    B = (params[2] + params[3] * P(t_arr)) * 1e3
    C = params[4] * 1e8
    D0 = params[5] * D(t_arr) * 1e7

    xdot = y[1] 
    ydot = -eps * y[0] - C * y[0]**2 * y[1] + B * y[1] - D0

    return jnp.array((xdot, ydot[0]))


term = diffrax.ODETerm(gradfun)
solver = diffrax.Dopri5()
saveat = diffrax.SaveAt(ts=jnp.linspace(0, 1.2, int(sr)))
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

In [None]:
soln = diffrax.diffeqsolve(term, solver, t0=0, t1=T, dt0=0.5/sr, y0=jnp.array((0, 0)), saveat=saveat,
                  stepsize_controller=stepsize_controller, args=(params_true, (K, D, P)), max_steps=int(1e6))

In [None]:
plt.plot(soln.ts, soln.ys[:, 0])

In [None]:
audio = jnp.interp(t_axis, soln.ts, soln.ys[:, 0])

If we wanted to save this generated data, we could do:
```python
import scipy.io as sio
import numpy as np

sio.wavfile.write('test.wav', int(sr), np.array(audio))
```

In [None]:
from scipy.signal import stft

freqs, times, spec = stft(audio, fs=sr, 
                  nperseg=512, 
                  noverlap=480)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].plot(t_axis, K(t_axis), label='vS tension')
axs[0].plot(t_axis, D(t_axis), label='dTB tension')
axs[0].plot(t_axis, P(t_axis), label='Pressure')
axs[0].set_xlabel("time (s)")
axs[0].set_ylabel("envelope (V)")
axs[0].set_xlim([0, 1.2])
axs[0].legend();

axs[1].imshow(jnp.abs(spec), extent=[times[0], times[-1], freqs[0], freqs[-1]], aspect='auto')
axs[1].set_xlabel('time (s)')
axs[1].set_ylabel('frequency (Hz)')
axs[1].set_ylim([12500, 19000])
axs[1].set_xlim([0, 1.2])

plt.tight_layout()

# Fit an ODE model to data

Now we'll use data generated from the model above and see if we can recover the parameters of the model.

In [None]:
# split audio into short chunks and stack to make a dataset of snippets
ys_full = jnp.stack([jnp.interp(t_axis, soln.ts, soln.ys[:, idx]) for idx in range(soln.ys.shape[-1])]).T

# take derivative
diff_order = 1
ys_full_grads = jnp.diff(ys_full, diff_order, axis=0, prepend=jnp.zeros((diff_order, ys_full.shape[1]))) * sr
T_snippet = 100  # samples
ys = ys_full.reshape((-1, T_snippet, soln.ys.shape[-1]))
ys_grads = ys_full_grads.reshape((-1, T_snippet, soln.ys.shape[-1]))
ts = t_axis.reshape((-1, T_snippet))

print(ys.shape, ts.shape)

In [None]:
plt.plot(t_axis, ys_full_grads[:, 1])
plt.xlim([0.3, 0.304])

true_grads = jax.vmap(gradfun, in_axes=(0, 0, None))(t_axis, ys_full, (params_true, (K, D, P)))

plt.figure()
plt.plot(t_axis, true_grads[:, 1])
plt.xlim([0.3, 0.304])

plt.figure()
plt.plot(t_axis, true_grads[:, 1] - ys_full_grads[:, 1])
plt.xlim([0.3, 0.304])

In [None]:
# taken from https://docs.kidger.site/diffrax/examples/neural_ode/
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

In [None]:
import equinox as eqx
import optax
import time
from typing import Callable, Any

In [None]:
class NeuralODE(eqx.Module):
    func: Callable
    params: jax.Array
    extra_args: Any

    def __init__(self, gradfun, params, *args, **kwargs):
        super().__init__(**kwargs)
        self.func = gradfun
        self.extra_args = args
        self.params = params

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=(ts[1] - ts[0]),
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-5, atol=1e-5),
            saveat=diffrax.SaveAt(ts=ts),
            args=(self.params, self.extra_args),
            max_steps=int(1e6)
        )
        return solution.ys

In [None]:
class GradPredictor(eqx.Module):
    func: Callable
    params: jax.Array
    extra_args: Any

    def __init__(self, gradfun, params, *args, **kwargs):
        super().__init__(**kwargs)
        self.func = gradfun
        self.extra_args = args
        self.params = params
    

    def __call__(self, ts, ys):
        grad_preds = jax.vmap(self.func, in_axes=(0, 0, None))(ts, ys, (self.params, self.extra_args))
        return grad_preds

In [None]:
from tensorboardX import SummaryWriter
import uuid

def main(
    init,
    batch_size=64,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 2500), 
    length_strategy=(0.1, 1),
    seed=5678,
    plot=True,
    print_every=25,
    train_grad=False,
    runnum=None
):
    if not runnum:
        runnum = uuid.uuid4()

    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    _, length_size, data_size = ys.shape
    
    # eps1, eps2, beta1, beta2, C, delta
    if train_grad:
        model = GradPredictor(gradfun, init, K, D, P)
    else:
        model = NeuralODE(gradfun, init, K, D, P)

    # Training loop like normal.
    #
    # Only thing to notice is that up until step 500 we train on only the first 10% of
    # each time series. This is a standard trick to avoid getting caught in a local
    # minimum.

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi, dyi):
        if train_grad:
            y_pred = jax.vmap(model)(ti, yi)
            return jnp.mean((dyi - y_pred) ** 2/sr**2)
        else:
            y_pred = jax.vmap(model)(ti, yi[:, 0])
            return jnp.mean((yi[:, :, 0] - y_pred[:, :, 0]) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, dyi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi, dyi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    writer = SummaryWriter(f"logs/run{runnum}", flush_secs=1)

    globstep = 0
    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adam(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[:, : int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        _dys = ys_grads[:, : int(length_size * length)]
        for step, (ti, yi, dyi) in zip(
            range(steps), dataloader((_ts, _ys, _dys), batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(ti, yi, dyi, model, opt_state)
            end = time.time()
            globstep += 1
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")
                writer.add_scalar('loss', loss, globstep)
                writer.add_scalars('parameters', {
                    'eps1': model.params[0]/params_true[0],
                    'eps2': model.params[1]/params_true[1],
                    'beta1': model.params[2]/params_true[2],
                    'beta2': model.params[3]/params_true[3],
                    'C': model.params[4]/params_true[4],
                    'delta': model.params[5]/params_true[5],
                }, globstep, end)

    writer.close()
    if plot:
        plt.plot(ts[0], ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts[0], ys[0, :, 1], c="dodgerblue")
        if train_grad:
            model_y = model(ts[0], ys[0])
        else:
            model_y = model(ts[0], ys[0, 0])
        plt.plot(ts[0], model_y[:, 0], c="crimson", label="Model")
        plt.plot(ts[0], model_y[:, 1], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, model

In [None]:
key = jr.PRNGKey(12345)
params0 = jnp.array([1., 100., 1, 1000., 1, 1])
ts, ys, model = main(init=params0, train_grad=False, steps_strategy=(500, 40000), 
                     lr_strategy=(1e-3, 1e-3), batch_size=64, print_every=100)

In [None]:
params_true, params0, model.params

In [None]:
plt.plot(t_axis, ys_full[:, 0])

plt.figure()
y_pred = model(t_axis, ys_full[0])
plt.plot(t_axis, y_pred[:, 0])

plt.figure()
y_pred = model(t_axis, ys_full[0])
plt.plot(t_axis, y_pred[:, 0] - 1 * ys_full[:, 0])

In [None]:
freqs, times, spec_pred = stft(y_pred[:, 0], fs=sr, 
                  nperseg=512, 
                  noverlap=480)

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(5, 12))
axs[0].plot(t_axis, K(t_axis), label='vS tension')
axs[0].plot(t_axis, D(t_axis), label='dTB tension')
axs[0].plot(t_axis, P(t_axis), label='Pressure')
axs[0].set_xlabel("time (s)")
axs[0].set_ylabel("envelope (V)")
axs[0].set_xlim([0, 1.2])
axs[0].legend();

axs[1].imshow(jnp.abs(spec), extent=[times[0], times[-1], freqs[0], freqs[-1]], aspect='auto')
axs[1].set_xlabel('time (s)')
axs[1].set_ylabel('frequency (Hz)')
axs[1].set_ylim([12500, 19000])
axs[1].set_xlim([0, 1.2])
curr_fontsize = plt.rcParams['font.size']
axs[1].annotate('Original', (0.025, 0.9), 
                xycoords='axes fraction', 
                color='white',
                fontsize=2 * curr_fontsize)

axs[2].imshow(jnp.abs(spec_pred), extent=[times[0], times[-1], freqs[0], freqs[-1]], aspect='auto')
axs[2].set_xlabel('time (s)')
axs[2].set_ylabel('frequency (Hz)')
axs[2].set_ylim([12500, 19000])
axs[2].set_xlim([0, 1.2])
axs[2].annotate('Fitted', (0.025, 0.9), 
                xycoords='axes fraction', 
                color='white',
                fontsize=2 * curr_fontsize)

plt.tight_layout()