In [1]:
from torch import nn
from torch.utils.data import DataLoader, TensorDataset

from dlc_practical_prologue import generate_pair_sets
from models import *
import copy
import matplotlib.pyplot as plt
from trainer import Trainer

In [2]:
# Load data
train_input, train_target, train_classes, test_input, test_target, test_classes = generate_pair_sets(nb=1000)

# Prepare dataloader with class variable (Used to test auxiliary networks)
ds_train_class = torch.utils.data.TensorDataset(torch.cat((train_input[:, 0:1, :, :], train_input[:, 1:2, :, :]), 0), torch.cat((train_classes[:, 0], train_classes[:, 1]), 0))
ds_test_class = torch.utils.data.TensorDataset(torch.cat((test_input[:, 0:1, :, :], test_input[:, 1:2, :, :]), 0), torch.cat((test_classes[:, 0], test_classes[:, 1]), 0))
dl_train_class = torch.utils.data.DataLoader(ds_train_class, batch_size=32, shuffle=True, num_workers=4)
dl_test_class = torch.utils.data.DataLoader(ds_test_class, batch_size=32, shuffle=False, num_workers=4)

# Prepare dataloader with target variable (Used for networks without auxiliary loss)
ds_train_target = torch.utils.data.TensorDataset(train_input, train_target)
ds_test_target = torch.utils.data.TensorDataset(test_input, test_target)
dl_train_target = torch.utils.data.DataLoader(ds_train_target, batch_size=32, shuffle=True, num_workers=4)
dl_test_target = torch.utils.data.DataLoader(ds_test_target, batch_size=32, shuffle=False, num_workers=4)

# Prepare dataloader with target and class variable (Used for networks with auxiliary loss)
ds_train_all = torch.utils.data.TensorDataset(train_input, train_classes, train_target)
ds_test_all = torch.utils.data.TensorDataset(test_input, test_classes, test_target)
dl_train_all = torch.utils.data.DataLoader(ds_train_all, batch_size=32, shuffle=True, num_workers=4)
dl_test_all = torch.utils.data.DataLoader(ds_test_all, batch_size=32, shuffle=False, num_workers=4)

In [3]:
# Train and validate the Baseline Classifier (without auxiliary loss)

model = Baseline()
trainer = Trainer(
    nb_epochs=1
)
trainer.fit(model, dl_train_target, dl_test_target)

# Epoch 1/1:	 loss=3.56	 loss_val=0.72	 acc_val=65.33


In [14]:
# Train and validate the LeNet auxiliary network

lenet = LeNet()
trainer = Trainer(
    nb_epochs=10
)
trainer.fit(lenet, dl_train_class, dl_test_class)

# Epoch 1/10:	 loss=2.64	 loss_val=0.71	 acc_val=78.87
# Epoch 2/10:	 loss=0.37	 loss_val=0.31	 acc_val=89.58
# Epoch 3/10:	 loss=0.22	 loss_val=0.33	 acc_val=90.38
# Epoch 4/10:	 loss=0.15	 loss_val=0.27	 acc_val=91.27
# Epoch 5/10:	 loss=0.11	 loss_val=0.22	 acc_val=93.65
# Epoch 6/10:	 loss=0.05	 loss_val=0.22	 acc_val=93.75
# Epoch 7/10:	 loss=0.04	 loss_val=0.23	 acc_val=93.9


KeyboardInterrupt: 

In [8]:
# Train and validate the Resnet auxiliary network
resnet = ResNet(32, 3, 25)
trainer = Trainer(
    nb_epochs=10
)
trainer.fit(resnet, dl_train_class, dl_test_class)

# Epoch 1/1:	 loss=2.62	 loss_val=1.19	 acc_val=58.68


In [4]:
# Train and validate the LeNet classifier (with auxiliary loss)
auxiliary = LeNet()
model = CombinedNet(auxiliary)
trainer = Trainer(
    nb_epochs=10
)
trainer.fit(model, dl_train_all, dl_test_all)

# Epoch 1/10:	 loss=4.65	 loss_val=0.67	 acc_val=62.99
# Epoch 2/10:	 loss=1.23	 loss_val=0.55	 acc_val=71.58
# Epoch 3/10:	 loss=0.83	 loss_val=0.47	 acc_val=74.71
# Epoch 4/10:	 loss=0.62	 loss_val=0.42	 acc_val=79.79
# Epoch 5/10:	 loss=0.51	 loss_val=0.44	 acc_val=79.69
# Epoch 6/10:	 loss=0.42	 loss_val=0.36	 acc_val=83.98
# Epoch 7/10:	 loss=0.31	 loss_val=0.37	 acc_val=83.5
# Epoch 8/10:	 loss=0.23	 loss_val=0.33	 acc_val=86.33
# Epoch 9/10:	 loss=0.19	 loss_val=0.37	 acc_val=84.38
# Epoch 10/10:	 loss=0.14	 loss_val=0.38	 acc_val=85.25


In [5]:
trainer.test(model, dl_test_all)

loss_test=0.38	 acc_test=85.25


In [11]:
# Train and validate the LeNet classifier with a pretrained lenet auxiliary network
model = CombinedNet(copy.deepcopy(lenet))
trainer = Trainer(
    nb_epochs=1
)
trainer.fit(model, dl_train_all, dl_test_all)

# Epoch 1/1:	 loss=1.19	 loss_val=0.45	 acc_val=77.93
