In [None]:
import os
os.chdir("../")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
from src.utils import train, compute_accuracy, set_device, plot_performance_over_time
from src.models import SimpleMLP, AttentionMLP, ConjugationRNN

In [None]:
SEED = 265
torch.manual_seed(SEED)
DEVICE = set_device("cuda")
print(f"Using device: {DEVICE}")

In [None]:
data_train = torch.load("generated_data/data_train.pt")
data_val = torch.load("generated_data/data_val.pt")
data_test = torch.load("generated_data/data_test.pt")
mapping = torch.load("generated_data/mapping.pt")
embedding = torch.load("generated_data/embedding_matrix.pt")
vocab = torch.load("generated_data/vocabulary.pt")

In [None]:
datasets = [data_train, data_val, data_test]
target_filter  = ["be", "am", "are", "is", "was", "were", "been", "being", "have", "has", "had", "having"]
target_map = {vocab[w]: i for i, w in enumerate(target_filter)}

for i, dataset in enumerate(datasets):
    filtered_tensors = []
    for context, target in dataset:
        if mapping[int(target)] in target_filter:
            filtered_tensors.append((context, torch.tensor(target_map[int(target)])))
    filtered_tensors = list(zip(*filtered_tensors))
    filtered_tensors = TensorDataset(*[torch.stack(t) for t in filtered_tensors])
    datasets[i] = filtered_tensors

In [None]:
data_train, data_val, data_test = datasets
print("Size of training data: ", len(data_train))
print("Size of validation data: ", len(data_val))
print("Size of test data: ", len(data_test))

In [None]:
context_size = int(data_train[0][0].shape[0])
batch_size = 64
n_epochs = 15
loss_fn = nn.CrossEntropyLoss()

print(f"-- Global Parameters --")
print(f"{batch_size=}")
print(f"{n_epochs=}")
print(f"{context_size=}")

model_architectures = [SimpleMLP, AttentionMLP, ConjugationRNN]
parameter_search = [
    {"lr":0.008},
]

In [None]:
train_loader = DataLoader(data_train, batch_size=batch_size)
val_loader = DataLoader(data_val, batch_size=batch_size)

In [None]:
train_losses = []
val_losses = []
train_accs = []
val_accs = []
val_perfs = []
models = []

for params in parameter_search:
    print("\n-- Training with following parameters --:")
    for architecture in model_architectures:
        print("\nModel architecture: ", architecture)
        for name, val in params.items():
            print(f"{name}: {val}")

        embedding = embedding.to(DEVICE)
        torch.manual_seed(SEED)
        model = architecture(embedding, max_len=context_size)
        model.to(DEVICE)
        optimizer = Adam(model.parameters(), lr=params["lr"])
        
        train_loss, val_loss, train_acc, val_acc = train(n_epochs, model, optimizer, loss_fn, train_loader, val_loader, DEVICE)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        val_perfs.append(val_acc[-1])
        models.append(model)
        print(f"Train accuracy: {train_acc[-1]*100:.3f}%")
        print(f"Validation accuracy: {val_acc[-1]*100:.3f}%")

In [None]:
chosen_index = val_perfs.index(max(val_perfs))
chosen_model = models[chosen_index]
print(chosen_model)

In [None]:
plot_performance_over_time(train_losses[chosen_index], val_losses[chosen_index], "Training and Validation loss of chosen model", "loss")
plot_performance_over_time(train_accs[chosen_index], val_accs[chosen_index], "Training and Validation accuracy of chosen model", "accuracy")

In [None]:
test_loader = DataLoader(data_test, batch_size=batch_size)

In [None]:
test_acc = compute_accuracy(chosen_model, test_loader, device=DEVICE)
print(f"Test accuracy: {test_acc*100:.3f}%")