In [1]:
from scipy.io import loadmat
import numpy as np
import torch 
from torch.utils.data import random_split, DataLoader
from math import ceil
from torch import nn
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt

device = torch.device("cuda:1")
g_cuda = torch.Generator(device='cuda:1')
from mrnn7 import MilliesDataset, MilliesRNN
from hessianfree import HessianFree

# Load data
whole_dataset = MilliesDataset('monkey_data.mat')
dataset_size = len(whole_dataset)
train_dataset, test_dataset = random_split(whole_dataset, [401, 101])

in_dim, out_dim, trial_len = whole_dataset.dimensions() #  21  &  50
hid_dim = 100

train_dataloader = DataLoader(train_dataset, batch_size=ceil(len(train_dataset)/5), shuffle=False, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True, pin_memory=True)
# for generating output later
whole_dataloader = DataLoader(whole_dataset, batch_size = len(whole_dataset), shuffle=False) 



In [2]:
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import pickle 
from tqdm import tqdm

from mrnn7 import MilliesRNN
from hessianfree import HessianFree

hessian = False
hardcore = True
intermodule_connections_removed = .9

learning_rates = [0.01, 0.001, 0.0001]
batch_size = [3,4,5]
weight_decay = [0,1e-6,1e-5,1e-4]

num_epochs = 20
test_losses = {}

print("When i = 0, clipping, when i = 1, no clipping\n")

criterion1 = nn.MSELoss() 

for l in tqdm(learning_rates, desc="Searching over models..."):
    for b in batch_size:
        train_dataloader = DataLoader(train_dataset, batch_size=ceil(len(train_dataset)/b), shuffle=True)
        for w in weight_decay:
            for i in range(2):
                print("Learning rate = " + str(l) + ", batch size = " + str(1/b) + ", i = " + str(i) + ", weight decay = " + str(w) + "\n")
                model_type = f"mse_{l}_{b}_{i}_{w}"
                model = MilliesRNN(in_dim, hid_dim, out_dim, True)
                model.to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=l, weight_decay = w)
                module1 = model.h2o
                prune.random_unstructured(module1, name="weight", amount=intermodule_connections_removed)
                module2 = model.thal
                prune.random_unstructured(module2, name="weight", amount=intermodule_connections_removed)
                model.train()
                for epoch in range(num_epochs):
                    for i, (inp_batch, out_batch) in enumerate(train_dataloader):
                        inp_batch = inp_batch.to(device)
                        out_batch = out_batch.to(device)
                        optimizer.zero_grad()

                        outputs = model(inp_batch)   

                        loss = criterion1(outputs, out_batch)
                        
                        loss.backward()
                        optimizer.step()
                        
                        if(i == 0):
                            nn.utils.clip_grad_norm_(model.parameters(), 1)
                

                num_samples = len(test_dataset)
            
                model.eval()
                with torch.no_grad():
                    inp, out = next(iter(test_dataloader))
                    gen_out = model(inp.to(device))
                    loss1 = criterion1(gen_out, out.to(device))

                
                test_losses[model_type] = (loss1.item() / num_samples)
                print(f"Average test: {test_losses[model_type]}\n")

                with torch.no_grad():
                    inp, out_true = next(iter(whole_dataloader))
                    whole_out = model(inp.to(device))

                with open(f'tuning_outputs2/{model_type}.pickle', 'wb') as handle:
                    pickle.dump(whole_out, handle, protocol=pickle.HIGHEST_PROTOCOL)
        

When i = 0, clipping, when i = 1, no clipping



Searching over models...:   0%|                                                                                                                | 0/3 [00:00<?, ?it/s]

Learning rate = 0.01, batch size = 0.3333333333333333, i = 0, weight decay = 0

Average test: 0.00011547772896171797

Learning rate = 0.01, batch size = 0.3333333333333333, i = 1, weight decay = 0

Average test: 0.00012468100331797458

Learning rate = 0.01, batch size = 0.3333333333333333, i = 0, weight decay = 1e-06

Average test: 0.00011389508917190061

Learning rate = 0.01, batch size = 0.3333333333333333, i = 1, weight decay = 1e-06

Average test: 0.00012245388830652332

Learning rate = 0.01, batch size = 0.3333333333333333, i = 0, weight decay = 1e-05

Average test: 0.00012449977347756377

Learning rate = 0.01, batch size = 0.3333333333333333, i = 1, weight decay = 1e-05

Average test: 0.00011445352309706188

Learning rate = 0.01, batch size = 0.3333333333333333, i = 0, weight decay = 0.0001

Average test: 0.00012295301265940808

Learning rate = 0.01, batch size = 0.3333333333333333, i = 1, weight decay = 0.0001

Average test: 0.00013399424750616055

Learning rate = 0.01, batch si

Searching over models...:  33%|█████████████████████████████████▎                                                                  | 1/3 [34:31<1:09:03, 2071.75s/it]

Learning rate = 0.001, batch size = 0.3333333333333333, i = 0, weight decay = 0

Average test: 0.00011370313529035832

Learning rate = 0.001, batch size = 0.3333333333333333, i = 1, weight decay = 0

Average test: 7.960990550789503e-05

Learning rate = 0.001, batch size = 0.3333333333333333, i = 0, weight decay = 1e-06

Average test: 6.709510626474229e-05

Learning rate = 0.001, batch size = 0.3333333333333333, i = 1, weight decay = 1e-06

Average test: 6.687808966282571e-05

Learning rate = 0.001, batch size = 0.3333333333333333, i = 0, weight decay = 1e-05

Average test: 0.0001239498597708079

Learning rate = 0.001, batch size = 0.3333333333333333, i = 1, weight decay = 1e-05

Average test: 9.318949230531655e-05

Learning rate = 0.001, batch size = 0.3333333333333333, i = 0, weight decay = 0.0001

Average test: 0.00011842954203043834

Learning rate = 0.001, batch size = 0.3333333333333333, i = 1, weight decay = 0.0001

Average test: 0.00016237070581110396

Learning rate = 0.001, batc

Searching over models...:  67%|██████████████████████████████████████████████████████████████████▋                                 | 2/3 [1:08:46<34:21, 2061.57s/it]

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 0, weight decay = 0

Average test: 0.00017771377495609888

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 1, weight decay = 0

Average test: 0.00017872264627182837

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 0, weight decay = 1e-06

Average test: 0.00017957743441704475

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 1, weight decay = 1e-06

Average test: 0.000184317810995744

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 0, weight decay = 1e-05

Average test: 0.00017901633561837792

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 1, weight decay = 1e-05

Average test: 0.00018464036212109103

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 0, weight decay = 0.0001

Average test: 0.0001865277749181974

Learning rate = 0.0001, batch size = 0.3333333333333333, i = 1, weight decay = 0.0001

Average test: 0.00018219963306247597

Learning rate = 0

Searching over models...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [1:42:54<00:00, 2058.07s/it]


In [3]:
print(test_losses)

{'mse_0.01_3_0_0': 0.00011547772896171797, 'mse_0.01_3_1_0': 0.00012468100331797458, 'mse_0.01_3_0_1e-06': 0.00011389508917190061, 'mse_0.01_3_1_1e-06': 0.00012245388830652332, 'mse_0.01_3_0_1e-05': 0.00012449977347756377, 'mse_0.01_3_1_1e-05': 0.00011445352309706188, 'mse_0.01_3_0_0.0001': 0.00012295301265940808, 'mse_0.01_3_1_0.0001': 0.00013399424750616055, 'mse_0.01_4_0_0': 0.00015690824994356325, 'mse_0.01_4_1_0': 0.00011097081927674832, 'mse_0.01_4_0_1e-06': 0.0001203827098897188, 'mse_0.01_4_1_1e-06': 0.00011703552733553518, 'mse_0.01_4_0_1e-05': 0.00011513527888472718, 'mse_0.01_4_1_1e-05': 0.00011320513590137557, 'mse_0.01_4_0_0.0001': 0.00011400774231936672, 'mse_0.01_4_1_0.0001': 0.0001130196090676997, 'mse_0.01_5_0_0': 0.00011268816888332367, 'mse_0.01_5_1_0': 0.00011482547120292588, 'mse_0.01_5_0_1e-06': 0.00014910891209498492, 'mse_0.01_5_1_1e-06': 0.00011404695929867206, 'mse_0.01_5_0_1e-05': 0.00014103119178573683, 'mse_0.01_5_1_1e-05': 0.0001192190822693381, 'mse_0.01_

In [11]:
with open(f'tuning_outputs2/test_loss_mse.pickle', 'wb') as handle:
    pickle.dump(test_losses, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
asc_sorted_losses = sorted(test_losses.items(), key=lambda x:x[1])
for i,k in enumerate(asc_sorted_losses):
    print(k)
    if i > 10:
        break

('mse_0.001_5_0_1e-05', 6.016962562162097e-05)
('mse_0.001_4_0_0', 6.263819292630299e-05)
('mse_0.001_5_1_0', 6.518416604635739e-05)
('mse_0.001_3_1_1e-06', 6.687808966282571e-05)
('mse_0.001_3_0_1e-06', 6.709510626474229e-05)
('mse_0.001_4_0_0.0001', 6.748079829434358e-05)
('mse_0.001_4_1_0.0001', 6.844522431492805e-05)
('mse_0.001_5_0_0', 7.662532094976689e-05)
('mse_0.001_3_1_0', 7.960990550789503e-05)
('mse_0.001_5_0_1e-06', 8.088161265200908e-05)
('mse_0.001_4_1_1e-06', 8.42728808817297e-05)
('mse_0.001_4_1_0', 9.043794125318527e-05)


In [13]:
desc_sorted_losses = sorted(test_losses.items(), key=lambda x:x[1], reverse=True)
for i,k in enumerate(desc_sorted_losses):
    print(k)
    if i > 10:
        break

('mse_0.0001_3_0_0.0001', 0.0001865277749181974)
('mse_0.0001_3_1_1e-05', 0.00018464036212109103)
('mse_0.0001_3_1_1e-06', 0.000184317810995744)
('mse_0.0001_3_1_0.0001', 0.00018219963306247597)
('mse_0.0001_3_0_1e-06', 0.00017957743441704475)
('mse_0.0001_3_0_1e-05', 0.00017901633561837792)
('mse_0.0001_3_1_0', 0.00017872264627182837)
('mse_0.0001_4_1_0.0001', 0.00017871300108952096)
('mse_0.0001_3_0_0', 0.00017771377495609888)
('mse_0.0001_5_1_0.0001', 0.0001774615617376743)
('mse_0.0001_4_0_1e-05', 0.00017670761461895291)
('mse_0.0001_4_0_0', 0.00017661514627461387)
