In [None]:
import copy
import pickle
from pprint import pprint
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps as cm
from thalamocortex.utils import create_data_loaders, train_thalreadout
from thalamocortex.models import CTCNet, CTCNetThalReadout
from sklearn.decomposition import PCA
from scipy.special import softmax

  from .autonotebook import tqdm as notebook_tqdm


load trained model

In [2]:
model_path = Path("/Users/patmccarthy/Documents/thalamocortex/results/21_02_24_mod1_leftrightmnist_gridsearch2/0_CTCNet_TC_multi_pre_activation_reciprocal/model.pth")

In [3]:
model = torch.load(model_path)

load model and training report

In [4]:
# training progress and hyperparameter paths
learning_path = Path("/Users/patmccarthy/Documents/thalamocortex/results/21_02_24_mod1_leftrightmnist_gridsearch2/0_CTCNet_TC_multi_pre_activation_reciprocal/learning.pkl")
params_path = Path("/Users/patmccarthy/Documents/thalamocortex/results/21_02_24_mod1_leftrightmnist_gridsearch2/0_CTCNet_TC_multi_pre_activation_reciprocal/hyperparams.pkl")

In [5]:
# load training progress
with open(learning_path, "rb") as handle:
    learning = pickle.load(handle)
    print(f"{learning.keys()=}")
results = {"val_losses": learning["val_losses"],
                        "train_losses": learning["train_losses"],
                        "train_time": learning["train_time"],
                        "state_dicts": learning["state_dicts"]}

learning.keys()=dict_keys(['train_losses', 'val_losses', 'final_val_losses', 'state_dicts', 'train_time'])


In [6]:
# load hyperparams
with open(params_path, "rb") as handle:
    hp = pickle.load(handle)

In [7]:
hp

{'norm': 'normalise',
 'dataset': 'LeftRightMNIST',
 'save_path': '/Users/patmccarthy/Documents/thalamocortex/data',
 'batch_size': 32,
 'input_size': 1568,
 'output_size': 10,
 'ctx_layer_size': 64,
 'thal_layer_size': 16,
 'thalamocortical_type': 'multi_pre_activation',
 'thal_reciprocal': True,
 'thal_to_readout': False,
 'thal_per_layer': False,
 'lr': 5e-06,
 'loss': CrossEntropyLoss(),
 'epochs': 800,
 'ohe_targets': True,
 'track_loss_step': 50}

In [8]:
# instantiate model
model = CTCNet(input_size=hp["input_size"],
                output_size=hp["output_size"],
                ctx_layer_size=hp["ctx_layer_size"],
                thal_layer_size=hp["thal_layer_size"],
                thalamocortical_type=hp["thalamocortical_type"],
                thal_reciprocal=hp["thal_reciprocal"],
                thal_to_readout=hp["thal_to_readout"], 
                thal_per_layer=hp["thal_per_layer"])

In [9]:
# get model trained to specified epoch
epoch = 200
weights = results["state_dicts"][0]

In [10]:
weights

OrderedDict([('thal.0.weight',
              tensor([[-0.0165, -0.0100,  0.0353,  ...,  0.0268,  0.0541, -0.0005],
                      [ 0.0515, -0.0023,  0.0343,  ..., -0.0538,  0.0523,  0.0534],
                      [ 0.0336,  0.0677,  0.0768,  ...,  0.0459, -0.0273, -0.0043],
                      ...,
                      [-0.0495,  0.0871, -0.0480,  ...,  0.0597,  0.0214,  0.0852],
                      [ 0.0778,  0.0143,  0.1628,  ...,  0.0449, -0.0094,  0.0383],
                      [-0.0070,  0.0341,  0.0703,  ...,  0.0381, -0.0492, -0.0921]])),
             ('thal.0.bias',
              tensor([ 0.0593,  0.0063, -0.0402,  0.0679,  0.2337, -0.0203,  0.1267,  0.0372,
                      -0.0227, -0.0600,  0.0614,  0.0982,  0.0622, -0.0260,  0.1111,  0.0560])),
             ('thal_to_ctx1_projections.weight',
              tensor([[-0.2014, -0.2006, -0.2079,  ...,  0.0336, -0.1844,  0.1968],
                      [ 0.1322, -0.2155,  0.0631,  ..., -0.2208,  0.1441,  0.0158]

In [11]:
# set model weights
model.load_state_dict(weights)

<All keys matched successfully>

create new model with thalamic readout

In [12]:
hp_model_thal = copy.deepcopy(hp)
hp_model_thal["thal_output_size"] = 2

In [13]:
hp_model_thal

{'norm': 'normalise',
 'dataset': 'LeftRightMNIST',
 'save_path': '/Users/patmccarthy/Documents/thalamocortex/data',
 'batch_size': 32,
 'input_size': 1568,
 'output_size': 10,
 'ctx_layer_size': 64,
 'thal_layer_size': 16,
 'thalamocortical_type': 'multi_pre_activation',
 'thal_reciprocal': True,
 'thal_to_readout': False,
 'thal_per_layer': False,
 'lr': 5e-06,
 'loss': CrossEntropyLoss(),
 'epochs': 800,
 'ohe_targets': True,
 'track_loss_step': 50,
 'thal_output_size': 2}

In [17]:
model_thal = CTCNetThalReadout(input_size=hp_model_thal["input_size"],
                               ctx_output_size=hp_model_thal["output_size"],
                               thal_output_size=hp_model_thal["thal_output_size"],
                               ctx_layer_size=hp_model_thal["ctx_layer_size"],
                               thal_layer_size=hp_model_thal["thal_layer_size"],
                               thalamocortical_type=hp_model_thal["thalamocortical_type"],
                               thal_reciprocal=hp_model_thal["thal_reciprocal"],
                               thal_to_readout=hp_model_thal["thal_to_readout"], 
                               thal_per_layer=hp_model_thal["thal_per_layer"])

get training dataset

In [18]:
dataset_path = "/Users/patmccarthy/Documents/thalamocortex/data/BinaryMNIST/train.pkl"

In [16]:
trainset_loader, testset_loader, metadata = create_data_loaders("BinaryMNIST", "normalise", 32, "/Users/patmccarthy/Documents/thalamocortex/data")

initialise backbone with pretrained weights

In [19]:
for name, params in model_thal.named_parameters():
    print(f"{name=}, {params=}")

name='thal.0.weight', params=Parameter containing:
tensor([[ 0.0226,  0.0611,  0.0532,  ...,  0.0735,  0.0776, -0.0179],
        [-0.0211,  0.0039, -0.0306,  ..., -0.0505, -0.0006,  0.0108],
        [ 0.0328,  0.0589,  0.0102,  ..., -0.0352, -0.0563,  0.0786],
        ...,
        [-0.0606,  0.0755,  0.0572,  ...,  0.0082,  0.0316, -0.0551],
        [ 0.0319,  0.0520,  0.0201,  ...,  0.0106,  0.0340, -0.0004],
        [ 0.0473,  0.0835,  0.0253,  ...,  0.0035, -0.0504, -0.0875]],
       requires_grad=True)
name='thal.0.bias', params=Parameter containing:
tensor([ 0.0014, -0.0253,  0.0332, -0.0415, -0.0770, -0.0382,  0.0804, -0.0061,
         0.0575,  0.0275, -0.0569, -0.0362,  0.0389,  0.0732,  0.0527,  0.0107],
       requires_grad=True)
name='thal_to_ctx1_projections.weight', params=Parameter containing:
tensor([[-0.1390,  0.0477, -0.0298,  ...,  0.1766, -0.0886,  0.1557],
        [ 0.0081,  0.0120, -0.2251,  ..., -0.2272,  0.0362,  0.0099],
        [ 0.0028, -0.1992, -0.1320,  ..., 

fine tune

In [20]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [21]:
# define loss and optimiser
loss_fn = copy.deepcopy(hp_model_thal["loss"])
optimizer = torch.optim.Adam(model.parameters(),
                            lr = hp_model_thal["lr"])

In [22]:
# train model
train_losses, val_losses, state_dicts, train_time = train_thalreadout(model=model,
                                                                                trainset_loader=trainset_loader,
                                                                                valset_loader=testset_loader,
                                                                                optimizer=optimizer,
                                                                                loss_fn=loss_fn,
                                                                                ohe_targets=hp_model_thal["ohe_targets"],
                                                                                num_classes=len(metadata["classes"]),
                                                                                num_epochs=hp_model_thal["epochs"],
                                                                                device=device,
                                                                                loss_track_step=hp_model_thal["track_loss_step"],
                                                                                get_state_dict=True)

TypeError: train_one_epoch_thalreadout() got an unexpected keyword argument 'trainset_loader'