### Use this notebook to load and test a model
### Install required packages

In [None]:
!pip install matplotlib pandas torch torchmetrics scikit-learn

### Import all libraries and models

In [None]:
# Matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
# Numpy
import numpy as np
# Pandas
import pandas as pd
# Torch
import torch
import torch.nn as nn
import json
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import Accuracy
from models import ResNet50, ResNet50BiLSTMAttention, ResNet34BiLSTMAttention

import torch.optim as optim

import pickle
import random
from sklearn.model_selection import train_test_split
import os

#Implemented seeding 
def seed_functions(seed):
	"""Seeds functions from numpy and torch."""
	np.random.seed(seed)
	random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True
	os.environ['PYTHONHASHSEED'] = str(seed)

SEED = 37
seed_functions(SEED)

### Helper function to load and test best model (no changes needed)

In [None]:
def load_best_model_and_test(model_dir, model, test_loader, num_classes):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # Load model checkpoint
    with open(os.path.join(model_dir, "best_model.pkl"), "rb") as f:
        saved_data = pickle.load(f)
        model.load_state_dict(saved_data["model_state"])
        print(f"Best Model Achieved at Epoch: {saved_data['epoch']} with Validation Loss: {saved_data['val_loss']:.4f}")
    
    # Setup accuracy metric
    accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes).to(device)

    total_loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            accuracy_metric.update(outputs, targets)

    avg_loss = total_loss / len(test_loader)
    test_accuracy = accuracy_metric.compute().item()
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

### Load and test ResNet34BiLSTMAttention

In [None]:
model_dir = "checkpoints/ResNet34BiLSTMAttentionlr0001"
model = ResNet34BiLSTMAttention(classes=num_languages)
load_best_model_and_test(model_dir, model, test_loader, num_classes=num_languages)

### Load and test ResNet50BiLSTMAttention

In [None]:
model_dir = "checkpoints/ResNet50BiLSTMAttentionlr0001"
model = ResNet50BiLSTMAttention(classes=num_languages)
load_best_model_and_test(model_dir, model, test_loader, num_classes=num_languages)

### Load and test ResNet50

In [None]:
model_dir = "checkpoints/ResNet500001"
model = ResNet50(classes=num_languages)
load_best_model_and_test(model_dir, model, test_loader, num_classes=num_languages)