In [1]:
from scipy.io import loadmat
import numpy as np
import torch 
from torch.utils.data import random_split, DataLoader
from math import ceil
from torch import nn
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt

torch.set_default_device('cuda')
from mrnn7 import MilliesDataset, MilliesRNN
from hessianfree import HessianFree



# Load data
whole_dataset = MilliesDataset('monkey_data.mat')
dataset_size = len(whole_dataset)
train_dataset, test_dataset = random_split(whole_dataset, [401, 101])

in_dim, out_dim, trial_len = whole_dataset.dimensions() #  21  &  50
hid_dim = 100

train_dataloader = DataLoader(train_dataset, batch_size=ceil(len(train_dataset)/5), shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True)

# for generating output later
whole_dataloader = DataLoader(whole_dataset, batch_size = len(whole_dataset), shuffle=False) 



In [4]:
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import pickle 
from tqdm import tqdm

from mrnn7 import MilliesRNN
from hessianfree import HessianFree

hessian = False
hardcore = True
intermodule_connections_removed = .9

learning_rates = [0.01, 0.001, 0.0001]
batch_size = [3,4,5]
weight_decay = [0,1e-6,1e-5,1e-4]

num_epochs = 20
test_losses = {}

print("When i = 0, clipping, when i = 1, no clipping\n")

criterion1 = nn.MSELoss() 

for l in tqdm(learning_rates, desc="Searching over models..."):
    for b in batch_size:
        train_dataloader = DataLoader(train_dataset, batch_size=ceil(len(train_dataset)/b), shuffle=True)
        for w in weight_decay:
            for i in range(2):
                print("Learning rate = " + str(l) + ", batch size = " + str(1/b) + ", i = " + str(i) + ", weight decay = " + str(w) + "\n")
                model_type = f"mse_{l}_{b}_{i}_{w}"
                model = MilliesRNN(in_dim, hid_dim, out_dim, True)
                optimizer = torch.optim.Adam(model.parameters(), lr=l, weight_decay = w)
                module1 = model.h2o
                prune.random_unstructured(module1, name="weight", amount=intermodule_connections_removed)
                module2 = model.thal
                prune.random_unstructured(module2, name="weight", amount=intermodule_connections_removed)
                model.train()
                for epoch in range(num_epochs):
                    for i, (inp_batch, out_batch) in enumerate(train_dataloader):
                        optimizer.zero_grad()

                        outputs = model(inp_batch)   

                        loss = criterion1(outputs, out_batch)
                        
                        loss.backward()
                        optimizer.step()
                        
                        if(i == 0):
                            nn.utils.clip_grad_norm_(model.parameters(), 1)
                

                num_samples = len(test_dataset)
            
                model.eval()
                with torch.no_grad():
                    inp, out = next(iter(test_dataset))
                    gen_out = model(inp)
                    loss1 = criterion1(gen_out, out)

                print(f"Average test: {loss1 / num_samples}\n")
                test_losses[model_type] = loss1 / num_samples

                with torch.no_grad():
                    inp, out_true = next(iter(whole_dataloader))
                    whole_out = model(inp)

                with open(f'outputs/{model_type}.pickle', 'wb') as handle:
                    pickle.dump(whole_out, handle, protocol=pickle.HIGHEST_PROTOCOL)
        

When i = 0, clipping, when i = 1, no clipping



Searching over models...:   0%|          | 0/3 [00:00<?, ?it/s]

Learning rate = 0.01, batch size = 0.3333333333333333, i = 0, weight decay = 0



Searching over models...:   0%|          | 0/3 [05:46<?, ?it/s]


IndexError: too many indices for tensor of dimension 2

In [None]:
print(test_losses)

{'mse_0.01_3_0_0': {tensor(0.0002)}, 'mse_0.01_3_1_0': {tensor(0.0002)}, 'mse_0.01_3_0_1e-06': {tensor(0.0002)}, 'mse_0.01_3_1_1e-06': {tensor(0.0002)}, 'mse_0.01_3_0_1e-05': {tensor(0.0002)}, 'mse_0.01_3_1_1e-05': {tensor(0.0002)}, 'mse_0.01_3_0_0.0001': {tensor(0.0002)}, 'mse_0.01_3_1_0.0001': {tensor(0.0002)}, 'mse_0.01_4_0_0': {tensor(0.0002)}, 'mse_0.01_4_1_0': {tensor(0.0002)}, 'mse_0.01_4_0_1e-06': {tensor(0.0002)}, 'mse_0.01_4_1_1e-06': {tensor(0.0002)}, 'mse_0.01_4_0_1e-05': {tensor(0.0002)}, 'mse_0.01_4_1_1e-05': {tensor(0.0002)}, 'mse_0.01_4_0_0.0001': {tensor(0.0002)}, 'mse_0.01_4_1_0.0001': {tensor(0.0002)}, 'mse_0.01_5_0_0': {tensor(0.0002)}, 'mse_0.01_5_1_0': {tensor(0.0002)}, 'mse_0.01_5_0_1e-06': {tensor(0.0002)}, 'mse_0.01_5_1_1e-06': {tensor(0.0002)}, 'mse_0.01_5_0_1e-05': {tensor(0.0002)}, 'mse_0.01_5_1_1e-05': {tensor(0.0002)}, 'mse_0.01_5_0_0.0001': {tensor(0.0002)}, 'mse_0.01_5_1_0.0001': {tensor(0.0002)}, 'mse_0.001_3_0_0': {tensor(0.0002)}, 'mse_0.001_3_1_0': {

In [None]:
with open(f'outputs/test_loss_mse.pickle', 'wb') as handle:
    pickle.dump(test_losses, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
print(max(test_losses.values()))
print(min(test_losses.values()))

{tensor(0.0002)}
{tensor(0.0002)}
