In [None]:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import networkx as nx
from quantities import ms
from ebdataset.vision import NMnist
from torch.utils import data
from tqdm import tqdm

from ebdataset.vision.transforms import Compose, ToDense, Flatten

from module_utils import graph_to_tensor

try:
    set_start_method("spawn")
except RuntimeError:
    pass

import sys
import torch.utils.data
from torch.autograd import Variable
import torchvision.datasets
import echotorch.nn.reservoir
#from echotorch.nn.reservoir.ESN import ESN



class ESN(nn.Module):
    '''
    Addopted from - 
    N. Schaetti, M. Salomon, and R. Couturier,
    Echo state networks-basedreservoir computing for mnist handwritten digits recognition
    https://ieeexplore-ieee-org.ezproxy.weizmann.ac.il/abstract/document/7982291
    '''
    def __init__(self, input_dim,
                 reservoir_size,
                 output_dim,
                 batch_size,
                 leaky_rate, 
                 w_in = None, 
                 w_rec = None, 
                 w_out = None,
                 w_bias = None,
                 tau = 0.2,
                 noise = False,
                 decaying_output = False):
        super().__init__()
        #create input layer from input to reservoir
        self.input_layer = nn.Linear(input_dim, reservoir_size)
        
        if w_in is not None:
            with torch.no_grad():
                self.input_layer.weight = nn.Parameter(w_in)
        #creat recuurent connections in the reservoir
        if w_rec is None:            
            #Some numbers from - 
            #Goudar and Buonomano. eLife 2018;7:e31134. DOI: https://doi.org/10.7554/eLife.31134
            #create a random fully connected layer
            self.w_rec = torch.normal(0,1.6,size = (reservoir_size,reservoir_size))
            self.w_rec = torch.normal(0,1.6,size = (reservoir_size,reservoir_size))
            #Set connectiviy to 20%
            mask = torch.rand([reservoir_size,reservoir_size]) < 0.2
            self.w_rec *= mask
            
        else:
            self.w_rec =nn.Parameter( w_rec)
            
        
        #create reservoir state at reset = 0, with size (batch_size,reservoir_size)
        #to hold the states for each different example in the batch. 
        self.echo_state = torch.zeros((batch_size,reservoir_size))
        
        #Create readout layer
        self.output_layer = nn.Linear(reservoir_size, output_dim)
        
        #Create a output state to hold the states for the decaying output option
        self.output = torch.zeros((batch_size, output_dim))
        if w_out is not None:
            with torch.no_grad():
                self.output_layer.weight = nn.Parameter(w_out)

        self.w_bias = w_bias 
        self.non_linearity = torch.tanh
        self.noise = noise
        self.tau = tau
        self.reservoir_size = reservoir_size
        self.output_dim = output_dim
        self.decaying_output = decaying_output
        
        self.w_rec_out = nn.Linear(self.output_dim,self.output_dim)
        with torch.no_grad():
            self.w_rec_out.weight = nn.Parameter(self.tau * torch.eye(self.output_dim,requires_grad=True))
        #self.output_rnn = nn.RNN(self.reservoir_size,self.output_dim)
        #with torch.no_grad():
        #    self.output_rnn.RNN.weight_hh_l = \
        #        nn.Parameter(self.tau * torch.eye((self.output_dim, self.output_dim)))
    def forward(self,x):
        
        #Compute the increament from the next input and the reccurent state
        #This is exactly (right?) like the nn.RNN function, I need to try 
        #running with this as well and see what happens.
        with torch.no_grad():
            next_state = self.input_layer(x) +\
                        F.linear(self.echo_state.double(),self.w_rec.double(),self.w_bias) 

        #Add noise if defined
        if self.noise:
            next_state += self.noise
                
        #Add non-linearity
        next_state = self.non_linearity(next_state)
        
        #Update echo state with tau
        self.echo_state += -torch.mul(self.echo_state, self.tau) +\
                            torch.mul(next_state, self.tau)
        
        
        
        if self.decaying_output:
            self.output = self.w_rec_out(self.output)
            self.output = self.output_layer(self.echo_state)
                
        else:
            self.output = self.output_layer(self.echo_state)
        
        return self.output 
    
    def reset_reservoir(self, batch_size):
        ''' 
        Reset the reservoir state to 0 after object has passed.
        This is helpful but redundant, I guess we can look at the state
        after some delta t where the past states are no longer relevant
        NEED to experiment with this. 
        '''
        self.echo_state = torch.zeros((batch_size,self.reservoir_size))
        self.output = torch.zeros((batch_size, self.output_dim))
    


# Experiment parameters
reservoir_size = 500
connectivity = 0.1
spectral_radius = 1.3
leaky_rate = 0.2
batch_size = 60
input_scaling = 0.6
ridge_param = 0.0
bias_scaling = 1.0
image_time = 15 #How many time steps for each image
input_size = 34 * 34 * 2 #The image is flatten to a vector
n_digits = 10
training_size = 60000
test_size = 10000
use_cuda = False #and torch.cuda.is_available()

import os

def collate_fn(samples):
    max_duration = max([s[0].shape[-1] for s in samples])
    batch = torch.zeros(len(samples), 34 * 34 * 2, max_duration)
    labels = []
    for i, s in enumerate(samples):
        batch[i, :, : s[0].shape[-1]] = s[0]
        labels.append(s[1])
    return batch, torch.tensor(labels)


In [None]:
#%%
# Internal matrix
w_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
    connectivity=connectivity,
    spetral_radius=spectral_radius
)

# Input weights
win_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
    connectivity=connectivity,
    scale=input_scaling,
    apply_spectral_radius=False
    )

# Bias vector
wbias_generator = echotorch.utils.matrix_generation.NormalMatrixGenerator(
    connectivity=connectivity,
    scale=bias_scaling,
    apply_spectral_radius=False
     )

# New ESN-JS module
esn = ESN(
    #Input w, num of columns, the ESN is fed row by row so the number of
    #columns is the number of inputs to the ESN in each ts.
    input_dim=input_size,
    #Input h, num of rows
    #image_size=image_size,
    reservoir_size=reservoir_size,
    batch_size = batch_size,
    leaky_rate=leaky_rate,
    #ridge_param=ridge_param,
    output_dim=10,
    w_in=win_generator.generate(size=(input_size, reservoir_size)).T.float(),
    w_rec=w_generator.generate(size=(reservoir_size, reservoir_size)).T,
    w_bias=wbias_generator.generate(size=(reservoir_size)).T,
    decaying_output=False
     )

# Show the model

print(esn)


In [None]:

params = {
        "batch_size": batch_size,
        "collate_fn": collate_fn,
        "shuffle": True,
        "num_workers": 0,
    }

dt = 1 * ms


transforms = Compose([ToDense(dt=dt), Flatten(),])

NMNIST_PATH = "/home/orram/Documents/NMNIST"
OUT_DIR = "/home/orram/Documents/NMNIST/nmnist_output"

training_set = NMnist(NMNIST_PATH, is_train=True, transforms=transforms)
train_loader = data.DataLoader(training_set, **params)

test_set = NMnist(NMNIST_PATH, is_train=False, transforms=transforms)
test_loader = data.DataLoader(test_set, **params)



In [None]:
lr = 5e-2
optimizer = optim.SGD(esn.parameters(), lr = lr)
loss_func = nn.CrossEntropyLoss()
torch.autograd.set_detect_anomaly(True)

linear_regression = nn.Sequential(nn.Linear(reservoir_size*3,10))
linear_loss = nn.CrossEntropyLoss()
linear_optimizer = optim.SGD(linear_regression.parameters(), lr)

epochs = 1
training_size = len(train_loader)*epochs*60
train_loss = []
accur_vote = []
accur_max = []
accur_mts = []
echo_all_memory = []
with tqdm(total=training_size,position=0, leave=True) as pbar:
    for epoch in range(epochs):
        batch_loss = []
        joint_loss = []
        max_loss = []
        mts_loss_vec = []
        for i, (batch, labels) in enumerate(train_loader):
            batch = batch.squeeze(1)
            esn.reset_reservoir(batch_size = batch.shape[0] )
            labels = labels.reshape(batch.shape[0])
            spikes = []
            t_loss = []
            #A place holder to store the classification from each time step 
            #after the run over the image we take the max class, the one who got
            #the most votes! This we call equal voting, each run donates 
            #all the probabilities for class 1 to 10 and the desigion happens
            #after all the image is processed
            vote = torch.zeros((batch.shape[0],10))
            #Another option is only to take the class that got the 
            #most probability and sum over these predictions and drew 
            #the one that got the most votes. This is only max  vote. 
            #The last option - MTS - is described below.
            only_max_vote = torch.zeros((batch.shape[0],10))
            state_memory = torch.empty([batch.shape[0],3*reservoir_size])
            s = []
            ind = 0
            for t in range(len(batch[0,0,:])):
                
                #insert the inputs line by line (time step by time step)
                optimizer.zero_grad()
                output = esn(batch[:, :, t])
                esn.output *= esn.tau
                loss = loss_func(output,labels )
                loss.backward(retain_graph=True)
                optimizer.step()
                spikes.append(esn.echo_state)
                t_loss.append(loss.item())
                vote += output
                one_hot = torch.nn.functional.one_hot(torch.argmax(output,axis = 1), num_classes = 10)
                only_max_vote += one_hot
                '''
                Impliment mixed three state (MTS)- 
                We keep three states from the esn and compute the classification 
                from the unification of these states.
                We choose the state after t=T/3, 2T/3, T
                The layer to feed to the classifier is then (3*reservoir_size).
                To do this we define a new linear reggresser
                accur_
                '''
                
                if t == np.round(len(batch[0,0,:])/3) or \
                    t == np.round(len(batch[0,0,:])*2/3) or \
                        t == len(batch[0,0,:]) - 1 :
                            state_memory[:,ind:ind+reservoir_size] = esn.echo_state
                            s.append(esn.echo_state.detach().numpy())
                            ind+=reservoir_size
                            
            #Run the optimizer over MTS       
            linear_optimizer.zero_grad()
            mts_output = linear_regression(state_memory)
            mts_loss = linear_loss(mts_output, labels)
            mts_loss.backward()
            linear_optimizer.step()
            mts_loss_vec.append(mts_loss.item())
            echo_all_memory.append((state_memory, labels))
            
            
            get_max = torch.argmax(vote, axis = 1)
            joint_loss.append(loss_func(vote, labels).item())
            max_loss.append(loss_func(only_max_vote/len(only_max_vote), labels).item())
            batch_loss.append(np.mean(t_loss))
            spikes = torch.stack(spikes, dim=2)
            #joblib.dump(
            #    (spikes.cpu().numpy(), labels.cpu().numpy()),
            #    os.path.join(OUT_DIR, "%s_batch_%i" % ("TRAIN", i)),
            #    compress=3,
            #    )
            get_vote = torch.argmax(vote, axis = 1)
            get_max = torch.argmax(only_max_vote, axis = 1)
            correct_vote = sum(get_vote.detach().numpy() == labels.detach().numpy())
            correct_max = sum(get_max.detach().numpy() == labels.detach().numpy())
            accur_vote.append(correct_vote/batch.shape[0])
            
            accur_max.append(correct_max/batch.shape[0])
            mts_output = linear_regression(state_memory)
            get_mts = torch.argmax(mts_output, axis = 1)
            correct_mts = sum(get_mts == labels)
            accur_mts.append(correct_mts/batch.shape[0])
            pbar.update(batch_size)
            
        train_loss.append(np.mean(batch_loss))
        
pbar.close()
print('train loss = equal vote = {} max vote only = {} and mts = {} '.format\
      (joint_loss[-1], max_loss[-1], mts_loss_vec[-1]))