In [None]:
import torch

cuda_available = torch.cuda.is_available()
print("CUDA available:", cuda_available)

if cuda_available:
    print("Number of CUDA devices:", torch.cuda.device_count())
    print("Current CUDA device:", torch.cuda.current_device())
    print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))


In [None]:
from models.cnn import CNN
from models.resnet_vanilla import LightweightResNet
from models.resnet18 import ResNet18Emotion
from utils.data_loader import get_loaders
from utils.train import train_model

# 1. Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATASET_PATH = "./dataset_emotion"
EPOCHS = 6

# 2. Prepare Loaders
train_std, val_std, test_std, num_classes = get_loaders(DATASET_PATH, model_type="cnn")
train_ftune, val_ftune, test_ftune, _ = get_loaders(DATASET_PATH, model_type="resnet18")

# 3. Training Dictionary
models_to_train = {
    "cnn": (CNN(num_classes), train_std, val_std, test_std),
    "resnet_vanilla": (LightweightResNet(num_classes=num_classes), train_std, val_std, test_std),
    "resnet18": (ResNet18Emotion(num_classes=num_classes), train_ftune, val_ftune, test_ftune),
}

# 4. Train Models
for name, (model, tr, vl, ts) in models_to_train.items():
    train_model(model, name, tr, vl, DEVICE, num_epochs=EPOCHS)
    del model
    torch.cuda.empty_cache()

In [None]:
from utils.eval import get_preds_and_cm, print_report
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
eval_config = {
    "CNN": {
        "model": CNN(num_classes=4),
        "ckpt": "./checkpoints_cnn/best.pth",
        "loader": test_std
    },
    "ResNet Vanilla": {
        "model": LightweightResNet(num_classes=4),
        "ckpt": "./checkpoints_resnet_vanilla/best.pth",
        "loader": test_std
    },
    "ResNet18": {
        "model": ResNet18Emotion(num_classes=4),
        "ckpt": "./checkpoints_resnet18/best.pth",
        "loader": test_ftune
    }
}

# 4. Run Evaluation for Every Model
for model_name, config in eval_config.items():
    print(f"\n" + "="*30)
    print(f" Evaluating: {model_name} ")
    print("="*30)
    
    # Load weights
    checkpoint = torch.load(config["ckpt"], map_location=DEVICE, weights_only=True)
    model = config["model"]
    model.load_state_dict(checkpoint["model_state"])
    model.to(DEVICE)
    
    # Get predictions and ground truth using eval.py utility
    y_true, y_pred, cm = get_preds_and_cm(model, config["loader"], DEVICE)
    
    
    # Print metrics report
    class_names = config["loader"].dataset.classes
    print_report(y_true, y_pred, class_names)

    # --- Display the Confusion Matrix ---
    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    
    # Plotting with a blue color map and integer formatting
    disp.plot(cmap="Blues", values_format="d", ax=ax, xticks_rotation=45)
    ax.set_title(f"Confusion Matrix: {model_name}")
    plt.show()
    

