# Fully-connected model with auxiliary loss

In [1]:
from src.dlc_practical_prologue import generate_pair_sets
from src.utils import load_class_data, load_target_data, load_all_data, print_param_count
from src.models import *
from src.trainer import Trainer
import matplotlib.pyplot as plt

In [2]:
# Import data
dl_train_all, dl_val_all, dl_test_all = load_all_data(normalize=True)

## Let's run it

In [3]:
# Check the number of parameters of the model
print("--- LinearAlpha ---")
alpha = LinearAlpha()
print_param_count(alpha)

print("\n--- LinearBeta ---")
beta = LinearBeta()
print_param_count(beta)

print("\n--- Full model ----")
model = Siamese(alpha, beta, weight_aux=0)
print_param_count(model)

--- LinearAlpha ---
Total number of parameters:     33950
Number of trainable parameters: 33950

--- LinearBeta ---
Total number of parameters:     382
Number of trainable parameters: 382

--- Full model ----
Total number of parameters:     34332
Number of trainable parameters: 34332


In [5]:
results = []
for i in range(10):
    print(f"Run {i+1}/10")

    # Trainer
    trainer = Trainer(nb_epochs=25, verbose=False, run='fc_aux_argmax')

    # Import data
    dl_train_all, dl_val_all, dl_test_all = load_all_data(normalize=True)

    # Model
    alpha = LinearAlpha()
    beta = LinearBeta(label_encoded=False)
    model = Siamese(alpha, beta, weight_aux=0.5, 
                    softmax=False, 
                    argmax=False, strategy='sum')

    # Training
    trainer.fit(model, dl_train_all, dl_val_all, verbose=False)

    # Testing
    acc = trainer.test(model, dl_test_all, test_verbose=True, return_acc=True)
    results.extend([acc])

Run 1/10
loss_test=0.66	 acc_test=87.79
Run 2/10
loss_test=0.43	 acc_test=89.75
Run 3/10
loss_test=0.52	 acc_test=86.43
Run 4/10
loss_test=0.45	 acc_test=88.18
Run 5/10
loss_test=0.62	 acc_test=85.35
Run 6/10
loss_test=0.74	 acc_test=84.57
Run 7/10
loss_test=0.73	 acc_test=87.7
Run 8/10
loss_test=0.65	 acc_test=86.04
Run 9/10
loss_test=0.55	 acc_test=86.82
Run 10/10
loss_test=0.75	 acc_test=84.18


In [6]:
# Print results
results = torch.Tensor(results)
print("Average accuracy:   {:.2f}".format(results.mean()))
print("Standard deviation: {:.2f}".format(results.std()))

Average accuracy:   86.68
Standard deviation: 1.73
