# Fully-connected model without auxiliary loss

In [4]:
from src.dlc_practical_prologue import generate_pair_sets
from src.utils import load_class_data, load_target_data, load_all_data, print_param_count
from src.models import *
from src.trainer import Trainer
import matplotlib.pyplot as plt

## Let's run it

In [5]:
# Check the number of parameters of the model
print("--- LinearAlpha ---")
alpha = LinearAlpha()
print_param_count(alpha)

print("\n--- LinearBeta ---")
beta = LinearBeta()
print_param_count(beta)

print("\n--- Full model ----")
model = Siamese(alpha, beta, weight_aux=0)
print_param_count(model)

--- LinearAlpha ---
Total number of parameters:     33950
Number of trainable parameters: 33950

--- LinearBeta ---
Total number of parameters:     382
Number of trainable parameters: 382

--- Full model ----
Total number of parameters:     34332
Number of trainable parameters: 34332


In [6]:
results = []
for i in range(1):
    print(f"Run {i+1}/10")

    # Import data
    dl_train_all, dl_val_all, dl_test_all = load_all_data(normalize=True)

    # Model
    alpha = LinearAlpha()
    beta = LinearBeta(label_encoded=False)
    model = Siamese(alpha, beta, weight_aux=0, 
                    softmax=False, argmax=False, 
                    strategy='sum')

    # Trainer
    trainer = Trainer(nb_epochs=25, verbose=False, run='fc_noAux_argmax')

    # Training
    trainer.fit(model, dl_train_all, dl_val_all, verbose=True)

    # Testing
    acc = trainer.test(model, dl_test_all, test_verbose=True, return_acc=True)
    results.extend([acc])

Run 1/10
# Epoch 1/25:	 loss=0.69	 loss_val=0.69	 acc_val=56.7
# Epoch 2/25:	 loss=0.69	 loss_val=0.68	 acc_val=56.7
# Epoch 3/25:	 loss=0.68	 loss_val=0.67	 acc_val=56.7
# Epoch 4/25:	 loss=0.66	 loss_val=0.62	 acc_val=56.7
# Epoch 5/25:	 loss=0.58	 loss_val=0.52	 acc_val=79.91
# Epoch 6/25:	 loss=0.5	 loss_val=0.51	 acc_val=79.02
# Epoch 7/25:	 loss=0.44	 loss_val=0.46	 acc_val=80.36
# Epoch 8/25:	 loss=0.35	 loss_val=0.43	 acc_val=80.8
# Epoch 9/25:	 loss=0.29	 loss_val=0.5	 acc_val=78.57
# Epoch 10/25:	 loss=0.23	 loss_val=0.43	 acc_val=82.14
# Epoch 11/25:	 loss=0.18	 loss_val=0.52	 acc_val=81.25
# Epoch 12/25:	 loss=0.17	 loss_val=0.48	 acc_val=80.36
# Epoch 13/25:	 loss=0.1	 loss_val=0.58	 acc_val=80.8
# Epoch 14/25:	 loss=0.08	 loss_val=0.61	 acc_val=82.14
# Epoch 15/25:	 loss=0.09	 loss_val=0.6	 acc_val=81.25
# Epoch 16/25:	 loss=0.08	 loss_val=0.64	 acc_val=79.91
# Epoch 17/25:	 loss=0.06	 loss_val=0.62	 acc_val=80.36
# Epoch 18/25:	 loss=0.05	 loss_val=0.77	 acc_val=79.91
# 

In [7]:
# Print results
results = torch.Tensor(results)
print("Average accuracy:   {:.2f}".format(results.mean()))
print("Standard deviation: {:.2f}".format(results.std()))

Average accuracy:   81.93
Standard deviation: nan
