In [None]:

from scipy.io import loadmat
import numpy as np
import torch 
from torch.utils.data import random_split, DataLoader
from math import ceil
from torch import nn
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt

from mrnn7 import MilliesDataset, MilliesRNN
from hessianfree import HessianFree




In [None]:
# Loss function including firing rate regularization and weight regularization
def hardcore_loss(output, target, model_params):
    firing_reg = 1e-1
    weight_reg = 1e-5
    weight_sum = torch.zeros((1))
    for name, param in model_params: 
        if "weight" in name:
            weight_sum += torch.sum(param **  2)

    loss = torch.sum((output - target)**2) + firing_reg * torch.sum(output ** 2) + weight_reg * weight_sum
    return loss

In [None]:
# Defining muscle lengths and a weight vector that is dependent on the length
muscle_length_vec = torch.tensor([9.8, 10.8, 13.7, 6.8, 7.6, 8.7, 7.4, 16.2, 14.4, 13.8, 13.8, 25.4, 23.2, 27.9, 9.3, 13.4, 11.4, 11.4, 2.7, 3.3, 11.6, 13.2, 8.6, 17.3, 8.1, 5.9, 6.2, 6.3, 5.1, 6.4, 4.9, 2.8, 5.2, 7.4, 7.5, 8.4, 7.5, 8.0, 8.4, 7.5, 6.5, 6.3, 7.2, 7.0, 6.8, 5.9, 5.4, 6.8, 5.5, 7.1])
muscle_weight_vec = (1/muscle_length_vec) * (2.5*torch.min(muscle_length_vec))

# Same as hardcore loss but weights the loss assigned to each muscle by the previously defined vector
def hardcore_loss_weighted(output, target, model_params):
    firing_reg = 1e-1
    weight_reg = 1e-5
    weight_sum = torch.zeros((1))
    for name, param in model_params: 
        if "weight" in name:
            weight_sum += torch.sum(param **  2)
    
    target_diff_sum = torch.sum((muscle_weight_vec*(output - target))**2)

    loss = target_diff_sum + firing_reg * torch.sum(output ** 2) + weight_reg * weight_sum
    return loss

In [None]:
# Load data
whole_dataset = MilliesDataset('monkey_data.mat')
dataset_size = len(whole_dataset)
train_dataset, test_dataset = random_split(whole_dataset, [401, 101])

in_dim, out_dim, trial_len = whole_dataset.dimensions() #  21  &  50
hid_dim = 100

train_dataloader = DataLoader(train_dataset, batch_size=ceil(len(train_dataset)/5), shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True)

# for generating output later
whole_dataloader = DataLoader(whole_dataset, batch_size = len(whole_dataset), shuffle=False) 


In [None]:
# Training pipeline


# Flags for different hyperparameters (Current settings were used to train the final model)
learning_rate = 0.0005
hessian = False
hardcore = False
intermodule_connections_removed = .9

num_epochs = 50
training_loss = []


# Add sparsity between layers
model = MilliesRNN(in_dim, hid_dim, out_dim, True)
module1 = model.h2o
prune.random_unstructured(module1, name="weight", amount=intermodule_connections_removed)
module2 = model.thal
prune.random_unstructured(module2, name="weight", amount=intermodule_connections_removed)


if hardcore:
    criterion1 = hardcore_loss
else:
    criterion1 = nn.MSELoss() 

if hessian:
    optimizer = HessianFree(model.parameters(), use_gnm=True, verbose=True)
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



model.train()
for epoch in range(num_epochs):
    
    for i, (inp_batch, out_batch) in enumerate(train_dataloader):

        optimizer.zero_grad()

        outputs = model(inp_batch)   

        if hardcore:
            loss = criterion1(outputs, out_batch, model.named_parameters())
        else:
            loss = criterion1(outputs, out_batch)

        def closure():
            gen_output = model(inp_batch)
            loss = criterion1(gen_output, out_batch, model.named_parameters())
            loss.backward(create_graph=True)
            return loss, gen_output
        
        if hessian:
            optimizer.step(closure, M_inv=None)
        else: # gradient descent
            loss.backward()
            optimizer.step()

        training_loss.append(loss.item())
        print(
            f"Epoch [{epoch + 1}/{num_epochs}], "
            f"Batch [{i + 1}], "
            f"Loss: {loss.item():.4f}"
        )
        



In [None]:
# Plot training loss 
plt.plot(training_loss)
plt.xlabel("Epoch")
plt.ylabel("Loss")

In [None]:
# gets test loss using both criterion

num_correct = 0
num_samples = len(test_dataset)

model.eval()
loss = 0
with torch.no_grad():
    inp, out_true = next(iter(test_dataloader))
    loss1 = criterion1(out, out_true)

print(f"Average test loss (mse): {loss1 / num_samples}")

In [None]:
# JUST USED TO SAVE DATA TO LOAD INTO GUI LATER

import pickle 
import os
import pandas as pd

with torch.no_grad():
    inp, out_true = next(iter(whole_dataloader))
    out_neg = model(inp, d = (-1.0, 1.0))
    with open('../model_outputs/lesions/out_neg.pickle', 'wb') as handle:
        pickle.dump(out_neg, handle, protocol=pickle.HIGHEST_PROTOCOL)
    out_0 = model(inp, d = (0.0, 0.0))
    with open('../model_outputs/lesions/out_0.pickle', 'wb') as handle:
        pickle.dump(out_0, handle, protocol=pickle.HIGHEST_PROTOCOL)
    out_25 = model(inp, d = (.25, .5))
    with open('../model_outputs/lesions/out_25.pickle', 'wb') as handle:
        pickle.dump(out_25, handle, protocol=pickle.HIGHEST_PROTOCOL)
    out_50 = model(inp, d = (.5, .5))
    with open('../model_outputs/lesions/out_50.pickle', 'wb') as handle:
        pickle.dump(out_50, handle, protocol=pickle.HIGHEST_PROTOCOL)
    out_75 = model(inp, d = (1.0, 1.0))
    with open('../model_outputs/lesions/out_75.pickle', 'wb') as handle:
        pickle.dump(out_75, handle, protocol=pickle.HIGHEST_PROTOCOL)
    module2 = model.thal
    prune.random_unstructured(module2, name="weight", amount=.5)
    pruned_05 = model(inp)
    with open('../model_outputs/lesions/pruned_05.pickle', 'wb') as handle:
        pickle.dump(pruned_05, handle, protocol=pickle.HIGHEST_PROTOCOL)
    module2 = model.thal
    prune.random_unstructured(module2, name="weight", amount=.5)
    pruned_025 = model(inp)
    with open('../model_outputs/lesions/pruned_025.pickle', 'wb') as handle:
        pickle.dump(pruned_05, handle, protocol=pickle.HIGHEST_PROTOCOL)
    module2 = model.thal
    prune.random_unstructured(module2, name="weight", amount=.5)
    pruned_0125 = model(inp)
    with open('../model_outputs/lesions/pruned_0125.pickle', 'wb') as handle:
        pickle.dump(pruned_05, handle, protocol=pickle.HIGHEST_PROTOCOL)