
First of all, we will need to install Norse. Please run the cell below. Read on while it's running.
Run this if you want to install
```!pip install --quiet norse```


# Lower Dimension Quiescent Single Neuron Optimization Via SNN


Down the bottom of the [ipython Norse notebook](https://github.com/norse/notebooks/blob/master/single-neuron-experiments.ipynb) there is a remark that the optimization job actually doesn't perform that well.

I want to show that Spiking Neural Networks should be able to easily fit spike times in a single cell model if the dimensionality of the problem is reduced a lot.

To this end, I have remade the notebook so that:
* Weights of the outer SNN model, change the single-cell models' parameters. 
* When weights of the SNN model change single-cell model parameters, it is a simple way to elicit  the model to produce the right spike times.
* The single neuron model is more realistic Adexp, or Izhikevich models (using my own model code for dense Euler simulations, not tensor flow versions).
* The single neuron model is subject to unrealistic quiescent conditions. It has no synaptic input, its just a fixed current injection experiment.
* The right current amplitude to cause the target number of spikes is found via brute force each time inside the optimization loop, as this is surprisingly inexpensive.

 
If I can make this code succeeed I can evaluate the whole optimization job, and start to answer questions like:

* Does a more realistic single-cell model (with spike-timing adaptation) help?

* Does reducing the dimensionality of the optimization problem help? 


The Permission error can be resolved with $ sudo chmod -R a+r /sys/class/powercap/intel-rapl or with the $ conda activate my-rapl-env, whether you are root or an ordinary user.

In [1]:
!pip install pymongo>=3.9.0
!pip install pandas>=0.25.1
!pip install pyRAPL
!pip 
import pyRAPL
pyRAPL.setup()

import torch
import norse
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['ytick.left'] = True
mpl.rcParams['ytick.labelleft'] = True
mpl.rcParams['axes.spines.left'] = True
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.bottom'] = True
mpl.rcParams['legend.frameon'] = False




Usage:   
  pip <command> [options]

Commands:
  install                     Install packages.
  download                    Download packages.
  uninstall                   Uninstall packages.
  freeze                      Output installed packages in requirements format.
  list                        List installed packages.
  show                        Show information about installed packages.
  check                       Verify installed packages have compatible dependencies.
  config                      Manage local and global configuration.
  search                      Search PyPI for packages.
  wheel                       Build wheels from your requirements.
  hash                        Compute hashes of package archives.
  completion                  A helper command used for command completion.
  debug                       Show information useful for debugging.
  help                        Show help for commands.

General Options:
  -h, --help                  Show h

In [11]:
import norse
from jithub.models import model_classes
model = model_classes.ADEXPModel()
from neuronunit.optimization.model_parameters import MODEL_PARAMS
from neuronunit.tests.target_spike_current import SpikeCountSearch
import quantities as qt
from norse.torch.functional.lif import lif_step, LIFParameters, LIFState

class Neurons():
    def __init__(self, weights, alpha, nspikes):
        #super(Neurons, self).__init__()
        self.hidden_size = hidden_size
        self.w_in = torch.nn.Parameter(torch.tensor(weights).float())
        self.w_rec = torch.zeros(hidden_size, hidden_size) # no recurrent connections
        self.nspikes = nspikes
        self.lambda_vs = []
        self.lambda_is = []
        self.p = LIFParameters(alpha=alpha)

    def reset_lambda_recording(self):
        self.lambda_vs = []
        self.lambda_is = []
        
    def inner_model_eval(self):

        params = {}
        i = 0
        for k,j in MODEL_PARAMS['ADEXP'].items():
            wv = self.w_in.detach().numpy()[0]
            params[k] = np.mean([wv*j[0],wv*j[1]])
            i += 1
        model = model_classes.ADEXPModel()
        model.attrs = params
        
        observation = {}
        observation["value"] = self.nspikes
        scs = SpikeCountSearch(observation)
        target_current = scs.generate_prediction(model)
        inject_param = {
            "padding":0 * qt.ms,
            "delay": 0 * qt.ms,
            "amplitude": 900 * qt.pA,
            "duration": 1000 * qt.ms,
            "dt":0.25
        }
        try:
            inject_param["amplitude"] = target_current["value"]* qt.pA
            model.inject_square_current(**inject_param)
            vm = model.get_membrane_potential()
            plt.plot(vm.times,vm)
            plt.show()

            concise_spk_time = model.get_spike_train()
        except:
            # some model parameters will cause try block to fail.
            concise_spk_time = []
        return concise_spk_time

## Step 1: A simple neuron model

The point neuron models supported by Norse are almost all variants of the Leaky-Integrate and Fire neuron model. It is however relatively easy to implement your own model. The library
is build in layers, here I show an example of how to use the functional API directly. To
build large scale machine learning models, you should check out the tutorial on [PyTorch
lightning + Norse](high-performance-computing.ipynb).


## Step 2.1: Optimizing for a fixed number of spikes

A simple task to consider is a single neuron stimulated at different times by $k$ fixed poisson distributed spike trains, with synaptic weights distributed according to a gaussian distribution. The goal is for the neuron to respond to these fixed spike trains with a certain number of spikes $n_\text{target}$ within a time $T$. The loss in this case is
$$
l = -n_\text{target}/T + \sum_i \delta(t - t_i) 
$$
so
$$
S = \int_0^T (-n_\text{target}/T + \sum_i \delta(t - t_i)) dt = n_\text{actual} - n_\text{target}
$$

In [24]:

def run_training(
    w_in,
    z_in,
    alpha=100.0,
    max_epochs=100,
    target_spikes=6,
    target_spike_offset=10
):
    

    neurons = Neurons(w_in, alpha=torch.tensor(alpha),nspikes=target_spikes)
    optim = torch.optim.SGD(torch.tensor(w_in), lr=0.1)

    pbar = trange(max_epochs)
    for e in pbar:
        optim.zero_grad()
        concise_spk_time = neurons(z_in)
        # compute the loss according to the formula above
        concise_spk_time-target_spikes
        #loss = torch.sum(torch.abs((torch.sum(concise_spk_time, axis=0) - target_spikes)))
        loss.backward()

        pbar.set_postfix({"spike difference": loss.detach().item()})
        if loss.data == torch.tensor([0.0]):
            break

        # do a gradient optimisation step
        optim.step()

    return concise_spk_time

In [25]:
seq_length = 1000
input_size = 20
hidden_size = 1
batch_size = 1
epochs = 100
alpha = 100.0

spikes = torch.distributions.bernoulli.Bernoulli(probs=0.04*torch.ones(seq_length, batch_size, input_size))
spikes_in = spikes.sample()
w_in = np.random.randn(hidden_size,input_size) * np.sqrt(2/hidden_size)


### Force all weights to be positive?
## Questionable move?
Model parameters can only be positive.

Yet on the other hand, SNNs need some negative weights to inhibit bad optimization learnings.

In [26]:
from tqdm.notebook import trange
report = pyRAPL.outputs.DataFrameOutput()
with pyRAPL.Measurement('bar',output=report):

    w_in = [abs(i)/2.0 for i in w_in]
    w_in
report.data.head()

Unnamed: 0,label,timestamp,duration,pkg,dram,socket
0,bar,Sat Jan 1 20:54:30 2022,96.396,2503.0,1221.0,0


In [27]:
from tqdm.notebook import trange
report = pyRAPL.outputs.DataFrameOutput()

with pyRAPL.Measurement('bar',output=report):
    spikes, vs, cs, lambda_vs, lambda_is = run_training(z_in=spikes_in, w_in=w_in, alpha=alpha, target_spikes=6, max_epochs=epochs)
report.data.head()

TypeError: params argument given to the optimizer should be an iterable of Tensors or dicts, but got torch.DoubleTensor

Don't worry that the progress bar turned red, in this case it means that the optimisation
finished early. We can plot the error signals that are propagated backwards in time as follows. At each spike that reaches the neuron at synapse the variable $\lambda_i$ is accumulated to the gradient
of the synaptic weight.

Exercises:
- Change the epoch_from_last variable to plot the error traces at different times in the optimisation
  procedure.
- Change the value alpha. What do you observe?
- Repeat the experiment with more biologically realistic parameters


## Step 2.2: Learning target spike times

Another task is for one neuron to spike at specific spike times $t_0, \ldots, t_N$ given that it stimulated 
by a fixed set of poisson distributed spikes. We can choose as a loss in this case
$$
l = \sum_i \lvert v - v_{\text{th}} \rvert^2 \delta(t - t_i) + l_N
$$
that is we require the membrane voltages to be close to the threshold $v_{th}$ at the required spike times $t_i$
and penalise the neuron if it spikes more or less than the required number of times.

In [16]:
from tqdm.notebook import trange

def run_target_spike_time_training(
    w_in,
    z_in,
    alpha=100.0,
    epochs=40000,
    target_times=[100, 300, 500, 700]
):
    neurons = Neurons(w_in, alpha=torch.tensor(alpha),nspikes=len(target_times))
    params = neurons.parameters()

    optim = torch.optim.SGD(params, lr=0.1)



    v_target = torch.zeros(seq_length, batch_size, hidden_size)
    target_spikes = len(target_times)

    for time in target_times:  
        v_target[time,:] = 1.1 * torch.ones(hidden_size)
    loss=0
    pbar = trange(epochs)
    losses= []
    for e in pbar:
        optim.zero_grad()
        concise_spk_times = neurons(z_in)        
        concise_spk_times = neuron.inner_model_eval()
        target_times=[100, 300, 500, 700]
        target_times = [ gt/1000.0 for gt in target_times]
        try:
            loss = torch.nn.functional.mse_loss(torch.as_tensor(concise_spk_times),torch.as_tensor(target_times))
            #print(loss,"mine")
        except:
            loss = torch.as_tensor(1000.0)
        optimizer.zero_grad()
        optimizer.step()
        losses.append(loss1.detach())
    return losses

In [17]:
"""
neurons = Neurons(w_in, alpha=torch.tensor(alpha),nspikes=target_spikes)
params = neurons.parameters()
print([p for p in params])
params = neurons.parameters()

optim = torch.optim.SGD(params, lr=0.1)
print(neurons.p)
print(w_in)
"""

'\nneurons = Neurons(w_in, alpha=torch.tensor(alpha),nspikes=target_spikes)\nparams = neurons.parameters()\nprint([p for p in params])\nparams = neurons.parameters()\n\noptim = torch.optim.SGD(params, lr=0.1)\nprint(neurons.p)\nprint(w_in)\n'

In [18]:
seq_length = 1000
input_size = 150
hidden_size = 1
batch_size = 1
epochs = 100
alpha = 100.0
target_times = [100, 300, 500, 700]


#w_in = np.random.randn(hidden_size,input_size) * np.sqrt(2/hidden_size)

w_in = np.random.randn(hidden_size,input_size)* np.sqrt(2/hidden_size)
w_in = [abs(i)/2.0 for i in w_in]

In [19]:


spikes = torch.distributions.bernoulli.Bernoulli(probs=0.04*torch.ones(seq_length, batch_size, input_size))
z_in = spikes.sample()
result = run_target_spike_time_training(
    w_in=w_in, 
    z_in=z_in,
    alpha=alpha, 
    epochs=epochs, 
    target_times=target_times
)

AttributeError: 'Neurons' object has no attribute 'parameters'

In [None]:
spikes, vs, cs, lambda_vs, lambda_is = result

actual_times = spikes[-1][:,0,0].to_sparse().indices()[0]


for ts in target_times:
    plt.axvline(x=ts, color='red', linestyle='--')

for ts in list(actual_times):
    plt.axvline(x=ts, color='blue', linestyle='-')

plt.plot(vs[-1][:,0], color='grey', label='$v$')
plt.xlabel('Time [ms]')
plt.legend()

We again visualise the error traces over time.

In [None]:
plt.plot(lambda_vs[-2][:,0], label='$\lambda_v$')
plt.xlabel('Time [ms]')
plt.legend()

Exercises:
- This task doesn't actually do great, can you think of ways to improve it?
- What additions to the loss could one consider to make the task more stable?
- Explore different values for alpha, target_times and input size, what do you observe?
- Consider a different optimiser
- Consider using biologically plausible neuron parameters