# Fully-connected model with auxiliary loss

In [1]:
from src.dlc_practical_prologue import generate_pair_sets
from src.utils import load_class_data, load_target_data, load_all_data, print_param_count
from src.models import *
from src.trainer import Trainer
import matplotlib.pyplot as plt

In [2]:
# Import data
dl_train_all, dl_val_all, dl_test_all = load_all_data(normalize=True)

## Let's run it

In [3]:
# Check the number of parameters of the model
print("--- LinearAlpha ---")
alpha = LinearAlpha()
print_param_count(alpha)

print("\n--- LinearBeta ---")
beta = LinearBeta()
print_param_count(beta)

print("\n--- Full model ----")
model = Siamese(alpha, beta, weight_aux=0)
print_param_count(model)

--- LinearAlpha ---
Total number of parameters:     33950
Number of trainable parameters: 33950

--- LinearBeta ---
Total number of parameters:     382
Number of trainable parameters: 382

--- Full model ----
Total number of parameters:     34332
Number of trainable parameters: 34332


In [4]:
# Trainer
trainer = Trainer(nb_epochs=25, verbose=False)

results = []
for i in range(1):
    print(f"Run {i+1}/10")

    # Model
    lr = 0.001
    alpha = LinearAlpha(lr=lr)
    beta = LinearBeta(lr=lr)
    model = Siamese(alpha, beta, weight_aux=0.2, lr=lr)

    # Training
    trainer.fit(model, dl_train_all, dl_val_all, verbose=True)

    # Testing
    acc = trainer.test(model, dl_test_all, test_verbose=False, return_acc=True)
    results.extend([acc])

Run 1/10
# Epoch 1/25:	 loss=1.7	 loss_val=0.69	 acc_val=51.34
# Epoch 2/25:	 loss=1.68	 loss_val=0.7	 acc_val=51.34
# Epoch 3/25:	 loss=1.66	 loss_val=0.7	 acc_val=51.34
# Epoch 4/25:	 loss=1.65	 loss_val=0.69	 acc_val=51.34
# Epoch 5/25:	 loss=1.65	 loss_val=0.7	 acc_val=51.34
# Epoch 6/25:	 loss=1.63	 loss_val=0.69	 acc_val=51.34
# Epoch 7/25:	 loss=1.62	 loss_val=0.69	 acc_val=51.34
# Epoch 8/25:	 loss=1.63	 loss_val=0.69	 acc_val=51.34
# Epoch 9/25:	 loss=1.61	 loss_val=0.7	 acc_val=51.34
# Epoch 10/25:	 loss=1.6	 loss_val=0.69	 acc_val=51.34
# Epoch 11/25:	 loss=1.6	 loss_val=0.69	 acc_val=51.34
# Epoch 12/25:	 loss=1.58	 loss_val=0.69	 acc_val=51.34
# Epoch 13/25:	 loss=1.59	 loss_val=0.7	 acc_val=51.34
# Epoch 14/25:	 loss=1.57	 loss_val=0.69	 acc_val=51.34
# Epoch 15/25:	 loss=1.56	 loss_val=0.69	 acc_val=51.34
# Epoch 16/25:	 loss=1.56	 loss_val=0.69	 acc_val=51.34
# Epoch 17/25:	 loss=1.55	 loss_val=0.69	 acc_val=51.34
# Epoch 18/25:	 loss=1.52	 loss_val=0.69	 acc_val=51.34


In [5]:
# Print results
results = torch.Tensor(results)
print("Average accuracy:   {:.2f}".format(results.mean()))
print("Standard deviation: {:.2f}".format(results.std()))

Average accuracy:   52.64
Standard deviation: nan


## Training both parts separately

In [8]:
# Trainer
trainer = Trainer(nb_epochs=25, verbose=False)

# Data
dl_train_class, dl_val_class, dl_test_class = load_class_data(normalize=True)

results = []
for i in range(1):
    print(f"Run {i+1}/10")

    # Model
    lr = 0.001
    model = LinearAlpha(lr=lr)

    # Training
    trainer.fit(model, dl_train_class, dl_val_class, verbose=True)

    # Testing
    acc = trainer.test(model, dl_test_class, test_verbose=False, return_acc=True)
    results.extend([acc])

Run 1/10
# Epoch 1/25:	 loss=2.4	 loss_val=2.26	 acc_val=18.75
# Epoch 2/25:	 loss=2.34	 loss_val=2.21	 acc_val=22.36
# Epoch 3/25:	 loss=2.3	 loss_val=2.18	 acc_val=35.1
# Epoch 4/25:	 loss=2.26	 loss_val=2.13	 acc_val=39.66
# Epoch 5/25:	 loss=2.22	 loss_val=2.05	 acc_val=49.52
# Epoch 6/25:	 loss=2.16	 loss_val=1.98	 acc_val=51.44
# Epoch 7/25:	 loss=2.11	 loss_val=1.91	 acc_val=53.85
# Epoch 8/25:	 loss=2.04	 loss_val=1.82	 acc_val=56.73
# Epoch 9/25:	 loss=1.98	 loss_val=1.75	 acc_val=52.88
# Epoch 10/25:	 loss=1.91	 loss_val=1.69	 acc_val=52.64
# Epoch 11/25:	 loss=1.86	 loss_val=1.57	 acc_val=52.4
# Epoch 12/25:	 loss=1.8	 loss_val=1.54	 acc_val=54.81
# Epoch 13/25:	 loss=1.74	 loss_val=1.45	 acc_val=57.21
# Epoch 14/25:	 loss=1.7	 loss_val=1.42	 acc_val=55.29
# Epoch 15/25:	 loss=1.65	 loss_val=1.36	 acc_val=56.73
# Epoch 16/25:	 loss=1.63	 loss_val=1.29	 acc_val=57.93
# Epoch 17/25:	 loss=1.59	 loss_val=1.25	 acc_val=60.82
# Epoch 18/25:	 loss=1.56	 loss_val=1.23	 acc_val=60.5