In [1]:
# # header
import sys
sys.path.append(r"../")

%load_ext autoreload
%autoreload 2

In [2]:
# # built-in modules
import os
from pprint import pformat
from collections import OrderedDict
# # Torch modules
import torch
from torch.utils.data import random_split
from torchvision import transforms, datasets
# # internal imports
from prelude import startup_folders, get_device, load_dicts
from src.composer import IOR_DS, Arrow_DS, Cue_DS, Recognition_DS
from src.composer import Search_DS, Tracking_DS, Popout_DS
from src.model import AttentionModel
from src.utils import plot_all
from src.utils import build_loaders
from src.conductor import AttentionTrain


In [3]:
start_folder = r"../pretrained/mnist_v2"
results_folder, logger = startup_folders(start_folder, name=f"exp_mnist")
data_path = r"../data"

../pretrained/mnist_v2/1753528396 was created!


In [4]:
model_params = load_dicts(start_folder, "model_params")
tasks = load_dicts(start_folder, "tasks")
train_params = load_dicts(start_folder, "train_params")
DeVice, num_workers, pin_memory = get_device()
print(f"model_params: {model_params}")
print(f"tasks: {tasks}")
print(f"train_params: {train_params}")

Device set to mps
model_params: {'in_dims': [3, 96, 96], 'n_classes': 10, 'out_dim': 20, 'normalize': True, 'softness': 0.5, 'channels': [3, 16, 32, 64, 128, 128], 'residuals': False, 'kernels': 3, 'strides': 1, 'paddings': 1, 'conv_bias': True, 'conv_norms': [None, 'layer', 'layer', 'layer', 'layer'], 'conv_dropouts': 0.0, 'conv_funs': ReLU(), 'deconv_funs': Tanh(), 'deconv_norms': [None, 'layer', 'layer', 'layer', 'layer'], 'pools': [2, 2, 2, 2, 3], 'rnn_dims': [128, 64], 'rnn_bias': True, 'rnn_dropouts': 0.0, 'rnn_funs': ReLU(), 'n_tasks': 7, 'task_layers': 1, 'task_weight': True, 'task_bias': True, 'task_funs': Tanh(), 'norm_mean': [0.5, 0.5, 0.5], 'norm_std': [1.0, 1.0, 1.0], 'rnn_to_fc': False, 'trans_fun': ReLU()}
tasks: {'IOR': {'composer': 'IOR_DS', 'key': 0, 'params': {'n_digits': 3, 'n_attend': 2, 'noise': 0.25, 'overlap': 1.0}, 'datasets': ['IOR_DS', 'IOR_DS', 'IOR_DS'], 'dataloaders': [None, None, None], 'loss_w': [1.0, 1.0, 0.0], 'loss_s': [None, None], 'has_prompt': Fals

In [5]:
# # setting up the tasks
tasks['IOR']["composer"] = IOR_DS
tasks['Arrow']["composer"] = Arrow_DS
tasks['Cue']["composer"] = Cue_DS
tasks['Tracking']["composer"] = Tracking_DS
tasks['Recognition']["composer"] = Recognition_DS
tasks['Search']["composer"] = Search_DS
tasks['Popout']["composer"] = Popout_DS

tasks['IOR']["datasets"] = []
tasks['Arrow']["datasets"] = []
tasks['Cue']["datasets"] = []
tasks['Tracking']["datasets"] = []
tasks['Recognition']["datasets"] = []
tasks['Search']["datasets"] = []
tasks['Popout']["datasets"] = []

tasks['IOR']["dataloaders"] = []
tasks['Arrow']["dataloaders"] = []
tasks['Cue']["dataloaders"] = []
tasks['Tracking']["dataloaders"] = []
tasks['Recognition']["dataloaders"] = []
tasks['Search']["dataloaders"] = []
tasks['Popout']["dataloaders"] = []

tasks['IOR']["loss_s"] =  ((-1, ), (-1, ))
tasks['Arrow']["loss_s"] =  ((-1, ), (-1, ))
tasks['Cue']["loss_s"] =  ((-1, ), (-1, ))
tasks['Tracking']["loss_s"] =  ((-1, ), (-1, ))
tasks['Recognition']["loss_s"] = ((-1, ), (-1, ))
tasks['Search']["loss_s"] =  ((-1, ), (-1, ))
tasks['Popout']["loss_s"] =  ((-1, ), (-1, ))

tasks['Arrow']["params"]["directory"] = data_path

In [6]:
# datasets and dataloaders
tralid_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor())
test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())
train_ds, valid_ds = random_split(tralid_ds, (50000, 10000))
DeVice, num_workers, pin_memory = get_device()
for o in tasks:
    tasks[o]["datasets"].append(tasks[o]["composer"](train_ds, **tasks[o]["params"]))
    tasks[o]["datasets"].append(tasks[o]["composer"](valid_ds, **tasks[o]["params"]))
    tasks[o]["datasets"].append(tasks[o]["composer"](test_ds, **tasks[o]["params"]))
    tasks[o]["datasets"][1].build_valid_test()
    tasks[o]["datasets"][2].build_valid_test()
    tasks[o]["dataloaders"] = build_loaders(tasks[o]["datasets"], batch_size=train_params["batch_size"], num_workers=num_workers, pin_memory=pin_memory)

Device set to mps


In [7]:
# create a blank model
model = AttentionModel(**model_params)
conductor = AttentionTrain(model, None, None, tasks, logger, results_folder)

# load states into the model
model_dir = os.path.join(start_folder, "model" + ".pth")
assert os.path.exists(model_dir), "Could not find the model.pth in the given dir!"
model.load_state_dict(torch.load(model_dir, map_location=DeVice))

<All keys matched successfully>

In [8]:
# plotting...
plot_all(10, model, tasks, results_folder, "_test", DeVice, logger, False, "test")

In [9]:
# evaluating...
conductor.eval(DeVice, "test", False)

testing...
  Task IOR:
    CEi Loss: 0.133    CEe Loss: 0.000    Pix Err: 0.004    Att Acc: 0.831    Cls Acc: 9536/10000
  Task Arrow:
    CEi Loss: 1.192    CEe Loss: 0.023    Pix Err: 0.021    Att Acc: 0.926    Cls Acc: 9927/10000
  Task Cue:
    CEi Loss: 0.101    CEe Loss: 0.038    Pix Err: 0.002    Att Acc: 0.870    Cls Acc: 9877/10000
  Task Tracking:
    CEi Loss: 0.023    CEe Loss: 0.024    Pix Err: 0.009    Att Acc: 0.741    Cls Acc: 9945/10000
  Task Recognition:
    CEi Loss: 0.031    CEe Loss: 0.032    Pix Err: 0.009    Att Acc: 0.941    Cls Acc: 9908/10000
  Task Search:
    CEi Loss: 1.660    CEe Loss: 0.563    Pix Err: 0.002    Att Acc: 0.830    Cls Acc: 8967/10000
  Task Popout:
    CEi Loss: 0.634    CEe Loss: 0.028    Pix Err: 0.019    Att Acc: 0.911    Cls Acc: 9912/10000
