In [1]:
from bindsnet.network import load_network
from bindsnet.network.nodes import LIFNodes
from bindsnet.encoding import poisson
from bindsnet.analysis.plotting import plot_weights
from IPython.display import clear_output
import matplotlib.pyplot as plt
import pickle
import torch
import time as T
plt.rcParams["figure.figsize"] = (20, 20)
start_time = str(int(T.time()))

In [2]:
track = 255*torch.load(open('encoding/track08_numenta.p', 'rb'))

In [3]:
network = load_network('trained.net')
network.connections[('X', 'Ae')].update_rule = None
exc_monitor = network.monitors['exc']
del network.layers['Ae']
network.add_layer(LIFNodes(n=100, traces=True, rest=-65.0, reset=-60.0, thresh=-52.0, refrac=5, decay=1e-2, trace_tc=5e-2), name='Ae')
network.connections[('X', 'Ae')].target = network.layers['Ae']
network.connections[('Ae', 'Ai')].source = network.layers['Ae']
network.connections[('Ai', 'Ae')].target = network.layers['Ae']

In [4]:
plot = False
time = 500

In [5]:
def logs(track_n, msg):
    lfname = f'logs/{start_time}_track{track_n+1:02d}_spikes.csv'
    with open(lfname, 'a') as f:
        f.write(msg)
def logv(track_n, msg):
    lfname = f'logs/{start_time}_track{track_n+1:02d}_voltage.csv'
    with open(lfname, 'a') as f:
        f.write(msg)

In [6]:
def vmax(voltages):
    t = []
    for i in range(len(voltages)):
        t.append(max(voltages[i]))
    return torch.tensor(t)

In [7]:
logs(7, 'Iteration,Neuron,Spikes\n')
logv(7, 'Iteration,Neuron,Voltage\n')
for i in range(4, len(track)):
    print(f'Iteration {i-3}')
    
    orig = torch.cat((track[i-4], track[i-3], track[i-2], track[i-1], track[i]))
    pt = poisson(orig, time)
    
    inpts = {'X': pt}
    network.run(inpts=inpts, time=time)
    spikes = exc_monitor.get('s')
    voltage = exc_monitor.get('v')
    network.reset_()
    
    for neuron, value in enumerate(torch.sum(spikes, dim=1)):
        logs(7, f'{i-4},{neuron},{int(value)}\n')
        
    
    for neuron, value in enumerate(vmax(voltage)):
        logv(7, f'{i-4},{neuron},{int(value)}\n')
    
    if plot:
        fig = plt.figure(figsize=(20, 20))
        plt.subplot(2, 2, 1)
        plt.imshow(spikes, cmap='binary')
        plt.subplot(2, 2, 2)
        plt.imshow(orig.view(125, 40), cmap='gist_gray')
        plt.show()
        
    clear_output(wait=True)

Iteration 1071
