In [None]:
import sys
sys.path.append('..')  # Add the parent directory to sys.path

In [None]:
import time, os, torch
import numpy as np
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
from src.utils import train, compute_accuracy, set_device, plot_performance_over_time, plot_training_times
from src.models import SimpleMLP, AttentionMLP, ConjugationRNN

In [None]:
SEED = 265
torch.manual_seed(SEED)
DEVICE = set_device("cuda")
print(f"Using device: {DEVICE}")

In [None]:
PATH_GENERATED = "../generated_data/"
data_train = torch.load(PATH_GENERATED + "data_train.pt")
data_val = torch.load(PATH_GENERATED + "data_val.pt")
data_test = torch.load(PATH_GENERATED + "data_test.pt")
mapping = torch.load(PATH_GENERATED + "mapping.pt")
embedding = torch.load(PATH_GENERATED + "embedding_matrix.pt")
vocab = torch.load(PATH_GENERATED + "vocabulary.pt")

In [None]:
datasets = [data_train, data_val, data_test]
target_filter  = ["be", "am", "are", "is", "was", "were", "been", "being", "have", "has", "had", "having"]
target_map = {vocab[w]: i for i, w in enumerate(target_filter)}

if os.path.isfile(PATH_GENERATED + "conjugation_data.pt"):
    datasets = torch.load(PATH_GENERATED + "conjugation_data.pt")
else:
    for i, dataset in enumerate(datasets):
        filtered_tensors = []
        for context, target in dataset:
            if mapping[int(target)] in target_filter:
                filtered_tensors.append((context, torch.tensor(target_map[int(target)])))
        filtered_tensors = list(zip(*filtered_tensors))
        filtered_tensors = TensorDataset(*[torch.stack(t) for t in filtered_tensors])
        datasets[i] = filtered_tensors

In [None]:
data_train, data_val, data_test = datasets
print("Size of training data: ", len(data_train))
print("Size of validation data: ", len(data_val))
print("Size of test data: ", len(data_test))

In [None]:
context_size = int(data_train[0][0].shape[0])
batch_size = 64
n_epochs = 30
loss_fn = nn.CrossEntropyLoss()

print(f"-- Global Parameters --")
print(f"{batch_size=}")
print(f"{n_epochs=}")
print(f"{context_size=}") 

model_architectures = [SimpleMLP, AttentionMLP, ConjugationRNN]
# Each model parameter corresponds to the architecture at the same position
model_parameters = [
    [
        {"l1": 128, "l2": 32},
        {"l1": 256, "l2": 64},
        {"l1": 256, "l2": 256},
    ],
    [
        {"n_heads": 4, "w_size": 8},
        {"n_heads": 8, "w_size": 16},
        {"n_heads": 16, "w_size": 20},
    ],
    [
        {"num_hiddens": 8, "num_layers": 4, "dropout": 0},   
        {"num_hiddens": 16, "num_layers": 8, "dropout": 0.1},   
        {"num_hiddens": 20, "num_layers": 16, "dropout": 0.01},   
    ]
]
parameter_search = [
    {"lr":0.008},
    {"lr":0.001},
    {"lr":0.0005},
]

In [None]:
train_loader = DataLoader(data_train, batch_size=batch_size)
val_loader = DataLoader(data_val, batch_size=batch_size)

In [None]:
if os.path.isfile(PATH_GENERATED + "conjugation_model.pt"):
    print("Skipping training, loading existing model...")
else:
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    val_perfs = []
    hyper_params = []
    models = []
    architecture_times = []
    
    for architecture, m_params in zip(model_architectures, model_parameters):
        model_times = []
        for params in parameter_search:
            print("\n-- Training with following parameters --:")
            print("\nModel architecture: ", architecture)
            for name, val in params.items():
                print(f"{name}: {val}")
            for m_param in m_params:
                print(m_param)
                
                embedding = embedding.to(DEVICE)
                torch.manual_seed(SEED)
                model = architecture(embedding, max_len=context_size, **m_param)
                model.to(DEVICE)
                optimizer = Adam(model.parameters(), lr=params["lr"])
    
                start_time = time.time()
                train_loss, val_loss, train_acc, val_acc = train(n_epochs, model, optimizer, loss_fn, train_loader, val_loader, DEVICE)
                end_time = time.time()
                training_time = end_time-start_time
                model_times.append(training_time)
                
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                train_accs.append(train_acc)
                val_accs.append(val_acc)
                val_perfs.append(val_acc[-1])
                hyper_params.append({"architecture": architecture, **params, **m_param})
                models.append(model)
                print(f"Training time: {training_time:.3f}s")
                print(f"Train accuracy: {train_acc[-1]*100:.3f}%")
                print(f"Validation accuracy: {val_acc[-1]*100:.3f}%\n")
        architecture_times.append(np.average(model_times))
            

In [None]:
if os.path.isfile(PATH_GENERATED + "conjugation_model.pt"):
    chosen_model = torch.load(PATH_GENERATED + "conjugation_model.pt")
    chosen_index, architecture_times, train_losses, val_losses, train_accs, val_accs, hyper_params = torch.load(PATH_GENERATED + "conjugation_plots.pt")
else:
    chosen_index = val_perfs.index(max(val_perfs))
    chosen_model = models[chosen_index]
    torch.save(chosen_model, PATH_GENERATED + "conjugation_model.pt")
    torch.save((chosen_index, architecture_times, train_losses, val_losses, train_accs, val_accs, hyper_params), PATH_GENERATED + "conjugation_plots.pt")

In [None]:
print("Chosen parameters: ")
print(hyper_params[chosen_index])
print("\nChosen model: ")
print(chosen_model)

In [None]:
plot_training_times(architecture_times, ["SimlpeMLP", "AttentionMLP", "ConjugationRNN"],
                    f_name="../images/conjugation_training_times.png", save=True)
plot_performance_over_time(train_losses[chosen_index], val_losses[chosen_index],
                           "Training and Validation loss of chosen model", "loss",
                            f_name="../images/conjugation_loss.png", save=True)
plot_performance_over_time(train_accs[chosen_index], val_accs[chosen_index],
                           "Training and Validation accuracy of chosen model", "accuracy",
                            f_name="../images/conjugation_accuracy.png", save=True)

In [None]:
test_loader = DataLoader(data_test, batch_size=batch_size)

In [None]:
test_acc = compute_accuracy(chosen_model, test_loader, device=DEVICE)
print(f"Test accuracy: {test_acc*100:.3f}%")