In [None]:
import os

import numpy as np
import nengo
import nengo.utils.numpy as npext
from nengo_gui.ipython import IPythonViz

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
plt.rc('figure', figsize=(12, 3))

import phd

In [None]:
%%javascript
if($(IPython.toolbar.selector.concat(' > #kill-run-first')).length == 0){
  IPython.toolbar.add_buttons_group([
    {
      'label'   : 'kill and run-first',
      'icon'    : 'fa fa-angle-double-down',
      'callback': function(){
        IPython.notebook.kernel.restart();
        $(IPython.events).one('kernel_ready.Kernel', function(){
          var idx = IPython.notebook.get_selected_index();
          IPython.notebook.select(0);
          IPython.notebook.execute_cell();
          IPython.notebook.select(idx);
        });
      }
    }
  ], 'kill-run-first');
}

## Trajectory detection

Basic idea:

- We know the trajectory through time,
  and use this to make continuous predictions of the trajectory.
- When the actual trajectory matches our predictions, we advance time.
- We can interrogate how far along the trajectory we are.

In [None]:
import nengo.utils.numpy as npext

def predict(x):
    # Simple trajectory that just goes through 4 discrete states
    if x <= 0.25:
        return np.array([0, 0, 1, 1]).astype(float)
    elif x <= 0.5:
        return np.array([0, 1, 0, 1]).astype(float)
    elif x <= 0.75:
        return np.array([1, 0, 1, 0]).astype(float)
    else:
        return np.array([1, 1, 0, 0]).astype(float)

def observe(t):
    # Same trajectory, but at various points in time
    if t <= 0.5:
        return np.array([0, 0, 1, 1]).astype(float)
    elif t <= 0.6:
        return np.array([0, 1, 0, 1]).astype(float)
    elif t <= 0.7:
        return np.array([1, 0, 1, 0]).astype(float)
    else:
        return np.array([1, 1, 0, 0]).astype(float)

def similarity(v1, v2):
    # v1 and v2 are vectors
    eps = np.nextafter(0, 1)  # smallest float above zero
    dot = np.dot(v1, v2)
    dot /= max(npext.norm(v1), eps)
    dot /= max(npext.norm(v2), eps)
    return dot

# Go from time 0 to 2
t = np.linspace(0, 1, num=200)
# State
x = 0.
c = 0.02
x_hist = []
for tt in t:
    # Make a prediction, get an observation
    pred = predict(x)
    obs = observe(tt)

    # Increment x
    x += (similarity(pred, obs) > 0.9) * c
    x = np.clip(x, 0, 1)
    x_hist.append(x)

plt.plot(t, x_hist)

## Trajectory detection in Nengo

In [None]:
dt = 0.001
trajs = []
gests = ('pat', 'das')
for ges in gests:
    path = phd.ges_path('ges-de-cvc', '%s.ges' % ges.lower())
    trajs.append(phd.vtl.parse_ges(path).trajectory(dt=dt))

model = phd.sermo.Recognition()
model.trial.dt = dt
model.trial.trajectory = np.vstack(trajs)
for ges in gests:
    path = phd.ges_path('ges-de-cvc', '%s.ges' % ges.lower())
    traj = phd.vtl.parse_ges(path).trajectory(dt=model.trial.dt)
    model.add_syllable(label=ges.upper(),
                       n_per_d=400,
                       trajectory=traj,
                       similarity_th=0.85,
                       scale=0.67)

In [None]:
net = model.build()
with net:
    p_traj = nengo.Probe(net.trajectory.output, synapse=0.01)
    p_dmps = [nengo.Probe(dmp.state, synapse=0.01) for dmp in net.syllables]
    p_mem = nengo.Probe(net.memory.output, synapse=0.01)
    p_class = nengo.Probe(net.classifier, synapse=0.01)
    p_resets = [nengo.Probe(dmp.reset, synapse=0.01) for dmp in net.syllables]

sim = nengo.Simulator(net)
sim.run(sum(t.shape[0] * dt for t in trajs))

t = sim.trange()
plt.figure()
plt.plot(t, sim.data[p_traj])
plt.xlim(right=t[-1])
for pr, label in zip(p_dmps, gests):
    plt.figure()
    plt.plot(t, sim.data[pr])
    plt.title(label)
    plt.xlim(right=t[-1])
plt.figure()
plt.plot(t, nengo.spa.similarity(sim.data[p_mem], net.vocab, True))
plt.xlim(right=t[-1])
plt.legend(gests, loc='best')
plt.figure()
plt.plot(t, sim.data[p_class])
plt.xlim(right=t[-1]);

In [None]:
print("Whole model: %d neurons" % sum(e.n_neurons for e in net.all_ensembles))
print("Each syllable: %d neurons" % sum(e.n_neurons for e in net.syllables[0].all_ensembles))