In [None]:
# Common imports
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from JSAnimation import IPython_display
from skspeech.synthesis import kroger as kr

import nengo
import nengo.utils.numpy as npext
import nengo_gui.ipython

In [None]:
# Some plotting niceties
plt.rc('figure', figsize=(10, 8))

def shiftedcmap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    """Offset the 'center' of a colormap.

    Useful for data with a negative min and positive max and you
    want the middle of the colormap's dynamic range to be at zero.

    Parameters
    ----------
    cmap : The matplotlib colormap to be altered
    start : Offset from lowest point in the colormap's range.
            Defaults to 0.0 (no lower ofset). Should be between
            0.0 and `midpoint`.
    midpoint : The new center of the colormap. Defaults to 
               0.5 (no shift). Should be between 0.0 and 1.0. In
               general, this should be  1 - vmax/(vmax + abs(vmin))
               For example if your data range from -15.0 to +5.0 and
               you want the center of the colormap at 0.0, `midpoint`
              should be set to  1 - 5/(5 + 15)) or 0.75
    stop : Offset from highets point in the colormap's range.
           Defaults to 1.0 (no upper ofset). Should be between
           `midpoint` and 1.0.

    From http://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-matplotlib
    """
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)
        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap

#  Recognition system

## Auditory periphery

Making heavy use of [Brian hears](http://www.briansimulator.org/docs/hears.html),
but should also investigate other periphery models.

### Brian filter models

In [None]:
import brian_no_units  # For speed
import brian as br
import brian.hears as bh

In [None]:
def whitenoise_sound():
    sound = bh.whitenoise(100*br.ms).ramp()
    sound.level = 50*bh.dB
    sound.samplerate = 50*br.kHz
    return sound

def tone_sound():
    sound = bh.tone(500*br.Hz, 100*br.ms).ramp()
    sound.level = 50*bh.dB
    sound.samplerate = 50*br.kHz
    return sound

def cochleogram(gt_mon, time=None):
    cmap = plt.cm.RdBu
    if gt_mon.min() >= 0.0:
        cmap = plt.cm.Blues
    elif not np.allclose(gt_mon.max() + gt_mon.min(), 0, atol=1e-5):
        midpoint = np.abs(gt_mon.min()) / (gt_mon.max() - gt_mon.min())
        cmap = shiftedcmap(cmap, midpoint=midpoint)

    duration = time[-1] if time is not None else sound.duration / br.ms
    plt.imshow(gt_mon.T, aspect='auto', origin='lower left', cmap=cmap,
               extent=(0, duration, cf[0], cf[-1]))
    plt.yscale('log')
    plt.ylabel('Frequency (Hz)')
    plt.xlabel('Time (ms)')
    plt.colorbar()

sound = whitenoise_sound()
# sound = tone_sound()
n_cf = n_center_frequencies = 200
cf = center_frequencies = bh.erbspace(100*br.Hz, 1000*br.Hz, n_center_frequencies)

In [None]:
# Gammatone
from brian.hears.filtering.tan_carney import ZhangSynapseRate

b1 = 1.019  # factor determining the time constant of the filters

# Apply middle ear filter?
if False:
    gammatone = bh.Gammatone(bh.MiddleEar(sound), center_frequencies, b=b1)
else:
    gammatone = bh.Gammatone(sound, center_frequencies, b=b1)

gt_mon = gammatone.process()
cochleogram(gt_mon)
plt.title('Cochleogram')

# half wave rectify and compress it with a 1/3 power law
plt.figure()
ihc = bh.FunctionFilterbank(gammatone, lambda x: 3 * np.clip(x, 0, np.inf) ** (1. / 3.))
cochleogram(ihc.process())
plt.title('IHC activity')

# Get AN fiber rates
plt.figure()
syn = ZhangSynapseRate(ihc, cf)
s_mon = br.StateMonitor(syn, 's', record=True, clock=syn.clock)
net = br.Network(syn, s_mon)
net.run(sound.duration)
cochleogram(s_mon.values.T)
plt.title('AN spike rates')

# Let's see spikes now...
plt.figure()
syn = bh.ZhangSynapse(ihc, cf)
sp_mon = br.SpikeMonitor(syn)
net = br.Network(syn, sp_mon)
net.run(sound.duration)
br.raster_plot(sp_mon)

In [None]:
# Approximate Gammatone

# bandwidth of the filters (different in each channel)
bw = 10**(0.037+0.785*np.log10(center_frequencies))
gammatone = bh.ApproximateGammatone(sound, center_frequencies, bw, order=3)
gt_mon = gammatone.process()
cochleogram(gt_mon)

In [None]:
# Log Gammachirp

c1 = -2.96 # glide slope
b1 = 1.81  # factor determining the time constant of the filters
gamma_chirp = bh.LogGammachirp(sound, cf, c=c1, b=b1)
gamma_chirp_mon = gamma_chirp.process()
cochleogram(gamma_chirp_mon)

In [None]:
# Linear Gammachirp

c = 0.0 # glide slope
time_constant = np.linspace(3, 0.3, n_center_frequencies) * br.ms
gamma_chirp = bh.LinearGammachirp(sound, center_frequencies, time_constant, c)
gamma_chirp_mon = gamma_chirp.process()
cochleogram(gamma_chirp_mon)

In [None]:
# Tan & Carney
reproduce_paper_figs = False
if reproduce_paper_figs:
    bh.set_default_samplerate(50*br.kHz)
    sample_length = 1 / bh.get_samplerate(None)
    cf = 1000 * br.Hz

    print 'Testing click response'
    duration = 25 * br.ms
    levels = [40, 60, 80, 100, 120]
    # a click of two samples
    tones = bh.Sound([bh.Sound.sequence([bh.click(sample_length*2, peak=level*bh.dB),
                                         bh.silence(duration=duration - sample_length)])
                      for level in levels])
    ihc = bh.TanCarney(bh.MiddleEar(tones), [cf] * len(levels), update_interval=1)
    syn = bh.ZhangSynapse(ihc, cf)
    s_mon = br.StateMonitor(syn, 's', record=True, clock=syn.clock)
    R_mon = br.StateMonitor(syn, 'R', record=True, clock=syn.clock)
    spike_mon = br.SpikeMonitor(syn)
    net = br.Network(syn, s_mon, R_mon, spike_mon)
    net.run(duration * 1.5)

    for idx, level in enumerate(levels):
        plt.figure(1)
        plt.subplot(len(levels), 1, idx + 1)
        plt.plot(s_mon.times / br.ms, s_mon[idx])
        plt.xlim(0, 25)
        plt.xlabel('Time (msec)')
        plt.ylabel('Sp/sec')
        plt.text(15, np.nanmax(s_mon[idx])/2., 'Peak SPL=%s SPL' % str(level*bh.dB));
        ymin, ymax = plt.ylim()
        if idx == 0:
            plt.title('Click responses')

        plt.figure(2)
        plt.subplot(len(levels), 1, idx + 1)
        plt.plot(R_mon.times / br.ms, R_mon[idx])
        plt.xlabel('Time (msec)')
        plt.xlabel('Time (msec)')
        plt.text(15, np.nanmax(s_mon[idx])/2., 'Peak SPL=%s SPL' % str(level*bh.dB));
        plt.ylim(ymin, ymax)
        if idx == 0:
            plt.title('Click responses (with spikes and refractoriness)')
        plt.plot(spike_mon.spiketimes[idx] / br.ms,
             np.ones(len(spike_mon.spiketimes[idx])) * np.nanmax(R_mon[idx]), 'rx')

    print 'Testing tone response'
    br.reinit_default_clock()
    duration = 60*br.ms
    levels = [0, 20, 40, 60, 80]
    tones = bh.Sound([bh.Sound.sequence([bh.tone(cf, duration).atlevel(level*bh.dB).ramp(when='both',
                                                                                         duration=10*br.ms,
                                                                                         inplace=False),
                                         bh.silence(duration=duration/2)])
                      for level in levels])
    ihc = bh.TanCarney(bh.MiddleEar(tones), [cf] * len(levels), update_interval=1)
    syn = bh.ZhangSynapse(ihc, cf)
    s_mon = br.StateMonitor(syn, 's', record=True, clock=syn.clock)
    R_mon = br.StateMonitor(syn, 'R', record=True, clock=syn.clock)
    spike_mon = br.SpikeMonitor(syn)
    net = br.Network(syn, s_mon, R_mon, spike_mon)
    net.run(duration * 1.5)
    for idx, level in enumerate(levels):
        plt.figure(3)
        plt.subplot(len(levels), 1, idx + 1)
        plt.plot(s_mon.times / br.ms, s_mon[idx])
        plt.xlim(0, 120)
        plt.xlabel('Time (msec)')
        plt.ylabel('Sp/sec')
        plt.text(1.25 * duration/br.ms, np.nanmax(s_mon[idx])/2., '%s SPL' % str(level*bh.dB));
        ymin, ymax = plt.ylim()
        if idx == 0:
            plt.title('CF=%.0f Hz - Response to Tone at CF' % cf)

        plt.figure(4)
        plt.subplot(len(levels), 1, idx + 1)
        plt.plot(R_mon.times / br.ms, R_mon[idx])
        plt.xlabel('Time (msec)')
        plt.xlabel('Time (msec)')
        plt.text(1.25 * duration/br.ms, np.nanmax(R_mon[idx])/2., '%s SPL' % str(level*bh.dB));
        plt.ylim(ymin, ymax)
        if idx == 0:
            plt.title('CF=%.0f Hz - Response to Tone at CF (with spikes and refractoriness)' % cf)
        plt.plot(spike_mon.spiketimes[idx] / br.ms,
             np.ones(len(spike_mon.spiketimes[idx])) * np.nanmax(R_mon[idx]), 'rx')


ihc = bh.TanCarney(bh.MiddleEar(sound), cf, update_interval=1)
ihc_mon = ihc.process()
cochleogram(ihc_mon)

In [None]:
# Dual resonance nonlinear filter

# conversion to stape velocity (which are the units needed by the following centres)
# sound = sound*0.00014

#### Linear Pathway ####

# bandpass filter (second order gammatone filter)
center_frequencies_linear = 10**(-0.067+1.016*np.log10(center_frequencies))
bandwidth_linear = 10**(0.037+0.785*np.log10(center_frequencies))
order_linear = 3
gammatone = bh.ApproximateGammatone(sound, center_frequencies_linear,
                                    bandwidth_linear, order=order_linear)

# linear gain
g = 10**(4.2-0.48*np.log10(center_frequencies))
func_gain = lambda x: g * x
gain = bh.FunctionFilterbank(gammatone, func_gain)

# low pass filter(cascade of 4 second order lowpass butterworth filters)
cutoff_frequencies_linear = center_frequencies_linear
order_lowpass_linear = 2
lp_l = bh.LowPass(gain, cutoff_frequencies_linear)
lowpass_linear = bh.Cascade(gain, lp_l, 4)

#### Nonlinear Pathway ####

# bandpass filter (third order gammatone filters)
center_frequencies_nonlinear = center_frequencies
bandwidth_nonlinear = 10**(-0.031+0.774*np.log10(center_frequencies))
order_nonlinear = 3
bandpass_nonlinear1 = bh.ApproximateGammatone(sound, center_frequencies_nonlinear,
                                              bandwidth_nonlinear,
                                              order=order_nonlinear)

# compression (linear at low level, compress at high level)
a = 10**(1.402+0.819*np.log10(center_frequencies))  # linear gain
b = 10**(1.619-0.818*np.log10(center_frequencies))
v = .2  # compression exponent
func_compression = lambda x: np.sign(x) * np.minimum(a*np.abs(x), b*np.abs(x)**v)
compression = bh.FunctionFilterbank(bandpass_nonlinear1, func_compression)

# bandpass filter (third order gammatone filters)
bandpass_nonlinear2 = bh.ApproximateGammatone(compression,
                                              center_frequencies_nonlinear,
                                              bandwidth_nonlinear,
                                              order=order_nonlinear)

# low pass filter
cutoff_frequencies_nonlinear = center_frequencies_nonlinear
order_lowpass_nonlinear = 2
lp_nl = bh.LowPass(bandpass_nonlinear2, cutoff_frequencies_nonlinear)
lowpass_nonlinear = bh.Cascade(bandpass_nonlinear2, lp_nl, 3)

# adding the two pathways
dnrl_filter = lowpass_linear + lowpass_nonlinear
dnrl = dnrl_filter.process()

cochleogram(dnrl)

In [None]:
# DCGC; Compressive Gammachirp
samplerate = sound.samplerate

c1 = -2.96 # glide slope of the first filterbank
b1 = 1.81  # factor determining the time constant of the first filterbank
c2 = 2.2   # glide slope of the second filterbank
b2 = 2.17  # factor determining the time constant of the second filterbank

order_ERB = 4
ERBrate = 21.4*np.log10(4.37*cf/1000+1)
ERBwidth = 24.7*(4.37*cf/1000 + 1)
ERBspace = np.mean(np.diff(ERBrate))

# the filter coefficients are updated every update_interval (here in samples)
update_interval = 1

# bank of passive gammachirp filters. As the control path uses the same passive
# filterbank than the signal path (but shifted in frequency)
# this filterbank is used by both pathway.
pGc = bh.LogGammachirp(sound, cf, b=b1, c=c1)

fp1 = cf + c1*ERBwidth*b1/order_ERB #centre frequency of the signal path

#### Control Path ####

# the first filterbank in the control path consists of gammachirp filters
# value of the shift in ERB frequencies of the control path with respect to the signal path
lct_ERB = 1.5
n_ch_shift = np.round(lct_ERB/ERBspace)  # value of the shift in channels
# index of the channel of the control path taken from pGc
indch1_control = np.minimum(np.maximum(1, np.arange(1, n_cf+1)+n_ch_shift), n_cf).astype(int)-1
fp1_control = fp1[indch1_control]
# the control path bank pass filter uses the channels of pGc indexed by indch1_control
pGc_control = bh.RestructureFilterbank(pGc, indexmapping=indch1_control)

# the second filterbank in the control path consists of fixed asymmetric compensation filters
frat_control = 1.08
fr2_control = frat_control*fp1_control
asym_comp_control = bh.AsymmetricCompensation(pGc_control, fr2_control, b=b2, c=c2)

# definition of the pole of the asymmetric comensation filters
p0 = 2
p1 = 1.7818*(1-0.0791*b2)*(1-0.1655*abs(c2))
p2 = 0.5689*(1-0.1620*b2)*(1-0.0857*abs(c2))
p3 = 0.2523*(1-0.0244*b2)*(1+0.0574*abs(c2))
p4 = 1.0724

# definition of the parameters used in the control path output levels computation
# (see IEEE paper for details)
decay_tcst = .5*br.ms
order = 1.
lev_weight = .5
level_ref = 50.
level_pwr1 = 1.5
level_pwr2 = .5
RMStoSPL = 30.
frat0 = .2330
frat1 = .005
exp_deca_val = np.exp(-1/(decay_tcst*samplerate)*np.log(2))
level_min = 10**(-RMStoSPL/20)

# definition of the controller class. What is does it take the outputs of the
# first and second fitlerbanks of the control filter as input, compute an overall
# intensity level for each frequency channel. It then uses those level to update
# the filter coefficient of its target, the asymmetric compensation filterbank of
# the signal path.
class CompensensationFilterUpdater(object):
    def __init__(self, target):
        self.target = target
        self.level1_prev = -100
        self.level2_prev = -100

    def __call__(self, *input):
        value1 = input[0][-1,:]
        value2 = input[1][-1,:]
        # the current level value is chosen as the max between the current
        # output and the previous one decreased by a decay
        level1 = np.maximum(np.maximum(value1, 0), self.level1_prev*exp_deca_val)
        level2 = np.maximum(np.maximum(value2, 0), self.level2_prev*exp_deca_val)

        self.level1_prev = level1  # the value is stored for the next iteration
        self.level2_prev = level2
        # the overall intensity is computed between the two filterbank outputs
        level_total = (lev_weight*level_ref*(level1/level_ref)**level_pwr1+
                  (1-lev_weight)*level_ref*(level2/level_ref)**level_pwr2)
        # then it is converted in dB
        level_dB = 20*np.log10(np.maximum(level_total, level_min))+RMStoSPL
        # the frequency factor is calculated
        frat = frat0 + frat1*level_dB
        # the centre frequency of the asymmetric compensation filters are updated
        fr2 = fp1*frat
        coeffs = bh.asymmetric_compensation_coeffs(samplerate, fr2,
                       self.target.filt_b, self.target.filt_a, b2, c2,
                       p0, p1, p2, p3, p4)
        self.target.filt_b, self.target.filt_a = coeffs

#### Signal Path ####
# the signal path consists of the passive gammachirp filterbank pGc previously
# defined followed by a asymmetric compensation filterbank
fr1 = fp1*frat0
varyingfilter_signal_path = bh.AsymmetricCompensation(pGc, fr1, b=b2, c=c2)
updater = CompensensationFilterUpdater(varyingfilter_signal_path)
# the controler which takes the two filterbanks of the control path as inputs
# and the varying filter of the signal path as target is instantiated
control = bh.ControlFilterbank(varyingfilter_signal_path,
                               [pGc_control, asym_comp_control],
                               varyingfilter_signal_path, updater, update_interval)

# run the simulation
# Remember that the controler are at the end of the chain and the output of the
# whole path comes from them
signal = control.process()
cochleogram(signal)

In [None]:
# Zilany -- unfortunately doesn't work right now...
from brian.hears.filtering.zilany import ZILANY

zil = ZILANY(sound, cf, update_interval=1)
zil_mon = zil.process()
cochleogram(zil_mon)

### Hooking them up to Nengo

In [None]:
from brian.hears.filtering.tan_carney import ZhangSynapseRate
from scipy.io.wavfile import read as readwav
from scipy.signal import resample
from nengo.utils.compat import range


class FuncProcess(nengo.processes.Process):
    """Psych! Not a process, just a function.

    Implemented so that we can use functions and
    processes interchangeably without having to
    write annoying conditionals.
    """
    def __init__(self, fn):
        self.fn = fn

    def make_step(self, size_in, size_out, dt, rng):
        return self.fn


class WavFile(nengo.processes.Process):
    def __init__(self, path, at_end='loop'):
        self.default_size_out = 1

        self.path = path
        # Possible at_end values:
        #   loop: start again from the start
        #   stop: output silence (0) after sound
        assert at_end in ('loop', 'stop')
        self.at_end = at_end

    def make_step(self, size_in, size_out, dt, rng):
        assert size_in == 0
        assert size_out == 1

        rate = 1. / dt

        orig_rate, orig = readwav(self.path)
        new_size = orig.size * (rate / orig_rate)
        wave = resample(orig, new_size)
        wave -= wave.mean()
        wave *= 0.1  # arbitrary... should do this better
        #wave *= 10

        if self.at_end == 'loop':

            def step_wavfileloop(t):
                idx = int(t * rate) % wave.size
                return wave[idx]
            return step_wavfileloop

        elif self.at_end == 'stop':

            def step_wavfilestop(t):
                idx = int(t * rate)
                if idx > wave.size:
                    return 0.
                else:
                    return wave[idx]
            return step_wavfilestop


class AuditoryFilterBank(nengo.processes.Process):
    def __init__(self, freq, sound_process, filterbank, zhang_synapse=False, samplerate=None):
        self.freq = freq
        self.sound_process = sound_process
        self.filterbank = filterbank
        self.zhang_synapse = zhang_synapse
        self.samplerate = samplerate

    # IHC activity
    @staticmethod
    def bm2ihc(x):
        """Half wave rectify and compress it with a 1/3 power law."""
        return 3 * np.clip(x, 0, np.inf) ** (1. / 3.)

    def make_step(self, size_in, size_out, dt, rng):
        assert size_in == 0
        assert size_out == self.freq.size

        # If samplerate isn't specified, we'll assume dt
        samplerate = 1. / dt if self.samplerate is None else self.samplerate
        sound_dt = 1. / samplerate

        # Set up the sound
        step_f = self.sound_process.make_step(0, 1, sound_dt, rng)
        ns = NengoSound(step_f, size_out, samplerate)
        self.filterbank.source = ns

        duration = int(dt / sound_dt)
        self.filterbank.buffersize = duration
        ihc = bh.FunctionFilterbank(self.filterbank, self.bm2ihc)        
        # Fails if we don't do this...
        ihc.cached_buffer_end = 0

        if self.zhang_synapse:
            syn = ZhangSynapseRate(ihc, self.freq)
            s_mon = br.RecentStateMonitor(
                syn, 's', record=True, clock=syn.clock, duration=dt * br.second)
            net = br.Network(syn, s_mon)

            def step_synapse(t):
                net.run(dt * br.second)
                return s_mon.values[-1]
            return step_synapse
        else:
            def step_filterbank(t):
                sound = ns.buffer_fetch_next(duration)
                result = ihc.func(self.filterbank.buffer_apply(sound))
                return result[-1]
            return step_filterbank


class NengoSound(bh.BaseSound):
    def __init__(self, step_f, nchannels, samplerate):
        self.step_f = step_f
        self.nchannels = nchannels
        self.samplerate = samplerate
        self.t = 0.0
        self.dt = 1. / self.samplerate

    def buffer_init(self):
        pass

    def buffer_fetch(self, start, end):
        return self.buffer_fetch_next(end - start)

    def buffer_fetch_next(self, samples):
        out = np.empty((samples, self.nchannels))
        for i in range(samples):
            self.t += self.dt
            out[i] = self.step_f(self.t)
        return out

In [None]:
spnoise = WavFile('speech.wav')
dt = 1./50000
plt.plot(spnoise.trange(0.668, dt=dt), spnoise.run(0.668, dt=dt))

In [None]:
# Nengo auditory periphery model
def periphery(freqs, noise, filterbank, neurons_per_freq=30,
              zhang_synapse=False, fs=50000.):
    # Inner hair cell activity
    fb = AuditoryFilterBank(freqs, noise, br_filterbank,
                            samplerate=fs, zhang_synapse=zhang_synapse)
    ihc = nengo.Node(output=fb, size_out=freqs.size)

    # Cochlear neurons projecting down auditory nerve
    an = nengo.networks.EnsembleArray(neurons_per_freq, freqs.size,
                                      neuron_nodes=True,  # For plotting raster
                                      intercepts=nengo.dists.Uniform(0.4, 0.8),
                                      encoders=nengo.dists.Choice([[1]]))
    if zhang_synapse:
        nengo.Connection(ihc, an.input, transform=0.1, synapse=None)
    else:
        nengo.Connection(ihc, an.input)
    return ihc, an

# Dummy sound for now; wnoise will be set during Nengo sim
br_filterbank = bh.Gammatone(bh.Sound(np.zeros(0)), cf, b=1.019)
wnoise = nengo.processes.WhiteNoise(nengo.dists.Gaussian(mean=0, std=0.01))
tnoise = FuncProcess(lambda t: np.sin(2 * np.pi * t * 250))  # 250 Hz tone
spnoise = WavFile('speech.wav')

with nengo.Network() as net:
    ihc, an = periphery(cf, spnoise, br_filterbank, zhang_synapse=False)

    # Probes
    ihc_p = nengo.Probe(ihc, synapse=None)
    an_in_p = nengo.Probe(an.input, synapse=None)
    an_p = nengo.Probe(an.neuron_output, synapse=None)

In [None]:
from nengo.utils.matplotlib import rasterplot

dt = 1. / cf.max()
sim = nengo.Simulator(net, dt=dt*.5)
sim.run(0.1)

plt.figure()
cochleogram(sim.data[ihc_p], sim.trange())
plt.figure()
cochleogram(sim.data[an_in_p], sim.trange())
#plt.figure()
#rasterplot(sim.trange(), sim.data[an_p])
#plt.ylim(0, an.n_neurons * an.n_ensembles)

## Preprocessing layer

### Temporal processing with delay networks

In [None]:
# Generic LTI stuff
from scipy.linalg import solve_lyapunov
from nengo.utils.filter_design import zpk2ss, tf2ss, ss2tf, cont2discrete


class LTI(object):
    def __init__(self, a, b, c, d):
        self.a = np.array(a)
        self.b = np.array(b)
        self.c = np.array(c)
        self.d = np.array(d)

    @property
    def abcd(self):
        return (self.a, self.b, self.c, self.d)

    @classmethod
    def from_synapse(cls, synapse):
        """Instantiate class from a Nengo synapse."""
        if not hasattr(synapse, 'num') or not hasattr(synapse, 'den'):
            raise ValueError("Must be a linear filter with 'num' and 'den'")
        return cls(tf2ss(synapse.num, synapse.den))

    @classmethod
    def from_tf(cls, num, den):
        """Instantiate class from a transfer function."""
        return cls(*tf2ss(num, den))

    @classmethod
    def from_zpk(cls, z, p, k):
        """Instantiate class from a zero-pole-gain representation."""
        return cls(zpk2ss(z, p, k))

    def copy(self):
        return LTI(*self.abcd)

    def scale_to(self, radii=1.0):
        """Scales the system to give an effective radius of r to x."""
        r = np.asarray(radii, dtype=np.float64)
        if r.ndim > 1:
            raise ValueError("radii (%s) must be a 1-dim array or scalar" % radii)
        elif r.ndim == 0:
            r = np.ones(len(self.a)) * r
        self.a = self.a / r[:, None] * r
        self.b /= r[:, None]
        self.c *= r

    def ab_norm(self):
        """Returns H2-norm of each component of x in the state-space.

        Equivalently, this is the H2-norm of each component of (A, B, I, 0).
        This gives the power of each component of x in response to white-noise
        input with uniform power.

        Useful for setting the radius of an ensemble array with continuous
        dynamics (A, B)
        """
        p = solve_lyapunov(self.a, -np.dot(self.b, self.b.T))  # AP + PA^H = Q
        assert np.allclose(np.dot(self.a, p) + np.dot(p, self.a.T) + np.dot(self.b, self.b.T), 0)
        c = np.eye(len(self.a))
        h2norm = np.dot(c, np.dot(p, c.T))
        # The H2 norm of (A, B, C) is sqrt(tr(CXC^T)), so if we want the norm of
        # each component in the state-space representation, we evaluate this for
        # each elementary vector C separately, which is equivalent to just picking
        # out the diagonals
        return np.sqrt(h2norm[np.diag_indices(len(h2norm))])

    def to_sim(self, synapse, dt=0, copy=True):
        """Maps a state-space LTI to the synaptic dynamics on A and B."""
        if not isinstance(synapse, nengo.Lowpass):
            raise TypeError("synapse (%s) must be Lowpass" % (synapse,))
        if dt == 0:
            a = synapse.tau * self.a + np.eye(len(self.a))
            b = synapse.tau * self.b
        else:
            a, b, c, d, _ = cont2discrete(self.abcd, dt=dt)
            aa = np.exp(-dt / synapse.tau)
            a = 1. / (1 - aa) * (a - aa * np.eye(len(a)))
            b = 1. / (1 - aa) * b
        if copy:
            return LTI(a, b, c, d)
        else:
            self.a, self.b, self.c, self.d = a, b, c, d


def exp_delay(p, q, c=1.0):
    """Returns F = p/q such that F(s) = e^(-sc)."""
    # This leads to the same matrices used by the Delay LTISystem, except
    # this is numeric and the latter is the symbolic solution
    from scipy.misc import pade, factorial
    i = np.arange(1, p+q+1)
    taylor = np.append([1.0], (-c)**i / factorial(i))
    return pade(taylor, q)

In [None]:
from nengo.utils.compat import is_number

# LTI in Nengo
def lti_net(n_neurons, lti, synapse=nengo.Lowpass(0.05),
            controlled=False, dt=0.001, radii=None, radius=1.0):
    lti = lti.copy()
    if radii is None:
        radii = lti.ab_norm()
    radii *= radius
    lti.scale_to(radii)  # Probably should require this outside of this function
    lti.to_sim(synapse, dt, copy=False)

    size_in = lti.b.shape[1]
    size_state = lti.a.shape[0]
    size_out = lti.c.shape[0]

    a, b, c, d = lti.abcd

    inp = nengo.Node(size_in=size_in, label="input")
    out = nengo.Node(size_in=size_out, label="output")
    if controlled:
        x = Product(n_neurons, size_state)
        x_in = x.A
    else:
        x = nengo.networks.EnsembleArray(n_neurons, size_state)
        x_in = x.input
    x_out = x.output
    
    nengo.Connection(x_out, x_in, transform=a, synapse=synapse)
    nengo.Connection(inp, x_in, transform=b, synapse=synapse)
    nengo.Connection(x_out, out, transform=c, synapse=None)
    nengo.Connection(inp, out, transform=d, synapse=None)

    return inp, out


class Highpass(nengo.LinearFilter):
    """Differentiated lowpass, raised to a given power."""
    def __init__(self, tau, order=1):
        if order < 1 or not is_number(order):
            raise ValueError("order (%s) must be integer >= 1" % order)
        num, den = [np.poly1d([tau, 0]), np.poly1d([tau, 1])]
        super(Highpass, self).__init__(num=num**order, den=den**order)


# Differentiator
def deconv_net(n_neurons, tf, delay, degree=4, **lti_kwargs):
    """Approximate the inverse of a given transfer function using a delay."""
    num, den = [np.poly1d(tf[0]), np.poly1d(tf[1])]
    order = len(den) - len(num)
    # t.f. can be non-causal as long as the delay order accounts for it
    if order >= degree:
        raise ValueError("order (%s) must be < degree (%s)"
                         % (order, degree))
    # given the tf (1, 1), this is equivalent to Delay,
    # however this uses exp_delay rather than the symbolic solution for
    # the state-space representation
    edp, edq = exp_delay(degree - order, degree, delay)
    p, q = np.polymul(edp, den), np.polymul(edq, num)
    lti = LTI.from_tf(p, q)
    inp, out = lti_net(n_neurons, lti, **lti_kwargs)
    return inp, out, degree


def diff_net(n_neurons, tau, delay, **deconv_kwargs):
    """Output a signal that is a derivative of the input."""
    return deconv_net(n_neurons, ([1], [tau, 0]), delay, **deconv_kwargs)

In [None]:
tau_highpass = 0.05
br_filterbank = bh.Gammatone(bh.Sound(np.zeros(0)), cf, b=1.019)
wnoise = nengo.processes.WhiteNoise(nengo.dists.Gaussian(mean=0, std=0.01))
tnoise = FuncProcess(lambda t: np.sin(2 * np.pi * t * 250))  # 250 Hz tone
spnoise = WavFile('speech.wav')

with nengo.Network() as net:
    # Input is auditory periphery layer
    ihc, an = periphery(cf, spnoise, br_filterbank, zhang_synapse=False)

    # Try 10 ms and 30 ms derivative
    shortderiv = nengo.Node(None, size_in=cf.size)
    for i, freq in enumerate(cf):
        diff_in, diff_out, _ = diff_net(50, tau=tau_highpass, delay=0.008, radius=0.1)
        nengo.Connection(an.output[i], diff_in)
        nengo.Connection(diff_out, shortderiv[i], synapse=tau_highpass)

    longderiv = nengo.Node(None, size_in=cf.size)
    for i, freq in enumerate(cf):
        diff_in, diff_out, _ = diff_net(50, tau=tau_highpass, delay=0.02, radius=0.1)
        nengo.Connection(an.output[i], diff_in)
        nengo.Connection(diff_out, longderiv[i], synapse=tau_highpass)

    # Probes
    ihc_p = nengo.Probe(ihc, synapse=None)
    an_p = nengo.Probe(an.output, synapse=0.01)
    short_p = nengo.Probe(shortderiv, synapse=0.01)
    long_p = nengo.Probe(longderiv, synapse=0.01)

In [None]:
dt = 1. / cf.max()
sim = nengo.Simulator(net, dt=dt*.5)
sim.run(0.667)

plt.figure()
cochleogram(sim.data[ihc_p], sim.trange())
plt.figure()
cochleogram(sim.data[an_p], sim.trange())
plt.figure()
cochleogram(sim.data[short_p], sim.trange())
plt.figure()
cochleogram(sim.data[long_p], sim.trange())

In [None]:
print(sum(ens.n_neurons for ens in net.all_ensembles))