In [33]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
import pickle
import json
import os
from copy import deepcopy
from pathlib import Path
import torch
from models import CTCNet
from utils import create_data_loaders, train, evaluate

In [35]:
# Set backend
print("Setting backend.")
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"Using {device} device.")

Setting backend.
Using cpu device.


In [38]:
save_path = "/Users/patmccarthy/Documents/thalamocortex/data"
hyperparams = {
    # data hyperparams
    "norm" : "normalise",
    "dataset" : "MNIST",
    "save_path" : "/Users/patmccarthy/Documents/thalamocortex/data",
    "batch_size" : 32,
    # model hyperparams
    "input_size" : 28 * 28,
    "output_size" : 10,
    "ctx_layer_size" : 128,
    "thal_layer_size" : 64,
    "thalamocortical_type" : "add",
    "thal_reciprocal" : False, # True or False
    "thal_to_readout" : True, # True or False
    "thal_per_layer" : False, # if no, mixing from cortical layers
    # training hyperparams
    "lr" : 0.001,
    "loss" : torch.nn.CrossEntropyLoss(),
    "epochs": 1,
    "ohe_targets": True,
    "loss_track_step": 50,
}

In [39]:
# create data loaders
trainset_loader, testset_loader, metadata = create_data_loaders(dataset=hyperparams["dataset"],
                                                                        norm=hyperparams["norm"],
                                                                        save_path=hyperparams["save_path"],
                                                                        batch_size=hyperparams["batch_size"])

In [40]:
# create model
model = CTCNet(input_size=hyperparams["input_size"],
                output_size=hyperparams["output_size"],
                ctx_layer_size=hyperparams["ctx_layer_size"],
                thal_layer_size=hyperparams["thal_layer_size"],
                thalamocortical_type=hyperparams["thalamocortical_type"],
                thal_reciprocal=hyperparams["thal_reciprocal"],
                thal_to_readout=hyperparams["thal_to_readout"], 
                thal_per_layer=hyperparams["thal_per_layer"])
model.summary()

Layer (type:depth-idx)                   Param #
├─Sequential: 1-1                        --
|    └─Linear: 2-1                       16,448
|    └─ReLU: 2-2                         --
├─Linear: 1-2                            50,960
├─Linear: 1-3                            8,320
├─Sequential: 1-4                        --
|    └─Linear: 2-3                       100,480
|    └─ReLU: 2-4                         --
├─Sequential: 1-5                        --
|    └─Linear: 2-5                       16,512
|    └─ReLU: 2-6                         --
├─Sequential: 1-6                        --
|    └─Linear: 2-7                       1,290
Total params: 194,010
Trainable params: 194,010
Non-trainable params: 0


In [41]:
# define loss and optimiser
loss_fn = deepcopy(hyperparams["loss"])
optimizer = torch.optim.Adam(model.parameters(),
                             lr = hyperparams["lr"])

In [42]:
# train model
train_losses, val_losses, train_time = train(model=model,
                                 trainset_loader=trainset_loader,
                                 valset_loader=testset_loader,
                                 optimizer=optimizer,
                                 loss_fn=loss_fn,
                                 ohe_targets=hyperparams["ohe_targets"],
                                 num_classes=len(metadata["classes"]),
                                 num_epochs=hyperparams["epochs"],
                                 device=device,
                                 loss_track_step=hyperparams["loss_track_step"])

Training...
Beginning epoch 1/1
training batch 1, loss: 2.288, 32/60000 datapoints
training batch 51, loss: 0.941, 1632/60000 datapoints
training batch 101, loss: 0.936, 3232/60000 datapoints
training batch 151, loss: 0.212, 4832/60000 datapoints
training batch 201, loss: 0.446, 6432/60000 datapoints
training batch 251, loss: 0.313, 8032/60000 datapoints
training batch 301, loss: 0.438, 9632/60000 datapoints
training batch 351, loss: 0.500, 11232/60000 datapoints


KeyboardInterrupt: 

In [9]:
# evaluate model
losses = evaluate(model=model,
                  data_loader=testset_loader,
                  optimizer=optimizer,
                  loss_fn=loss_fn,
                  ohe_targets=hyperparams["ohe_targets"],
                  num_classes=len(metadata["classes"]),
                  device=device,
                  loss_track_step=200)

validation batch 1, loss: 0.183, 32/10000 datapoints
validation batch 201, loss: 0.225, 6432/10000 datapoints


In [37]:
# Save model
save_path_this_model = Path(save_path, "model0_05_01_24")
if not os.path.exists(save_path_this_model):
    os.mkdir(save_path_this_model)
print("Saving...")
# model
torch.save(model.state_dict(), Path(f"{save_path_this_model}", "model.pth"))
# hyperparams
with open(Path(f"{save_path_this_model}", "hyperparams.pkl"), "wb") as handle:
    pickle.dump(hyperparams, handle)
# learning progress
training_stats = {"train_losses": train_losses,
                  "val_losses": val_losses,
                  "final_val_losses": losses,
                  "train_time": train_time}
with open(Path(f"{save_path_this_model}", "learning.pkl"), "wb") as handle:
    pickle.dump(training_stats, handle)
print("Done saving.")

Saving...
Done saving.


In [None]:
model()

In [18]:
X, y = next(iter(trainset_loader))
print(X.shape)
print(y.shape)

torch.Size([32, 1, 28, 28])
torch.Size([32])


In [23]:
y_est = model(X).detach().numpy()

In [20]:
y

tensor([1, 0, 1, 2, 0, 6, 7, 1, 3, 4, 5, 5, 6, 9, 8, 5, 4, 9, 5, 4, 9, 2, 7, 1,
        6, 3, 1, 1, 7, 9, 1, 2])

In [27]:
import numpy as np
np.argmax(y_est, axis=1)

array([1, 0, 1, 4, 0, 6, 7, 1, 3, 4, 7, 7, 6, 9, 8, 5, 4, 9, 5, 4, 9, 2,
       7, 1, 3, 3, 1, 1, 7, 9, 1, 2])

In [28]:
output.shape

NameError: name 'output' is not defined

In [24]:
np.sum(output[0, :].detach().numpy())

1.0