# Fully-connected model without auxiliary loss

In [1]:
from src.dlc_practical_prologue import generate_pair_sets
from src.utils import load_class_data, load_target_data, load_all_data, print_param_count
from src.models import *
from src.trainer import Trainer
import matplotlib.pyplot as plt

## Let's run it

In [2]:
# Check the number of parameters of the model
print("--- LinearAlpha ---")
alpha = LinearAlpha()
print_param_count(alpha)

print("\n--- LinearBeta ---")
beta = LinearBeta()
print_param_count(beta)

print("\n--- Full model ----")
model = Siamese(alpha, beta, weight_aux=0)
print_param_count(model)

--- LinearAlpha ---
Total number of parameters:     33950
Number of trainable parameters: 33950

--- LinearBeta ---
Total number of parameters:     382
Number of trainable parameters: 382

--- Full model ----
Total number of parameters:     34332
Number of trainable parameters: 34332


In [6]:
results = []
for i in range(1):
    print(f"Run {i+1}/10")

    # Import data
    dl_train_all, dl_val_all, dl_test_all = load_all_data(normalize=True)

    # Model
    alpha = LinearAlpha()
    beta = LinearBeta(label_encoded=True)
    model = Siamese(alpha, beta, weight_aux=0, softmax=False, strategy='sum')

    # Trainer
    trainer = Trainer(nb_epochs=25, verbose=False, run='fc_noAux_argmax')

    # Training
    trainer.fit(model, dl_train_all, dl_val_all, verbose=True)

    # Testing
    acc = trainer.test(model, dl_test_all, test_verbose=True, return_acc=True)
    results.extend([acc])

Run 1/10
# Epoch 1/25:	 loss=0.69	 loss_val=0.68	 acc_val=59.38
# Epoch 2/25:	 loss=0.69	 loss_val=0.68	 acc_val=59.38
# Epoch 3/25:	 loss=0.69	 loss_val=0.68	 acc_val=58.48
# Epoch 4/25:	 loss=0.69	 loss_val=0.68	 acc_val=58.48
# Epoch 5/25:	 loss=0.69	 loss_val=0.68	 acc_val=58.48
# Epoch 6/25:	 loss=0.69	 loss_val=0.68	 acc_val=58.48


KeyboardInterrupt: 

In [4]:
# Print results
results = torch.Tensor(results)
print("Average accuracy:   {:.2f}".format(results.mean()))
print("Standard deviation: {:.2f}".format(results.std()))

Average accuracy:   54.88
Standard deviation: nan
