In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import logging
import argparse
import traceback
from datetime import date
from copy import copy, deepcopy
from pathlib import Path

import numpy as np
import torch
import pickle

from thalamocortex.models import CTCNet
from thalamocortex.utils import make_grid, create_data_loaders, train, evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Custom class to redirect stdout to logger
class LoggerWriter:
    def __init__(self, level):
        self.level = level  # Logging level (INFO, ERROR, etc.)

    def write(self, message):
        if message.strip():  # Avoid logging empty lines
            self.level(message.strip())

    def flush(self):
        pass  # Needed for compatibility with sys.stdout

In [4]:
# hyperparameter grid for driver-type model
hyperparam_grid = {
    # data hyperparams
    "norm" : ["normalise"],
    "dataset" : ["MNIST"],
    "save_path" : ["/Users/patmccarthy/Documents/thalamocortex/data"],
    "batch_size" : [32],
    # model hyperparams
    "input_size" : [28 * 28],
    "output_size" : [10],
    "ctx_layer_size" : [32, 64],
    "thal_layer_size" : [16],
    "thalamocortical_type" : [None],
    "thal_reciprocal" : [False], 
    "thal_to_readout" : [False], 
    "thal_per_layer" : [False],
    # training hyperparams
    "lr" : [1e-6],
    "loss" : [torch.nn.CrossEntropyLoss()],
    "epochs": [800],
    "ohe_targets": [True],
    "track_loss_step": [50]
}

In [5]:
# Make parameter grid
model_param_grid = make_grid(hyperparam_grid)

In [6]:
# Set backend
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [11]:
# number of parameter combinations
num_comb = len(model_param_grid)
for hp_comb_idx, hyperparams in enumerate(model_param_grid):

    # create readable tag for saving
    if hp_comb_idx < 10:
        comb_id = f"0{hp_comb_idx}"
    else:
        comb_id = copy(hp_comb_idx)
    tag = f"{hp_comb_idx}_CTCNet"
    if hyperparams["thalamocortical_type"] is None:
        tag += "_TC_none"
    else:
        tag += f"_TC_{hyperparams['thalamocortical_type']}"
    if hyperparams["thal_reciprocal"]:
        tag += "_reciprocal"
    if hyperparams["thal_to_readout"]:
        tag += "_readout"
    if hyperparams["thal_per_layer"]:
        tag += "_per_layer"

    # create data loaders
    trainset_loader, testset_loader, metadata = create_data_loaders(dataset=hyperparams["dataset"],
                                                                    norm=hyperparams["norm"],
                                                                    save_path=hyperparams["save_path"],
                                                                    batch_size=hyperparams["batch_size"])

    # create model
    model = CTCNet(input_size=hyperparams["input_size"],
                    output_size=hyperparams["output_size"],
                    ctx_layer_size=hyperparams["ctx_layer_size"],
                    thal_layer_size=hyperparams["thal_layer_size"],
                    thalamocortical_type=hyperparams["thalamocortical_type"],
                    thal_reciprocal=hyperparams["thal_reciprocal"],
                    thal_to_readout=hyperparams["thal_to_readout"], 
                    thal_per_layer=hyperparams["thal_per_layer"])
    model.summary()

    # define loss and optimiser
    loss_fn = deepcopy(hyperparams["loss"])
    optimizer = torch.optim.Adam(model.parameters(),
                                lr = hyperparams["lr"])
    
    # train model
    train_losses_epochs, val_losses_epochs, train_topk_accs, val_topk_accs, state_dicts, train_time  = train(model=model,
                                    trainset_loader=trainset_loader,
                                    valset_loader=testset_loader,
                                    optimizer=optimizer,
                                    loss_fn=loss_fn,
                                    ohe_targets=hyperparams["ohe_targets"],
                                    num_classes=len(metadata["classes"]),
                                    num_epochs=hyperparams["epochs"],
                                    device=device,
                                    loss_track_step=hyperparams["track_loss_step"])

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Linear: 2-1                       1,040
|    └─ReLU: 2-2                         --
├─Sequential: 1-2                        --
|    └─Linear: 2-3                       25,120
|    └─ReLU: 2-4                         --
├─Sequential: 1-3                        --
|    └─Linear: 2-5                       1,056
|    └─ReLU: 2-6                         --
├─Sequential: 1-4                        --
|    └─Linear: 2-7                       330
Total params: 27,546
Trainable params: 27,546
Non-trainable params: 0
Training...
Beginning epoch 1/800
training batch 1, loss: 2.337, 32/60000 datapoints
training batch 51, loss: 2.324, 1632/60000 datapoints
training batch 101, loss: 2.324, 3232/60000 datapoints
training batch 151, loss: 2.356, 4832/60000 datapoints
training batch 201, loss: 2.332, 6432/60000 datapoints
training batch 251, loss: 2.293, 8032/60000 datapoints
training batch 301, loss: 2

KeyboardInterrupt: 