In [1]:
import torch
import matplotlib.pyplot as plt
from IPython.display import clear_output
from bindsnet.models import GPSModel
from bindsnet.network.monitors import Monitor
from bindsnet.encoding import poisson
import time as T
import os
plt.rcParams["figure.figsize"] = (20, 20)
start_time = str(int(T.time()))

In [2]:
try:
    os.mkdir('logs')
except:
    pass

In [3]:
def logs(track_n, msg):
    lfname = f'logs/{start_time}_track{track_n+1:02d}_spikes.csv'
    with open(lfname, 'a') as f:
        f.write(msg)
def logv(track_n, msg):
    lfname = f'logs/{start_time}_track{track_n+1:02d}_voltage.csv'
    with open(lfname, 'a') as f:
        f.write(msg)

In [4]:
def vmax(voltages):
    t = []
    for i in range(len(voltages)):
        t.append(max(voltages[i]))
    return torch.tensor(t)

In [5]:
def relax(network, time):
    img = torch.zeros(5000)
    inpts = {'X': poisson(img, time)}
    network.run(inpts=inpts, time=time)

In [6]:
time = 500
plot = False
network = GPSModel(5000, dt=1.0, norm=48.95, inh=3)
exc_monitor = Monitor(network.layers['Ae'], ['v', 's'], time=time)
network.add_monitor(exc_monitor, name='exc')

In [7]:
# enc_files = sorted([x for x in os.listdir('encoding/') if '_train.p' in x])
enc_files = [sorted([x for x in os.listdir('encoding/') if '_train.p' in x])[0]] * 60
tracks = [(255*torch.load(open('encoding/' + fn, 'rb'))) for fn in enc_files]

In [8]:
print(enc_files)

['track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_train.p', 'track01_tr

In [None]:
for track_n in range(len(tracks)):
    logs(track_n, 'Iteration,Neuron,Spikes\n')
    logv(track_n, 'Iteration,Neuron,Voltage\n')
    
    track = tracks[track_n]
    for i in range(4, len(track)):
        print(f'Track {track_n}   Iteration {i-3}')
        
        orig = torch.cat((track[i-4], track[i-3], track[i-2], track[i-1], track[i]))
        pt = poisson(orig, time)
        
        inpts = {'X': pt}
        network.run(inpts=inpts, time=time)
        spikes = exc_monitor.get('s')
        voltage = exc_monitor.get('v')
        relax(network, 50)
        
        for neuron, value in enumerate(torch.sum(spikes, dim=1)):
            logs(track_n, f'{i-4},{neuron},{int(value)}\n')
    
    
        for neuron, value in enumerate(vmax(voltage)):
            logv(track_n, f'{i-4},{neuron},{int(value)}\n')
        
        if plot:
            fig = plt.figure(figsize=(20, 20))
            plt.subplot(2, 2, 1)
            plt.imshow(spikes, cmap='binary')
            plt.subplot(2, 2, 2)
            plt.imshow(orig.view(125, 40), cmap='gist_gray')
            plt.show()
            
        clear_output(wait=True)
network.save('trained.net')

Track 0   Iteration 174
