# Notebook Description

If we don't konw the underlying dynamics of a system, we'll have to either fit or "learn" the dynamics.

In this notebook, we'll explore techniques for fitting and learning dynamics.

A dynamical system given by

$$\dot{x}=f(x) + g(u)$$
 
but with synapses, ensembles implement

$$\tau_{syn}\dot{x}=-x+f'(x)+g'(u)$$

to make the network implement the desired $f(x)$ and $g(u)$, we decode $f'(x)$ and feed in $g'(u)$

\begin{align}
f'(x) &= \tau_{syn} f(x) + x \\
g'(u) &= \tau_{syn} g(u)
\end{align}

## Workflow
 - Generate dynamical system
 - Fit dynamics to basis functions
 - Approximate dynamical system with nengo
 - Load in real accelerometer data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

import nengo
from nengo.utils.ensemble import tuning_curves
from nengo.utils.functions import piecewise

In [None]:
class DSProcessor:
    """Dynamical Systems Processor
    
    Parameters
    ----------
    dim: int
        dimensions
    nrns: int
        number of neurons
    dt: float
        simulator timestep
    fb_fn: function or (state, dstate) tuple
        closed form function or input/output data pairing
    """
    def __init__(self, dim, nrns, dt, fb_fn,
                 in_fn=None, radius=1, tau=0.1, seed=0,
                 neuron_type=nengo.LIFRate(), max_rates=nengo.dists.Uniform(200, 400)):
        self.dt = dt
        if callable(fb_fn):
            def function(x):
                dx = fb_fn(x)
                ret = [tau*dx_val + x_val for dx_val, x_val in zip(dx, x)]
                return ret
            eval_points = None
        else:
            state, dstate = fb_fn
            function = tau*dstate + state
            eval_points = state
            
        self.net = nengo.Network(seed=seed)
        with self.net:
            ens = nengo.Ensemble(
                nrns, dim, neuron_type=neuron_type,
                radius=radius, seed=seed, max_rates=max_rates)
            readout = nengo.Node(None, size_in=dim)
            self.probe = nengo.Probe(readout, synapse=None)
            
            self.conn = nengo.Connection(ens, ens, function=function, eval_points=eval_points, synapse=tau)
            nengo.Connection(ens, readout, function=function, eval_points=eval_points, synapse=tau)
            if in_fn is not None:
                stim = nengo.Node(in_fn)
                nengo.Connection(stim, ens, transform=tau, synapse=tau)
                nengo.Connection(stim, readout, transform=tau, synapse=tau)
                self.stim_probe = nengo.Probe(stim, synapse=None)
            else:
                self.stim_probe = None
        self.ens = ens
        self.sim = nengo.Simulator(self.net, dt)

    def run(self, sim_time):
        self.sim.run(sim_time)
        state = self.sim.data[self.probe]
        dstate_dt = np.diff(state, axis=0) / dt
        return self.sim.trange()[:-1], state[:-1], dstate_dt
    
    def get_target_decode(self, dstate_dt):
        """Compute the target decode points that an Ensemble's decoders should be optimized for
        
        Run simulator first
        """
        if self.stim_probe:
            fb_tgts = dstate_dt - self.sim.data[self.stim_probe][:-1]
        else:
            fb_tgts = dstate_dt
        return fb_tgts

    def get_fit(self, test_inputs):
        test_inputs, activities = tuning_curves(self.ens, self.sim, test_inputs)
        decoded_values = np.dot(activities, self.sim.data[self.conn].weights.T)
        return decoded_values

First we'll test our paradigm with a 1D low-pass filter.

Desired dynamics are given by
$$\dot{x} = \frac{-1}{\tau_{sys}}x + \frac{1}{\tau_{sys}}u$$

Therefore we train our ensemble to decode and feed back

$$\tau_{syn}\left(\frac{-1}{\tau_{sys}}x\right)+x=\left(1-\frac{\tau_{syn}}{\tau_{sys}}\right)x$$

When $\tau_{syn}>\tau_{sys}$, we're training for negative feedback, and when $\tau_{syn}<\tau_{sys}$, we're training for positive feedback.

In [None]:
def test_1d_lds():
    sys_tau = 0.2
    sim_time = 5*sys_tau
    dt = 0.001

    test_state = np.linspace(-1, 1.0).reshape((-1, 1)) # for comparing decodes

    def fb_fn(x):
        return [-x/sys_tau]
    def in_fn(t):
        return [1/sys_tau]
    dsp_ref = DSProcessor(1, 100, dt, fb_fn, in_fn=in_fn)
    time, state, dstate = dsp_ref.run(sim_time)
    ref_decode = dsp_ref.get_fit(test_state)

    fig, axs = plt.subplots(ncols=3, figsize=(12, 4))
    axs[0].plot(time, state, label="reference")
    axs[1].plot(state, dstate, label="reference")
    axs[2].plot(test_state, ref_decode, label="reference")

    tgt_dstate = dsp_ref.get_target_decode(dstate)
    dsp_appx = DSProcessor(1, 64, dt, (state, tgt_dstate), in_fn=in_fn)
    time, state_appx, dstate_appx = dsp_appx.run(sim_time)
    axs[0].plot(time, state_appx, label="fit")
    axs[1].plot(state_appx, dstate_appx, label="fit")
    decode_appx = dsp_appx.get_fit(test_state)
    axs[2].plot(test_state, decode_appx, label="fit")
    
    axs[0].legend(loc="best")
    axs[0].set_xlabel("time")
    axs[0].set_ylabel("state")
    axs[1].legend(loc="best")
    axs[1].set_xlabel("state")
    axs[1].set_ylabel("observed dstate/dt")
    axs[2].legend(loc="best")
    axs[2].set_xlabel("state")
    axs[2].set_ylabel("feedback decode")
    plt.tight_layout()
test_1d_lds()

Observed reference system data was used for fitting. Fit and reference state and dstate/dt look reasonable. Note how the fit feedback decode only well-approximates the reference decode over the state range used as inputs. 

--------------

## TODO next

cleanup van der pol oscillator
 - use new DSProcessor
 - plot fit over x, y space

----------

Let's try a 2D nonlinear, non-chaotic, [Van der Pol oscillator](https://en.wikipedia.org/wiki/Van_der_Pol_oscillator), oscillator dyanmical system

\begin{align*}
\dot{x} &= y \\
\dot{y} &= \mu(1-x^2)y-x \\
\end{align*}

In [None]:
# Use these in your simulation
dt = 0.001
sim_time = 20.  # duration of simulation

input_fn = piecewise({0.:[1., 0.], 1.:[0., 0.]})  # Use this as the input function to your network
mu = 1.  # damping coefficient
def vanderpol(x, mu=1.):
    ret = np.array([x[1], mu*(1.-x[0]**2)*x[1]-x[0]])
    return ret
dsp = DSProcessor(vanderpol, input_fn=input_fn, dim=2, radius=1, neuron_type=nengo.Direct())
ref_time, ref_state, ref_dstate = dsp.generate_data(sim_time, dt=dt)

In [None]:
# figure out what to use for training
t_clip = 5
t_idx = np.searchsorted(ref_time, t_clip)

plt.plot(ref_time, ref_state)
plt.plot(ref_time[t_idx:], ref_state[t_idx:])
plt.xlabel(r'$t$')

plt.figure()
plt.plot(ref_state[:, 0], ref_state[:, 1])
plt.plot(*ref_state[t_idx:].T)

In [None]:
plt.figure(figsize=(6, 4))
plt.plot(ref_time, ref_state[:, 0], label=r"$x$")
plt.plot(ref_time, ref_state[:, 1], label=r"$y$")
plt.xlabel('Time', fontsize=14)
plt.ylabel(" ", fontsize=16)
plt.legend(loc="best", fontsize=16)
plt.xlim((0, np.max(ref_time)))
plt.xticks([])
plt.yticks([])

plt.figure(figsize=(6, 4))
plt.plot(*ref_state.T, color="#2ca02c")
plt.xticks([])
plt.yticks([])
plt.xlabel(r'$x$', fontsize=16)
plt.ylabel(r'$y$', fontsize=16)

In [None]:
def plot_vdp(ref_time, ref_state, appx_time, appx_state):
    plt.figure(figsize=(6, 4))
    c0 = plt.plot(ref_time, ref_state[:, 0], '--', alpha=0.3)[0].get_color()
    c1 = plt.plot(ref_time, ref_state[:, 1], '--', alpha=0.3)[0].get_color()
    plt.plot(appx_time, appx_state[:, 0], color=c0)
    plt.plot(appx_time, appx_state[:, 1], color=c1)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel(" ", fontsize=16)
    plt.xlim((0, np.max(ref_time)))
    plt.xticks([])
    plt.yticks([])

    plt.figure(figsize=(6, 4))
    plt.plot(*ref_state.T, '--', color="#2ca02c", alpha=0.3)
    plt.plot(*appx_state.T, color="#2ca02c")
    plt.xticks([])
    plt.yticks([])
    plt.xlabel(r'$x$', fontsize=16)
    plt.ylabel(r'$y$', fontsize=16)
    
for npd in [16, 32, 64]: #[16, 32, 64, 128, 256, 512]:
    appx_time, appx_state = approximate_ds(2, npd, ref_state[t_idx:], ref_dstate[t_idx:], sim_time,
                                           radius=1, dt=dt, input_fn=input_fn, neuron_type=nengo.LIF())
    plot_vdp(ref_time, ref_state, appx_time, appx_state)

appx_time, appx_state = approximate_ds(2, 64, ref_state[t_idx:], ref_dstate[t_idx:], sim_time,
                                       radius=1, dt=dt, input_fn=input_fn,
                                       neuron_type=nengo.LIF(), max_rates=nengo.dists.Uniform(100, 200))
plot_vdp(ref_time, ref_state, appx_time, appx_state)

In [None]:
# Use these in your simulation
T = 20.  # duration of simulation
stim = piecewise({0.:[1., 0.], 1.:[0., 0.]})  # Use this as the input function to your network
mu = 1.5  # damping coefficient

### BEGIN SOLUTION
tau_syn = .1  # synaptic time constant 
N = 400       # number of neurons

def vanderpol(x, mu):
    ret = np.array(
        [x[1], 
         mu*(1.-x[0]**2)*x[1]-x[0]])
    return ret

def decode(x, nonlin_fun, *nonlin_args):
    return tau_syn*nonlin_fun(x, *nonlin_args) + x

net = nengo.Network()
with net:
    stim = nengo.Node(stim)
    ens = nengo.Ensemble(N, 2, radius=3.)
    nengo.Connection(stim, ens, transform=tau_syn, synapse=tau_syn)
    nengo.Connection(ens, ens, function=lambda x:decode(x, vanderpol, mu), synapse=tau_syn)
    
    probe_ens = nengo.Probe(ens, synapse=.01)
sim = nengo.Simulator(net)
sim.run(T, progress_bar=False)

spiking_state = sim.data[probe_ens]
ens.neuron_type = nengo.Direct()
sim = nengo.Simulator(net)
sim.run(T, progress_bar=False)
reference_state = sim.data[probe_ens]

t = sim.trange()

plt.plot(sim.trange(), spiking_state[:, 0], 'r-', label=r'spiking $x_0$')
plt.plot(sim.trange(), spiking_state[:, 1], 'b-', label=r'spiking $x_1$')
plt.plot(sim.trange(), reference_state[:, 0], 'r--', label=r'direct $x_0$')
plt.plot(sim.trange(), reference_state[:, 1], 'b--', label=r'direct $x_1$')
plt.legend(loc='upper left', bbox_to_anchor=(1., 1.))
plt.xlabel(r'$t$')

plt.figure()
plt.plot(spiking_state[:, 0], spiking_state[:, 1], 'b', label='spiking mode')
plt.plot(reference_state[:, 0], reference_state[:, 1], 'r', label='direct mode')
plt.legend(loc='upper left', bbox_to_anchor=(1., 1.));
### END SOLUTION

Let's try a canonical, nonlinear, chaotic dynamical system: the Lorenz "butterfly" attractor.  The equations are:
        
$$
\dot{x}_0 = \sigma(x_1 - x_0) \\\
\dot{x}_1 = x_0 (\rho - x_2) - x_1  \\\
\dot{x}_2 = x_0 x_1 - \beta x_2 
$$

Since $x_2$ is centered around approximately $\rho$, and since NEF ensembles are usually optimized to represent values within a certain radius of the origin, we substitute $x_2' = x_2 - \rho$, giving these equations:
$$
\dot{x}_0 = \sigma(x_1 - x_0) \\\
\dot{x}_1 = - x_0 x_2' - x_1\\\
\dot{x}_2' = x_0 x_1 - \beta (x_2' + \rho) - \rho
$$

In [None]:
def test_lorenz():
    # generate Lorenz dynamical system data
    dt = 0.0001
    sim_time = 5
    
    radius = 20
    
    sigma = 10
    beta = 8.0 / 3
    rho = 28

    def lorentz(x):
        dx0 = -sigma * x[0] + sigma * x[1]
        dx1 = -x[0] * x[2] - x[1]
        dx2 = x[0] * x[1] - beta * (x[2] + rho) - rho
        return [dx0, dx1, dx2]

    dsp = DSProcessor(lorentz, dim=3, seed=0, radius=radius, neuron_type=nengo.LIFRate())
    ref_time, ref_state, ref_dstate = dsp.generate_data(sim_time, dt=dt)

    appx_time, appx_state = approximate_ds(3, 1000, ref_state, ref_dstate, sim_time, radius=1, dt=dt)
    diff_state = ref_state - appx_state[:-1]

    # plot Lorentz Attractor data
    def plot_lorenz(max_t):
        idx = np.searchsorted(ref_time, max_t)
        all_xstate = [ref_state[:, 0], appx_state[:-1, 0], diff_state[:, 0]]
        all_ystate = [ref_state[:, 1], appx_state[:-1, 1], diff_state[:, 1]]
        all_zstate = [ref_state[:, 2], appx_state[:-1, 2], diff_state[:, 2]]
        xlims = (np.min(all_xstate), np.max(all_xstate))
        ylims = (np.min(all_ystate), np.max(all_ystate))
        zlims = (np.min(all_zstate), np.max(all_zstate))
        fig_3d = plt.figure(figsize=(14, 4))
        axs_3d = [fig_3d.add_subplot(1, 3, ax_idx+1, projection='3d') for ax_idx in range(3)]
        axs_3d[0].plot(*ref_state[:idx].T)
        axs_3d[1].plot(*appx_state[:idx].T)
        axs_3d[2].plot(*diff_state[:idx].T)
        axs_3d[0].set_title("ground truth")
        axs_3d[1].set_title("fit")
        axs_3d[2].set_title("difference")
        for ax in axs_3d:
            ax.set_xlim(xlims)
            ax.set_ylim(ylims)
            ax.set_zlim(zlims)

        fig_ts, axs_ts = plt.subplots(ncols=3, sharey=True, figsize=(14, 4))
        axs_ts[0].plot(ref_time[:idx], ref_state[:idx])
        axs_ts[1].plot(appx_time[:idx], appx_state[:idx])
        axs_ts[2].plot(ref_time[:idx], diff_state[:idx])
    plot_lorenz(1)
    plot_lorenz(20)
test_lorenz()

In [None]:
def test_lorenz():
    # generate Lorenz dynamical system data
    dt = 0.0001
    sim_time = 10
    sigma = 10
    beta = 8.0 / 3
    rho = 28

    def lorentz(x):
        dx0 = -sigma * x[0] + sigma * x[1]
        dx1 = -x[0] * x[2] - x[1]
        dx2 = x[0] * x[1] - beta * (x[2] + rho) - rho
        return [dx0, dx1, dx2]

    dsp = DSProcessor(lorentz, dim=3, seed=0)
    ref_time, ref_state, ref_dstate = dsp.generate_data(sim_time, dt=dt)

    appx_time, appx_state = approximate_ds(3, 1000, ref_state, ref_dstate, sim_time, dt=dt)
    diff_state = ref_state - appx_state[:-1]

    # plot Lorentz Attractor data
    def plot_lorenz(max_t):
        idx = np.searchsorted(ref_time, max_t)
        all_xstate = [ref_state[:, 0], appx_state[:-1, 0], diff_state[:, 0]]
        all_ystate = [ref_state[:, 1], appx_state[:-1, 1], diff_state[:, 1]]
        all_zstate = [ref_state[:, 2], appx_state[:-1, 2], diff_state[:, 2]]
        xlims = (np.min(all_xstate), np.max(all_xstate))
        ylims = (np.min(all_ystate), np.max(all_ystate))
        zlims = (np.min(all_zstate), np.max(all_zstate))
        fig_3d = plt.figure(figsize=(14, 4))
        axs_3d = [fig_3d.add_subplot(1, 3, ax_idx+1, projection='3d') for ax_idx in range(3)]
        axs_3d[0].plot(*ref_state[:idx].T)
        axs_3d[1].plot(*appx_state[:idx].T)
        axs_3d[2].plot(*diff_state[:idx].T)
        axs_3d[0].set_title("ground truth")
        axs_3d[1].set_title("fit")
        axs_3d[2].set_title("difference")
        for ax in axs_3d:
            ax.set_xlim(xlims)
            ax.set_ylim(ylims)
            ax.set_zlim(zlims)

        fig_ts, axs_ts = plt.subplots(ncols=3, sharey=True, figsize=(14, 4))
        axs_ts[0].plot(ref_time[:idx], ref_state[:idx])
        axs_ts[1].plot(appx_time[:idx], appx_state[:idx])
        axs_ts[2].plot(ref_time[:idx], diff_state[:idx])
    plot_lorenz(1)
    plot_lorenz(20)

# Factors that affect dynamics fit quality

- Noise: in measurements and observations
- Lag
- Mismatch between data dimensionality and model capacity (model dimensions)