<a href="https://colab.research.google.com/github/sazio/NMAs/blob/main/Data_Loader.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Exploratory Data Analysis of Stringer Dataset 
@authors: Simone Azeglio, Chetan Dhulipalla , Khalid Saifullah 


Part of the code here has been taken from [Neuromatch Academy's Computational Neuroscience Course](https://compneuro.neuromatch.io/projects/neurons/README.html), and specifically from [this notebook](https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/projects/neurons/load_stringer_spontaneous.ipynb)

# to do list

1. custom normalization: dividing by mean value per neuron
1a. downsampling: convolve then downsample by 5
2. training validation split: withhold last 20 percent of time series for testing
3. RNN for each layer: a way to capture the dynamics inside each layer instead of capturing extra dynamics from inter-layer interactions. it will be OK to compare the different RNNs. maintain same neuron count in each layer to reduce potential bias 
4. layer weight regularization: L2 
5. early stopping , dropout?

## Loading of Stringer spontaneous data



In [28]:
#@title Data retrieval
import os, requests

fname = "stringer_spontaneous.npy"
url = "https://osf.io/dpqaj/download"

if not os.path.isfile(fname):
    try:
        r = requests.get(url)
    except requests.ConnectionError:
        print("!!! Failed to download data !!!")
    else:
        if r.status_code != requests.codes.ok:
            print("!!! Failed to download data !!!")
        else:
            with open(fname, "wb") as fid:
                fid.write(r.content)

In [29]:
#@title Import matplotlib and set defaults
from matplotlib import rcParams 
from matplotlib import pyplot as plt
rcParams['figure.figsize'] = [20, 4]
rcParams['font.size'] =15
rcParams['axes.spines.top'] = False
rcParams['axes.spines.right'] = False
rcParams['figure.autolayout'] = True

## Exploratory Data Analysis (EDA)

In [30]:
#@title Data loading
import numpy as np
dat = np.load('stringer_spontaneous.npy', allow_pickle=True).item()
print(dat.keys())

dict_keys(['sresp', 'run', 'beh_svd_time', 'beh_svd_mask', 'stat', 'pupilArea', 'pupilCOM', 'xyz'])


In [31]:
# functions 

def moving_avg(array, factor = 5):
    """Reducing the number of compontents by averaging of N = factor
    subsequent elements of array"""
    zeros_ = np.zeros((array.shape[0], 2))
    array = np.hstack((array, zeros_))

    array = np.reshape(array, (array.shape[0],  int(array.shape[1]/factor), factor))
    array = np.mean(array, axis = 2)

    return array

## Extracting Data for RNN (or LFADS)
The first problem to address is that for each layer we don't have the exact same number of neurons. We'd like to have a single RNN encoding all the different layers activities, to make it easier we can take the number of neurons ($N_{neurons} = 1131$ of the least represented class (layer) and level out each remaining class. 

In [32]:
# Extract labels from z - coordinate
from sklearn import preprocessing
x, y, z = dat['xyz']

le = preprocessing.LabelEncoder()
labels = le.fit_transform(z)
### least represented class (layer with less neurons)
n_samples = np.histogram(labels, bins=9)[0][-1]

In [33]:
### Data for LFADS / RNN 
import pandas as pd 
dataSet = pd.DataFrame(dat["sresp"])
dataSet["label"] = labels 

In [34]:
# it can be done in one loop ... 
data_ = []
for i in range(0, 9):
    data_.append(dataSet[dataSet["label"] == i].sample(n = n_samples).iloc[:,:-1].values)

dataRNN = np.zeros((n_samples*9, dataSet.shape[1]-1))
for i in range(0,9):
    
    # dataRNN[n_samples*i:n_samples*(i+1), :] = data_[i]
    ## normalized by layer
    dataRNN[n_samples*i:n_samples*(i+1), :] = data_[i]/np.mean(np.asarray(data_)[i,:,:], axis = 0)

## shuffling for training purposes

#np.random.shuffle(dataRNN)

In [35]:
#unshuffled = np.array(data_)

In [36]:
#@title Convolutions code

# convolution moving average

# kernel_length = 50
# averaging_kernel = np.ones(kernel_length) / kernel_length

# dataRNN.shape

# avgd_dataRNN = list()

# for neuron in dataRNN:
#   avgd_dataRNN.append(np.convolve(neuron, averaging_kernel))

# avg_dataRNN = np.array(avgd_dataRNN)

# print(avg_dataRNN.shape)

In [37]:
# @title Z Score Code 


# from scipy.stats import zscore


# neuron = 500

# scaled_all = zscore(avg_dataRNN)
# scaled_per_neuron = zscore(avg_dataRNN[neuron, :])

# scaled_per_layer = list()

# for layer in unshuffled:
#   scaled_per_layer.append(zscore(layer))

# scaled_per_layer = np.array(scaled_per_layer)



# plt.plot(avg_dataRNN[neuron, :])
# plt.plot(avg_dataRNN[2500, :])
# plt.figure()
# plt.plot(dataRNN[neuron, :])
# plt.figure()
# plt.plot(scaled_all[neuron, :])
# plt.plot(scaled_per_neuron)
# plt.figure()
# plt.plot(scaled_per_layer[0,neuron,:])


In [38]:
# custom normalization

normed_dataRNN = list()
for neuron in dataRNN:
    normed_dataRNN.append(neuron)#/ neuron.mean())
normed_dataRNN = np.array(normed_dataRNN)

# downsampling and averaging 
#avgd_normed_dataRNN = dataRNN#
avgd_normed_dataRNN = moving_avg(dataRNN, factor=2)

issue: does the individual scaling by layer introduce bias that may artificially increase performance of the network?

## Data Loader 


In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.cuda.empty_cache()

In [41]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [42]:
# set the seed
np.random.seed(42)

# number of neurons 
NN = dataRNN.shape[0]

In [43]:
# swapping the axes to maintain consistency with seq2seq notebook in the following code - the network takes all the neurons at a time step as input, not just one neuron

# avgd_normed_dataRNN = np.swapaxes(avgd_normed_dataRNN, 0, 1)
avgd_normed_dataRNN.shape

(10179, 3510)

In [82]:
frac = 4/5

#x1 = torch.from_numpy(dataRNN[:,:int(frac*dataRNN.shape[1])]).to(device).float().unsqueeze(0)
#x2 = torch.from_numpy(dataRNN[:,int(frac*dataRNN.shape[1]):]).to(device).float().unsqueeze(0)
#x1 = torch.from_numpy(avgd_normed_dataRNN[:1131,:]).to(device).float().unsqueeze(2)
#x2 = torch.from_numpy(avgd_normed_dataRNN[:1131,:]).to(device).float().unsqueeze(2)

n_neurs = 1131
# let's use n_neurs/10 latent components
ncomp = int(n_neurs/10)

x1_train = torch.from_numpy(avgd_normed_dataRNN[:n_neurs,:int(frac*avgd_normed_dataRNN.shape[1])]).to(device).float().unsqueeze(2)
x2_train = torch.from_numpy(avgd_normed_dataRNN[:n_neurs,:int(frac*avgd_normed_dataRNN.shape[1])]).to(device).float().unsqueeze(2)

x1_valid = torch.from_numpy(avgd_normed_dataRNN[:n_neurs,int(frac*avgd_normed_dataRNN.shape[1]):]).to(device).float().unsqueeze(2)
x2_valid = torch.from_numpy(avgd_normed_dataRNN[:n_neurs,int(frac*avgd_normed_dataRNN.shape[1]):]).to(device).float().unsqueeze(2)

NN1 = x1_train.shape[0]
NN2 = x2_train.shape[0]

In [83]:
class Net(nn.Module):
    def __init__(self, ncomp, NN1, NN2, bidi=True):
        super(Net, self).__init__()

        
        self.rnn = nn.LSTM(NN1, ncomp, num_layers = 1, dropout = 0.,
                         bidirectional = bidi)
    
        
    
        self.feedback = nn.Sequential(
                    nn.Linear(ncomp, NN2*3),
                    nn.Mish(),
                    nn.Dropout(),
                    nn.Linear(NN2*3, NN2*3),
                    nn.Mish(),
                    nn.Dropout(),
                    nn.Linear(NN2*3, NN2))

        

        
        self.rnn2 = nn.LSTM(NN1*2, ncomp, num_layers = 1, dropout = 0.,
                         bidirectional = bidi)
        """
        self.rnn = nn.RNN(NN1, ncomp, num_layers = 1, dropout = 0,
                    bidirectional = bidi, nonlinearity = 'tanh')
        self.rnn = nn.GRU(NN1, ncomp, num_layers = 1, dropout = 0,
                         bidirectional = bidi)
        """
        
        self.mlp = nn.Sequential(
                    nn.Linear(ncomp, ncomp*5),
                    nn.Mish(),
                    nn.Dropout(),
                    nn.Linear(ncomp*5, ncomp*5),
                    nn.Mish(),
                    nn.Dropout(),
                    nn.Linear(ncomp*5, ncomp), 
                    nn.Mish())
        
        self.fc = nn.Linear(ncomp, NN2)

    def forward(self, x):
        x = x.permute(1, 2, 0)
        
        #print(x.shape)
        # h_0 = torch.zeros(2, x.size()[1], self.ncomp).to(device)
        
        y, h_n = self.rnn(x)

        #print(y.shape)
        #print(h_n.shape)
        
        
        
        if self.rnn.bidirectional:
          # if the rnn is bidirectional, it concatenates the activations from the forward and backward pass
          # we want to add them instead, so as to enforce the latents to match between the forward and backward pass
            q = (y[:, :, :ncomp] + y[:, :, ncomp:])/2
        else:
            q = y
        
        #q = self.mlp(q)
        
        q = self.feedback(q)#.permute(2, 0, 1)
        #print(q.shape)
        #print(x.shape)
       
        context = torch.cat((x, q), dim = 2)#.permute(2,0,1)
        #print(context.shape)
        #print(context.permute(1,2,0).shape)
        
        y1, h_n1 = self.rnn2(context)
        
        if self.rnn2.bidirectional:
          # if the rnn is bidirectional, it concatenates the activations from the forward and backward pass
          # we want to add them instead, so as to enforce the latents to match between the forward and backward pass
            q1 = (y1[:, :, :ncomp] + y1[:, :, ncomp:])/2
        else:
            q1 = y1
        
        #q1 = self.mlp(q1)

        # the softplus function is just like a relu but it's smoothed out so we can't predict 0
        # if we predict 0 and there was a spike, that's an instant Inf in the Poisson log-likelihood which leads to failure
        #z = F.softplus(self.fc(q), 10)
        #print(q.shape)
        z = self.fc(q1).permute(2, 0, 1)
        #print(z.shape)
        return z, q1

In [84]:
# we initialize the neural network
net = Net(ncomp, NN1, NN2, bidi = True).to(device)

# special thing:  we initialize the biases of the last layer in the neural network
# we set them as the mean firing rates of the neurons.
# this should make the initial predictions close to the mean, because the latents don't contribute much
net.fc.bias.data[:] = x1_train.mean(axis = (0,1))

# we set up the optimizer. Adjust the learning rate if the training is slow or if it explodes.
optimizer1 = torch.optim.Adam(net.parameters(), lr=.0001)
# optimizer2 = torch.optim.SGD(net.parameters(), lr = 0.0001, momentum = 0.9, weight_decay = 0.01, )
# optimizer3 = torch.optim.

In [85]:
# forward check 
# net(x1)
net(x1_train)

(tensor([[[0.8645],
          [0.8551],
          [0.8731],
          ...,
          [0.9462],
          [0.9902],
          [1.0154]],
 
         [[1.0890],
          [0.9859],
          [1.0541],
          ...,
          [0.8474],
          [1.0853],
          [0.9355]],
 
         [[0.9512],
          [0.8004],
          [0.8590],
          ...,
          [1.0671],
          [1.0460],
          [0.7970]],
 
         ...,
 
         [[1.1925],
          [1.1176],
          [1.0192],
          ...,
          [1.3090],
          [1.1207],
          [1.0711]],
 
         [[1.1815],
          [0.8538],
          [1.0385],
          ...,
          [1.0004],
          [0.8761],
          [0.8845]],
 
         [[1.1223],
          [1.1991],
          [1.2283],
          ...,
          [1.0412],
          [0.9367],
          [0.9327]]], device='cuda:0', grad_fn=<PermuteBackward>),
 tensor([[[-0.0292,  0.0085, -0.0215,  ...,  0.0076, -0.2420,  0.1812]],
 
         [[ 0.1665,  0.1342, -0.1763,

## Training 

In [86]:
from tqdm import tqdm

In [87]:
from sam import SAM

base_optimizer = torch.optim.Adam  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(net.parameters(), base_optimizer, lr=0.001, weight_decay = 1e-6)#, momentum=0.9)

# you can keep re-running this cell if you think the cost might decrease further

cost = nn.MSELoss()

#loss_save = []
#valid_save = []

niter =  5500 # + 5800
# rnn_loss = 0.2372, lstm_loss = 0.2340, gru_lstm = 0.2370
for k in tqdm(range(niter)):
    net.train()
    # the network outputs the single-neuron prediction and the latents
    z, y = net(x1_train)

    # our cost
    loss = cost(z, x2_train)

    # train the network as usual
    loss.backward()
    optimizer.first_step(zero_grad = True)
    
    cost(net(x1_train)[0],x2_train).backward()
    
    optimizer.second_step(zero_grad=True)
    
    with torch.no_grad():
        net.eval()
        loss_save.append(loss.item())
        valid_loss = cost(net(x1_valid)[0], x2_valid)
        valid_save.append(valid_loss.item())

    if k % 25== 0:
        with torch.no_grad():
            net.eval()
            valid_loss = cost(net(x1_valid)[0], x2_valid)
            
            print(f' iteration {k}, train cost {loss.item():.4f}, valid cost {valid_loss.item():.4f}')

  0%|          | 1/5500 [00:00<1:09:01,  1.33it/s]

 iteration 0, train cost 2.7321, valid cost 2.7653


  0%|          | 26/5500 [00:18<1:05:28,  1.39it/s]

 iteration 25, train cost 2.1122, valid cost 2.2014


  1%|          | 51/5500 [00:36<1:05:04,  1.40it/s]

 iteration 50, train cost 2.0674, valid cost 2.1554


  1%|▏         | 76/5500 [00:53<1:04:46,  1.40it/s]

 iteration 75, train cost 2.0586, valid cost 2.1449


  2%|▏         | 101/5500 [01:11<1:05:06,  1.38it/s]

 iteration 100, train cost 2.0299, valid cost 2.1141


  2%|▏         | 126/5500 [01:29<1:04:03,  1.40it/s]

 iteration 125, train cost 1.9802, valid cost 2.0612


  3%|▎         | 151/5500 [01:46<1:04:04,  1.39it/s]

 iteration 150, train cost 1.9214, valid cost 2.0035


  3%|▎         | 176/5500 [02:04<1:03:49,  1.39it/s]

 iteration 175, train cost 1.8551, valid cost 1.9428


  4%|▎         | 201/5500 [02:22<1:03:03,  1.40it/s]

 iteration 200, train cost 1.7859, valid cost 1.8777


  4%|▍         | 226/5500 [02:39<1:03:25,  1.39it/s]

 iteration 225, train cost 1.7186, valid cost 1.8200


  5%|▍         | 251/5500 [02:57<1:02:40,  1.40it/s]

 iteration 250, train cost 1.6555, valid cost 1.7660


  5%|▌         | 276/5500 [03:15<1:02:34,  1.39it/s]

 iteration 275, train cost 1.5962, valid cost 1.7167


  5%|▌         | 301/5500 [03:32<1:02:05,  1.40it/s]

 iteration 300, train cost 1.5437, valid cost 1.6729


  6%|▌         | 326/5500 [03:50<1:01:44,  1.40it/s]

 iteration 325, train cost 1.4973, valid cost 1.6347


  6%|▋         | 351/5500 [04:08<1:01:47,  1.39it/s]

 iteration 350, train cost 1.4548, valid cost 1.6002


  7%|▋         | 376/5500 [04:26<1:01:27,  1.39it/s]

 iteration 375, train cost 1.4165, valid cost 1.5698


  7%|▋         | 401/5500 [04:43<1:01:19,  1.39it/s]

 iteration 400, train cost 1.3809, valid cost 1.5412


  8%|▊         | 426/5500 [05:01<1:01:06,  1.38it/s]

 iteration 425, train cost 1.3479, valid cost 1.5155


  8%|▊         | 451/5500 [05:19<1:00:59,  1.38it/s]

 iteration 450, train cost 1.3174, valid cost 1.4912


  9%|▊         | 476/5500 [05:37<1:00:13,  1.39it/s]

 iteration 475, train cost 1.2882, valid cost 1.4683


  9%|▉         | 501/5500 [05:54<59:42,  1.40it/s]  

 iteration 500, train cost 1.2584, valid cost 1.4447


 10%|▉         | 526/5500 [06:12<59:25,  1.40it/s]

 iteration 525, train cost 1.2308, valid cost 1.4230


 10%|█         | 551/5500 [06:30<59:30,  1.39it/s]

 iteration 550, train cost 1.2057, valid cost 1.4036


 10%|█         | 576/5500 [06:47<58:34,  1.40it/s]

 iteration 575, train cost 1.1826, valid cost 1.3865


 11%|█         | 601/5500 [07:05<59:01,  1.38it/s]

 iteration 600, train cost 1.1614, valid cost 1.3694


 11%|█▏        | 626/5500 [07:23<59:05,  1.37it/s]

 iteration 625, train cost 1.1408, valid cost 1.3529


 12%|█▏        | 651/5500 [07:41<58:47,  1.37it/s]

 iteration 650, train cost 1.1223, valid cost 1.3393


 12%|█▏        | 676/5500 [07:59<58:37,  1.37it/s]

 iteration 675, train cost 1.1044, valid cost 1.3256


 13%|█▎        | 701/5500 [08:16<57:21,  1.39it/s]

 iteration 700, train cost 1.0878, valid cost 1.3127


 13%|█▎        | 726/5500 [08:34<57:47,  1.38it/s]

 iteration 725, train cost 1.0720, valid cost 1.3008


 14%|█▎        | 751/5500 [08:52<57:02,  1.39it/s]

 iteration 750, train cost 1.0572, valid cost 1.2900


 14%|█▍        | 776/5500 [09:10<56:45,  1.39it/s]

 iteration 775, train cost 1.0425, valid cost 1.2771


 15%|█▍        | 801/5500 [09:27<56:28,  1.39it/s]

 iteration 800, train cost 1.0288, valid cost 1.2687


 15%|█▌        | 826/5500 [09:45<55:52,  1.39it/s]

 iteration 825, train cost 1.0159, valid cost 1.2603


 15%|█▌        | 851/5500 [10:03<55:57,  1.38it/s]

 iteration 850, train cost 1.0037, valid cost 1.2495


 16%|█▌        | 876/5500 [10:20<55:11,  1.40it/s]

 iteration 875, train cost 0.9917, valid cost 1.2415


 16%|█▋        | 901/5500 [10:38<55:16,  1.39it/s]

 iteration 900, train cost 0.9806, valid cost 1.2321


 17%|█▋        | 926/5500 [10:56<54:56,  1.39it/s]

 iteration 925, train cost 0.9695, valid cost 1.2255


 17%|█▋        | 951/5500 [11:14<54:17,  1.40it/s]

 iteration 950, train cost 0.9590, valid cost 1.2182


 18%|█▊        | 976/5500 [11:31<54:29,  1.38it/s]

 iteration 975, train cost 0.9493, valid cost 1.2114


 18%|█▊        | 1001/5500 [11:49<53:50,  1.39it/s]

 iteration 1000, train cost 0.9398, valid cost 1.2049


 19%|█▊        | 1026/5500 [12:07<54:55,  1.36it/s]

 iteration 1025, train cost 0.9307, valid cost 1.1987


 19%|█▉        | 1051/5500 [12:25<54:31,  1.36it/s]

 iteration 1050, train cost 0.9220, valid cost 1.1918


 20%|█▉        | 1076/5500 [12:43<54:03,  1.36it/s]

 iteration 1075, train cost 0.9137, valid cost 1.1856


 20%|██        | 1101/5500 [13:01<53:59,  1.36it/s]

 iteration 1100, train cost 0.9057, valid cost 1.1807


 20%|██        | 1126/5500 [13:20<53:50,  1.35it/s]

 iteration 1125, train cost 0.8980, valid cost 1.1760


 21%|██        | 1151/5500 [13:38<53:17,  1.36it/s]

 iteration 1150, train cost 0.8907, valid cost 1.1702


 21%|██▏       | 1176/5500 [13:56<53:03,  1.36it/s]

 iteration 1175, train cost 0.8836, valid cost 1.1656


 22%|██▏       | 1201/5500 [14:14<52:31,  1.36it/s]

 iteration 1200, train cost 0.8769, valid cost 1.1615


 22%|██▏       | 1226/5500 [14:32<52:31,  1.36it/s]

 iteration 1225, train cost 0.8705, valid cost 1.1580


 23%|██▎       | 1251/5500 [14:50<51:10,  1.38it/s]

 iteration 1250, train cost 0.8641, valid cost 1.1538


 23%|██▎       | 1276/5500 [15:08<50:29,  1.39it/s]

 iteration 1275, train cost 0.8581, valid cost 1.1502


 24%|██▎       | 1301/5500 [15:26<50:27,  1.39it/s]

 iteration 1300, train cost 0.8524, valid cost 1.1459


 24%|██▍       | 1326/5500 [15:43<49:45,  1.40it/s]

 iteration 1325, train cost 0.8476, valid cost 1.1427


 25%|██▍       | 1351/5500 [16:01<49:52,  1.39it/s]

 iteration 1350, train cost 0.8418, valid cost 1.1385


 25%|██▌       | 1376/5500 [16:19<49:17,  1.39it/s]

 iteration 1375, train cost 0.8363, valid cost 1.1349


 25%|██▌       | 1401/5500 [16:37<49:05,  1.39it/s]

 iteration 1400, train cost 0.8327, valid cost 1.1329


 26%|██▌       | 1426/5500 [16:54<48:43,  1.39it/s]

 iteration 1425, train cost 0.8265, valid cost 1.1294


 26%|██▋       | 1451/5500 [17:12<48:27,  1.39it/s]

 iteration 1450, train cost 0.8216, valid cost 1.1265


 27%|██▋       | 1476/5500 [17:30<48:17,  1.39it/s]

 iteration 1475, train cost 0.8169, valid cost 1.1231


 27%|██▋       | 1501/5500 [17:47<48:00,  1.39it/s]

 iteration 1500, train cost 0.8124, valid cost 1.1212


 28%|██▊       | 1526/5500 [18:05<47:14,  1.40it/s]

 iteration 1525, train cost 0.8081, valid cost 1.1183


 28%|██▊       | 1551/5500 [18:23<47:13,  1.39it/s]

 iteration 1550, train cost 0.8040, valid cost 1.1158


 29%|██▊       | 1576/5500 [18:40<46:49,  1.40it/s]

 iteration 1575, train cost 0.8002, valid cost 1.1149


 29%|██▉       | 1601/5500 [18:58<46:36,  1.39it/s]

 iteration 1600, train cost 0.7966, valid cost 1.1123


 30%|██▉       | 1626/5500 [19:16<46:25,  1.39it/s]

 iteration 1625, train cost 0.7928, valid cost 1.1110


 30%|███       | 1651/5500 [19:33<46:11,  1.39it/s]

 iteration 1650, train cost 0.7892, valid cost 1.1081


 30%|███       | 1676/5500 [19:51<45:56,  1.39it/s]

 iteration 1675, train cost 0.7862, valid cost 1.1068


 31%|███       | 1701/5500 [20:09<45:22,  1.40it/s]

 iteration 1700, train cost 0.7827, valid cost 1.1041


 31%|███▏      | 1726/5500 [20:26<45:14,  1.39it/s]

 iteration 1725, train cost 0.7795, valid cost 1.1029


 32%|███▏      | 1751/5500 [20:44<44:58,  1.39it/s]

 iteration 1750, train cost 0.7765, valid cost 1.1016


 32%|███▏      | 1776/5500 [21:02<44:35,  1.39it/s]

 iteration 1775, train cost 0.7735, valid cost 1.1003


 33%|███▎      | 1801/5500 [21:20<44:25,  1.39it/s]

 iteration 1800, train cost 0.7705, valid cost 1.0983


 33%|███▎      | 1826/5500 [21:37<43:56,  1.39it/s]

 iteration 1825, train cost 0.7679, valid cost 1.0965


 34%|███▎      | 1851/5500 [21:55<43:35,  1.40it/s]

 iteration 1850, train cost 0.7652, valid cost 1.0957


 34%|███▍      | 1876/5500 [22:13<43:51,  1.38it/s]

 iteration 1875, train cost 0.7625, valid cost 1.0947


 35%|███▍      | 1901/5500 [22:31<42:46,  1.40it/s]

 iteration 1900, train cost 0.7599, valid cost 1.0931


 35%|███▌      | 1926/5500 [22:48<42:56,  1.39it/s]

 iteration 1925, train cost 0.7576, valid cost 1.0910


 35%|███▌      | 1951/5500 [23:06<42:14,  1.40it/s]

 iteration 1950, train cost 0.7551, valid cost 1.0906


 36%|███▌      | 1976/5500 [23:24<42:17,  1.39it/s]

 iteration 1975, train cost 0.7527, valid cost 1.0894


 36%|███▋      | 2001/5500 [23:41<41:52,  1.39it/s]

 iteration 2000, train cost 0.7504, valid cost 1.0876


 37%|███▋      | 2026/5500 [23:59<41:30,  1.39it/s]

 iteration 2025, train cost 0.7483, valid cost 1.0866


 37%|███▋      | 2051/5500 [24:17<41:22,  1.39it/s]

 iteration 2050, train cost 0.7463, valid cost 1.0854


 38%|███▊      | 2076/5500 [24:34<40:51,  1.40it/s]

 iteration 2075, train cost 0.7441, valid cost 1.0847


 38%|███▊      | 2101/5500 [24:52<40:46,  1.39it/s]

 iteration 2100, train cost 0.7420, valid cost 1.0838


 39%|███▊      | 2126/5500 [25:08<25:10,  2.23it/s]

 iteration 2125, train cost 0.7401, valid cost 1.0828


 39%|███▉      | 2151/5500 [25:17<18:58,  2.94it/s]

 iteration 2150, train cost 0.7382, valid cost 1.0824


 40%|███▉      | 2176/5500 [25:25<18:49,  2.94it/s]

 iteration 2175, train cost 0.7364, valid cost 1.0804


 40%|████      | 2201/5500 [25:34<18:41,  2.94it/s]

 iteration 2200, train cost 0.7347, valid cost 1.0810


 40%|████      | 2226/5500 [25:42<18:31,  2.94it/s]

 iteration 2225, train cost 0.7329, valid cost 1.0791


 41%|████      | 2251/5500 [25:50<18:20,  2.95it/s]

 iteration 2250, train cost 0.7311, valid cost 1.0798


 41%|████▏     | 2276/5500 [25:59<18:10,  2.96it/s]

 iteration 2275, train cost 0.7296, valid cost 1.0781


 42%|████▏     | 2301/5500 [26:07<18:03,  2.95it/s]

 iteration 2300, train cost 0.7277, valid cost 1.0769


 42%|████▏     | 2326/5500 [26:15<17:57,  2.94it/s]

 iteration 2325, train cost 0.7262, valid cost 1.0770


 43%|████▎     | 2351/5500 [26:24<17:45,  2.96it/s]

 iteration 2350, train cost 0.7248, valid cost 1.0755


 43%|████▎     | 2376/5500 [26:32<17:37,  2.95it/s]

 iteration 2375, train cost 0.7233, valid cost 1.0748


 44%|████▎     | 2401/5500 [26:40<17:29,  2.95it/s]

 iteration 2400, train cost 0.7218, valid cost 1.0745


 44%|████▍     | 2426/5500 [26:49<17:21,  2.95it/s]

 iteration 2425, train cost 0.7204, valid cost 1.0743


 45%|████▍     | 2451/5500 [26:57<17:13,  2.95it/s]

 iteration 2450, train cost 0.7193, valid cost 1.0739


 45%|████▌     | 2476/5500 [27:05<17:03,  2.95it/s]

 iteration 2475, train cost 0.7177, valid cost 1.0733


 45%|████▌     | 2501/5500 [27:14<16:55,  2.95it/s]

 iteration 2500, train cost 0.7162, valid cost 1.0721


 46%|████▌     | 2526/5500 [27:22<16:50,  2.94it/s]

 iteration 2525, train cost 0.7150, valid cost 1.0719


 46%|████▋     | 2551/5500 [27:30<16:41,  2.94it/s]

 iteration 2550, train cost 0.7136, valid cost 1.0715


 47%|████▋     | 2576/5500 [27:41<31:17,  1.56it/s]

 iteration 2575, train cost 0.7124, valid cost 1.0710


 47%|████▋     | 2601/5500 [27:58<34:55,  1.38it/s]

 iteration 2600, train cost 0.7114, valid cost 1.0701


 48%|████▊     | 2626/5500 [28:16<34:24,  1.39it/s]

 iteration 2625, train cost 0.7100, valid cost 1.0695


 48%|████▊     | 2651/5500 [28:34<34:07,  1.39it/s]

 iteration 2650, train cost 0.7089, valid cost 1.0696


 49%|████▊     | 2676/5500 [28:51<33:53,  1.39it/s]

 iteration 2675, train cost 0.7078, valid cost 1.0695


 49%|████▉     | 2701/5500 [29:09<33:33,  1.39it/s]

 iteration 2700, train cost 0.7066, valid cost 1.0689


 50%|████▉     | 2726/5500 [29:27<33:14,  1.39it/s]

 iteration 2725, train cost 0.7057, valid cost 1.0674


 50%|█████     | 2751/5500 [29:45<33:06,  1.38it/s]

 iteration 2750, train cost 0.7046, valid cost 1.0671


 50%|█████     | 2776/5500 [30:02<32:37,  1.39it/s]

 iteration 2775, train cost 0.7039, valid cost 1.0668


 51%|█████     | 2801/5500 [30:20<32:16,  1.39it/s]

 iteration 2800, train cost 0.7026, valid cost 1.0674


 51%|█████▏    | 2826/5500 [30:38<32:02,  1.39it/s]

 iteration 2825, train cost 0.7032, valid cost 1.0675


 52%|█████▏    | 2851/5500 [30:55<31:37,  1.40it/s]

 iteration 2850, train cost 0.7009, valid cost 1.0669


 52%|█████▏    | 2876/5500 [31:13<31:21,  1.39it/s]

 iteration 2875, train cost 0.6997, valid cost 1.0658


 53%|█████▎    | 2901/5500 [31:31<31:06,  1.39it/s]

 iteration 2900, train cost 0.6989, valid cost 1.0654


 53%|█████▎    | 2926/5500 [31:49<31:01,  1.38it/s]

 iteration 2925, train cost 0.6979, valid cost 1.0657


 54%|█████▎    | 2951/5500 [32:06<30:30,  1.39it/s]

 iteration 2950, train cost 0.6971, valid cost 1.0658


 54%|█████▍    | 2976/5500 [32:24<30:15,  1.39it/s]

 iteration 2975, train cost 0.6961, valid cost 1.0650


 55%|█████▍    | 3001/5500 [32:42<29:52,  1.39it/s]

 iteration 3000, train cost 0.6953, valid cost 1.0647


 55%|█████▌    | 3026/5500 [32:59<29:41,  1.39it/s]

 iteration 3025, train cost 0.6947, valid cost 1.0634


 55%|█████▌    | 3051/5500 [33:17<29:27,  1.39it/s]

 iteration 3050, train cost 0.6938, valid cost 1.0645


 56%|█████▌    | 3076/5500 [33:35<29:13,  1.38it/s]

 iteration 3075, train cost 0.6931, valid cost 1.0633


 56%|█████▋    | 3101/5500 [33:52<28:40,  1.39it/s]

 iteration 3100, train cost 0.6922, valid cost 1.0630


 57%|█████▋    | 3126/5500 [34:10<28:15,  1.40it/s]

 iteration 3125, train cost 0.6914, valid cost 1.0642


 57%|█████▋    | 3151/5500 [34:28<27:59,  1.40it/s]

 iteration 3150, train cost 0.6907, valid cost 1.0640


 58%|█████▊    | 3176/5500 [34:45<27:42,  1.40it/s]

 iteration 3175, train cost 0.6901, valid cost 1.0633


 58%|█████▊    | 3201/5500 [35:03<27:30,  1.39it/s]

 iteration 3200, train cost 0.6893, valid cost 1.0631


 59%|█████▊    | 3226/5500 [35:21<27:16,  1.39it/s]

 iteration 3225, train cost 0.6886, valid cost 1.0622


 59%|█████▉    | 3251/5500 [35:39<26:56,  1.39it/s]

 iteration 3250, train cost 0.6878, valid cost 1.0623


 60%|█████▉    | 3276/5500 [35:56<26:39,  1.39it/s]

 iteration 3275, train cost 0.6871, valid cost 1.0622


 60%|██████    | 3301/5500 [36:14<26:21,  1.39it/s]

 iteration 3300, train cost 0.6865, valid cost 1.0615


 60%|██████    | 3326/5500 [36:32<26:02,  1.39it/s]

 iteration 3325, train cost 0.6859, valid cost 1.0620


 61%|██████    | 3351/5500 [36:50<25:43,  1.39it/s]

 iteration 3350, train cost 0.6854, valid cost 1.0614


 61%|██████▏   | 3376/5500 [37:07<25:26,  1.39it/s]

 iteration 3375, train cost 0.6846, valid cost 1.0612


 62%|██████▏   | 3401/5500 [37:25<25:30,  1.37it/s]

 iteration 3400, train cost 0.6842, valid cost 1.0609


 62%|██████▏   | 3426/5500 [37:43<25:23,  1.36it/s]

 iteration 3425, train cost 0.6834, valid cost 1.0613


 63%|██████▎   | 3451/5500 [38:02<25:22,  1.35it/s]

 iteration 3450, train cost 0.6828, valid cost 1.0616


 63%|██████▎   | 3476/5500 [38:19<24:15,  1.39it/s]

 iteration 3475, train cost 0.6822, valid cost 1.0613


 64%|██████▎   | 3501/5500 [38:37<23:59,  1.39it/s]

 iteration 3500, train cost 0.6823, valid cost 1.0600


 64%|██████▍   | 3526/5500 [38:55<23:44,  1.39it/s]

 iteration 3525, train cost 0.6811, valid cost 1.0606


 65%|██████▍   | 3551/5500 [39:13<23:17,  1.39it/s]

 iteration 3550, train cost 0.6807, valid cost 1.0604


 65%|██████▌   | 3576/5500 [39:30<22:58,  1.40it/s]

 iteration 3575, train cost 0.6802, valid cost 1.0605


 65%|██████▌   | 3601/5500 [39:48<22:41,  1.39it/s]

 iteration 3600, train cost 0.6795, valid cost 1.0606


 66%|██████▌   | 3626/5500 [40:06<22:24,  1.39it/s]

 iteration 3625, train cost 0.6791, valid cost 1.0596


 66%|██████▋   | 3651/5500 [40:23<22:01,  1.40it/s]

 iteration 3650, train cost 0.6785, valid cost 1.0598


 67%|██████▋   | 3676/5500 [40:41<21:45,  1.40it/s]

 iteration 3675, train cost 0.6780, valid cost 1.0597


 67%|██████▋   | 3701/5500 [40:59<21:28,  1.40it/s]

 iteration 3700, train cost 0.6776, valid cost 1.0591


 68%|██████▊   | 3726/5500 [41:16<21:17,  1.39it/s]

 iteration 3725, train cost 0.6772, valid cost 1.0598


 68%|██████▊   | 3751/5500 [41:34<21:07,  1.38it/s]

 iteration 3750, train cost 0.6765, valid cost 1.0598


 69%|██████▊   | 3776/5500 [41:52<20:44,  1.39it/s]

 iteration 3775, train cost 0.6764, valid cost 1.0594


 69%|██████▉   | 3801/5500 [42:09<20:18,  1.39it/s]

 iteration 3800, train cost 0.6772, valid cost 1.0593


 70%|██████▉   | 3826/5500 [42:27<19:59,  1.40it/s]

 iteration 3825, train cost 0.6756, valid cost 1.0591


 70%|███████   | 3851/5500 [42:45<19:45,  1.39it/s]

 iteration 3850, train cost 0.6747, valid cost 1.0589


 70%|███████   | 3876/5500 [43:02<19:37,  1.38it/s]

 iteration 3875, train cost 0.6745, valid cost 1.0584


 71%|███████   | 3901/5500 [43:20<19:10,  1.39it/s]

 iteration 3900, train cost 0.6740, valid cost 1.0588


 71%|███████▏  | 3926/5500 [43:38<18:52,  1.39it/s]

 iteration 3925, train cost 0.6735, valid cost 1.0591


 72%|███████▏  | 3951/5500 [43:56<18:30,  1.40it/s]

 iteration 3950, train cost 0.6732, valid cost 1.0589


 72%|███████▏  | 3976/5500 [44:13<18:14,  1.39it/s]

 iteration 3975, train cost 0.6730, valid cost 1.0588


 73%|███████▎  | 3988/5500 [44:22<16:49,  1.50it/s]


KeyboardInterrupt: 

In [None]:
"""
### Original training
# you can keep re-running this cell if you think the cost might decrease further

cost = nn.MSELoss()

niter = 5800
# rnn_loss = 0.2372, lstm_loss = 0.2340, gru_lstm = 0.2370
for k in tqdm(range(niter)):
    net.train()
    # the network outputs the single-neuron prediction and the latents
    z, y = net(x1_train)

    # our cost
    loss = cost(z, x2_train)

    # train the network as usual
    loss.backward()
    optimizer1.step()
    optimizer1.zero_grad()
    

    if k % 50 == 0:
        with torch.no_grad():
            net.eval()
            valid_loss = cost(net(x1_valid)[0], x2_valid)
            
            print(f' iteration {k}, train cost {loss.item():.4f}, valid cost {valid_loss.item():.4f}')

""";  

## Validation from same neurons

In [None]:
test, hidden = net(x1_valid)

In [None]:
plt.plot(x2_valid[6,:,0].cpu().detach().numpy())
plt.plot(test[6,:,0].cpu().detach().numpy())

## Testing neurons from same layer

In [None]:
test, hidden = net(torch.from_numpy(avgd_normed_dataRNN[n_neurs:2*n_neurs,:]).unsqueeze(2).to(device).float())

In [None]:
test.shape

In [None]:
n_n = 15
plt.plot(test[n_n,:,0].cpu().detach().numpy())
plt.plot(avgd_normed_dataRNN[n_neurs + n_n,:])

## Testing neurons from another layer (#9)

In [None]:
test, hidden = net(torch.from_numpy(avgd_normed_dataRNN[10000:10100,:]).unsqueeze(2).to(device).float())

In [None]:
plt.plot(test[10,:,0].cpu().detach().numpy())
plt.plot(avgd_normed_dataRNN[10010,:])

# Training 9 Networks 
Each Network corresponds to a different layer in V1

In [None]:
# you can keep re-running this cell if you think the cost might decrease further

cost = nn.MSELoss()

niter = 10000
# rnn_loss = 0.2372, lstm_loss = 0.2340, gru_lstm = 0.2370
for k in tqdm(range(niter)):
    net.train()
    # the network outputs the single-neuron prediction and the latents
    z, y = net(x1_train)

    # our cost
    loss = cost(z, x2_train)

    # train the network as usual
    loss.backward()
    optimizer1.step()
    optimizer1.zero_grad()

    if k % 50 == 0:
        with torch.no_grad():
            net.eval()
            valid_loss = cost(net(x1_valid)[0], x2_valid)
            
            print(f' iteration {k}, train cost {loss.item():.4f}, valid cost {valid_loss.item():.4f}')

    