In [None]:
# %% [markdown]
# <p align="center">
#   <img src="https://upload.wikimedia.org/wikipedia/commons/7/78/Eindhoven_University_of_Technology_logo_new.png?20231008195526" alt="TU/e Logo" width="200px"/>
# </p>
# 
# # Assignment 2: CNN, AutoML & Hyperparameter Optimization (with CV Loss & Accuracy Tracking)
# 
# **Course:** 1BM120 – Decision Making with Artificial Intelligence  
# **Date:** *Q4 - 2025*  
# 
# ---
# 
# ## Group 3
# 
# - **Sadra Moosavi Lar**  
#   ✉️ [s.s.moosavi.lar@student.tue.nl](mailto:s.s.moosavi.lar@student.tue.nl)
# 
# - **Floris van Hasselt**  
#   ✉️ [f.j.p.v.hasselt@student.tue.nl](mailto:f.j.p.v.hasselt@student.tue.nl)
# 
# - **Sam Fiers**  
#   ✉️ [s.s.w.fiers@student.tue.nl](mailto:s.s.w.fiers@student.tue.nl)
# 
# ---
# 
# **Repository:** [GitHub – Group 3 Repo](https://github.com/sadra-hub/1BM120-decisiondecisionMakingWithAI)
# 
# ---
# 
# ## Description
# 
# We build three workflows:
# 1. **Baseline CNN** (no tuning): train on the full training set, track training & test accuracy and loss per epoch, plot learning curves.
# 2. **Hyperparameter Tuning** (Random Search vs. TPE Search): perform 5-fold cross-validation inside each trial, train 20 epochs per fold; record mean CV accuracy (for choosing best trial) and record mean CV loss in trial metadata. Plot validation curves for both accuracy and loss.
# 3. **Retrain Best Models**: using best hyperparameters from Random Search and TPE Search (based on accuracy), retrain each on the entire training set, track training & test accuracy and loss per epoch, plot learning curves, highlight highest accuracies, save hyperparameters alongside model weights, and save plots in high quality.

# %%
# -----------------------------
#        INPUT PARAMETERS
# -----------------------------
# Random seed
SEED = 18

# DataLoader settings
NUM_WORKERS = 4
PIN_MEMORY = True

# Baseline CNN settings
BATCH_SIZE_BASELINE = 16
NUM_EPOCHS_BASELINE = 10
LR_BASELINE = 0.001

# Cross-Validation (CV) settings
CV_FOLDS = 5
CV_EPOCHS = 20

# Optuna hyperparameter tuning settings
OPTUNA_TRIALS = 5

# Final model training settings
NUM_EPOCHS_FINAL = 20
BATCH_SIZE_TEST = 16  # For test DataLoader

# Other constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VAL_CURVES_FILENAME = "validation_curves.png"
VAL_LOSS_CURVES_FILENAME = "validation_loss_curves.png"
BASELINE_CURVE_FILENAME = "baseline_learning_curve.png"
RANDOM_CURVE_FILENAME = "random_learning_curve.png"
TPE_CURVE_FILENAME = "tpe_learning_curve.png"
RANDOM_MODEL_FILENAME = "cnn_random_search_best.pth"
TPE_MODEL_FILENAME = "cnn_tpe_search_best.pth"

# -----------------------------
#        END PARAMETERS
# -----------------------------

# Import necessary libraries (after parameters to use DEVICE)
import torch
import random
import optuna
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from support import load_dataset
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

# %%
# Set seed for reproducibility
def set_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# %%
# Load datasets using the provided function in support.py
train_dataset, test_dataset = load_dataset()

# Simple helper to display balanced samples from a DataLoader
def show_balanced_samples(loader, classes=["OK", "Defective"]):
    ok_imgs = []
    def_imgs = []
    for images, labels in loader:
        for img, label in zip(images, labels):
            if label == 0 and len(ok_imgs) < 4:
                ok_imgs.append(img)
            elif label == 1 and len(def_imgs) < 4:
                def_imgs.append(img)
            if len(ok_imgs) == 4 and len(def_imgs) == 4:
                break
        if len(ok_imgs) == 4 and len(def_imgs) == 4:
            break

    ordered_imgs = []
    ordered_labels = []
    for i in range(4):
        ordered_imgs.extend([ok_imgs[i], def_imgs[i]])
        ordered_labels.extend([0, 1])

    _, H, W = ordered_imgs[0].shape
    aspect = W / H

    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()
    for idx in range(8):
        img = ordered_imgs[idx].permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())
        axes[idx].imshow(img)
        axes[idx].set_title(classes[ordered_labels[idx]])
        axes[idx].axis("off")
        axes[idx].set_aspect(aspect)
    plt.tight_layout()
    plt.show()

# Display a few samples from the training set
train_loader_for_display = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE_BASELINE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)
show_balanced_samples(train_loader_for_display)

# Define a single CNN class, parameterized by num_filters (for first conv layer) and dropout rate
class CNN(nn.Module):
    def __init__(self, num_filters: int = 16, dropout: float = 0.5):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, num_filters, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),                  # Output: (num_filters, H/2, W/2)
            nn.Conv2d(num_filters, num_filters * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)                   # Output: (num_filters*2, H/4, W/4)
        )
        # Dynamically compute flattened size after convolutions (assuming input 60×30)
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 60, 30)
            out = self.features(dummy)
            flat_size = out.view(1, -1).shape[1]

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_size, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 2)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# %%
# Helper functions: train for one epoch, evaluate on accuracy and loss

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / len(loader.dataset)
    return epoch_loss, epoch_acc

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
    return correct / len(loader.dataset)

def evaluate_loss(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            count += inputs.size(0)
    mean_loss = total_loss / count
    return mean_loss

# %%
# Helper to train a model for multiple epochs, tracking training & test accuracy and loss per epoch
def train_full_model(model, train_loader, test_loader, criterion, optimizer, device, num_epochs: int):
    train_accs = []
    train_losses = []
    test_accs = []
    test_losses = []
    for epoch in range(1, num_epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        test_loss = evaluate_loss(model, test_loader, criterion, device)
        test_acc = evaluate(model, test_loader, device)

        train_accs.append(train_acc)
        train_losses.append(train_loss)
        test_accs.append(test_acc)
        test_losses.append(test_loss)

        print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    return train_losses, train_accs, test_losses, test_accs

# %%
# Function to plot accuracy and loss learning curves and save in high quality
def plot_learning_curves(
    epochs,
    train_accs,
    test_accs,
    train_losses,
    test_losses,
    acc_title: str,
    loss_title: str,
    acc_filename: str,
    loss_filename: str
):
    # Accuracy plot
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_accs, marker='o', label="Train Accuracy")
    plt.plot(epochs, test_accs, marker='s', label="Test Accuracy")
    # Highlight highest train and test
    best_train_idx = int(np.argmax(train_accs))
    best_train_val = max(train_accs)
    plt.scatter(epochs[best_train_idx], best_train_val, color='blue')
    plt.text(
        epochs[best_train_idx],
        best_train_val + 0.01,
        f"Max Train Acc: {best_train_val:.2f}",
        color='blue'
    )
    best_test_idx = int(np.argmax(test_accs))
    best_test_val = max(test_accs)
    plt.scatter(epochs[best_test_idx], best_test_val, color='orange')
    plt.text(
        epochs[best_test_idx],
        best_test_val + 0.01,
        f"Max Test Acc: {best_test_val:.2f}",
        color='orange'
    )
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(acc_title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(acc_filename, dpi=300)
    plt.show()

    # Loss plot
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_losses, marker='o', label="Train Loss")
    plt.plot(epochs, test_losses, marker='s', label="Test Loss")
    # Highlight lowest train and test
    best_train_loss_idx = int(np.argmin(train_losses))
    best_train_loss_val = min(train_losses)
    plt.scatter(epochs[best_train_loss_idx], best_train_loss_val, color='blue')
    plt.text(
        epochs[best_train_loss_idx],
        best_train_loss_val + 0.01,
        f"Min Train Loss: {best_train_loss_val:.2f}",
        color='blue'
    )
    best_test_loss_idx = int(np.argmin(test_losses))
    best_test_loss_val = min(test_losses)
    plt.scatter(epochs[best_test_loss_idx], best_test_loss_val, color='orange')
    plt.text(
        epochs[best_test_loss_idx],
        best_test_loss_val + 0.01,
        f"Min Test Loss: {best_test_loss_val:.2f}",
        color='orange'
    )
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(loss_title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(loss_filename, dpi=300)
    plt.show()

# %%
print(f"Using device: {DEVICE}")

# %%
# --------- 1. BASELINE CNN (No Hyperparameter Tuning) ---------
# Train on entire training set, evaluate on test set each epoch, plot learning curves

baseline_train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE_BASELINE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)
baseline_test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE_BASELINE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

baseline_model = CNN(num_filters=16, dropout=0.5).to(DEVICE)
criterion_baseline = nn.CrossEntropyLoss()
optimizer_baseline = optim.Adam(baseline_model.parameters(), lr=LR_BASELINE)

print("\nTraining baseline CNN (no hyperparameter tuning)...")
baseline_train_losses, baseline_train_accs, baseline_test_losses, baseline_test_accs = train_full_model(
    baseline_model,
    baseline_train_loader,
    baseline_test_loader,
    criterion_baseline,
    optimizer_baseline,
    DEVICE,
    NUM_EPOCHS_BASELINE
)

epochs_baseline = np.arange(1, NUM_EPOCHS_BASELINE + 1)
plot_learning_curves(
    epochs=epochs_baseline,
    train_accs=baseline_train_accs,
    test_accs=baseline_test_accs,
    train_losses=baseline_train_losses,
    test_losses=baseline_test_losses,
    acc_title="Baseline CNN: Train & Test Accuracy per Epoch",
    loss_title="Baseline CNN: Train & Test Loss per Epoch",
    acc_filename=BASELINE_CURVE_FILENAME,
    loss_filename="baseline_loss_curve.png"
)

print(
    f"\nBaseline highest training accuracy: {max(baseline_train_accs):.4f} "
    f"(Epoch {np.argmax(baseline_train_accs) + 1})"
)
print(
    f"Baseline highest test accuracy:     {max(baseline_test_accs):.4f} "
    f"(Epoch {np.argmax(baseline_test_accs) + 1})"
)
print(
    f"Baseline lowest training loss:      {min(baseline_train_losses):.4f} "
    f"(Epoch {np.argmin(baseline_train_losses) + 1})"
)
print(
    f"Baseline lowest test loss:          {min(baseline_test_losses):.4f} "
    f"(Epoch {np.argmin(baseline_test_losses) + 1})"
)

# %%
# --------- 2. HYPERPARAMETER TUNING (Random Search vs. TPE Search) ---------
# We perform CV_FOLDS-fold stratified cross-validation inside each trial, training CV_EPOCHS per fold.
# We track both mean CV accuracy (for selecting best trial) and mean CV loss (stored in trial metadata).

# Precompute labels for StratifiedKFold
labels_list = [int(train_dataset[i][1]) for i in range(len(train_dataset))]
all_indices = list(range(len(train_dataset)))

def objective(trial):
    # Sample hyperparameters
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    dropout = trial.suggest_float("dropout", 0.1, 0.7)
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])
    batch_size = trial.suggest_categorical("batch_size", [16, 32])
    num_filters = trial.suggest_int("num_filters", 16, 64, step=16)

    skf = StratifiedKFold(n_splits=CV_FOLDS, shuffle=True, random_state=SEED)
    fold_accuracies = []
    fold_losses = []

    for fold_idx, (train_idx, val_idx) in enumerate(
        skf.split(all_indices, labels_list), start=1
    ):
        # Create subsets for this fold
        train_sub = Subset(train_dataset, train_idx)
        val_sub = Subset(train_dataset, val_idx)

        train_loader = DataLoader(
            train_sub,
            batch_size=batch_size,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=PIN_MEMORY
        )
        val_loader = DataLoader(
            val_sub,
            batch_size=batch_size,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=PIN_MEMORY
        )

        model = CNN(num_filters=num_filters, dropout=dropout).to(DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = (
            optim.Adam(model.parameters(), lr=lr)
            if optimizer_name == "Adam"
            else optim.SGD(model.parameters(), lr=lr)
        )

        # Train for CV_EPOCHS on this fold
        for _epoch in range(1, CV_EPOCHS + 1):
            train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)

        # Evaluate on validation fold: compute both loss and accuracy
        model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * inputs.size(0)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                total += inputs.size(0)

        mean_loss_fold = total_loss / total
        acc_fold = correct / total

        fold_losses.append(mean_loss_fold)
        fold_accuracies.append(acc_fold)

    mean_cv_acc = float(np.mean(fold_accuracies))
    mean_cv_loss = float(np.mean(fold_losses))
    # Store mean CV loss in trial metadata
    trial.set_user_attr("mean_cv_loss", mean_cv_loss)
    return mean_cv_acc

# Run Random Search
random_study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.RandomSampler()
)
random_study.optimize(objective, n_trials=OPTUNA_TRIALS)

# Run TPE Search
tpe_study = optuna.create_study(
    direction="maximize",
    sampler=optuna.samplers.TPESampler()
)
tpe_study.optimize(objective, n_trials=OPTUNA_TRIALS)

print(f"\nBest hyperparameters (Random Search): {random_study.best_params}")
print(f"Best CV accuracy (Random Search):   {random_study.best_value:.4f}")
print(f"Mean CV loss for that trial:        {random_study.best_trial.user_attrs['mean_cv_loss']:.4f}")

print(f"\nBest hyperparameters (TPE Search):  {tpe_study.best_params}")
print(f"Best CV accuracy (TPE Search):      {tpe_study.best_value:.4f}")
print(f"Mean CV loss for that trial:        {tpe_study.best_trial.user_attrs['mean_cv_loss']:.4f}")

# Extract trial numbers, CV accuracies, and CV losses for plotting
rand_trial_nums = [t.number for t in random_study.trials]
rand_trial_accs = [t.value for t in random_study.trials]
rand_trial_losses = [t.user_attrs["mean_cv_loss"] for t in random_study.trials]

tpe_trial_nums = [t.number for t in tpe_study.trials]
tpe_trial_accs = [t.value for t in tpe_study.trials]
tpe_trial_losses = [t.user_attrs["mean_cv_loss"] for t in tpe_study.trials]

# Plot validation curves (CV accuracy vs. trial) for both searches
plt.figure(figsize=(8, 5))
plt.plot(rand_trial_nums, rand_trial_accs, marker='o', label="Random Search CV Acc")
plt.plot(tpe_trial_nums, tpe_trial_accs, marker='s', label="TPE Search CV Acc")

# Highlight best points by accuracy
best_rand_idx = rand_trial_nums[np.argmax(rand_trial_accs)]
best_rand_val = max(rand_trial_accs)
plt.scatter(best_rand_idx, best_rand_val, color='blue')
plt.text(
    best_rand_idx,
    best_rand_val + 0.005,
    f"Max Random Acc: {best_rand_val:.2f}",
    color='blue'
)

best_tpe_idx = tpe_trial_nums[np.argmax(tpe_trial_accs)]
best_tpe_val = max(tpe_trial_accs)
plt.scatter(best_tpe_idx, best_tpe_val, color='orange')
plt.text(
    best_tpe_idx,
    best_tpe_val + 0.005,
    f"Max TPE Acc: {best_tpe_val:.2f}",
    color='orange'
)

plt.xlabel("Trial Number")
plt.ylabel("Mean CV Accuracy")
plt.title(
    f"Hyperparameter Tuning (Accuracy) "
    f"({CV_FOLDS}-Fold CV, {CV_EPOCHS} Epochs/Fold)"
)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(VAL_CURVES_FILENAME, dpi=300)
plt.show()

# Plot validation loss curves (CV loss vs. trial) for both searches
plt.figure(figsize=(8, 5))
plt.plot(rand_trial_nums, rand_trial_losses, marker='o', label="Random Search CV Loss")
plt.plot(tpe_trial_nums, tpe_trial_losses, marker='s', label="TPE Search CV Loss")

# Highlight best points by loss (minimum)
best_rand_loss_idx = rand_trial_nums[np.argmin(rand_trial_losses)]
best_rand_loss_val = min(rand_trial_losses)
plt.scatter(best_rand_loss_idx, best_rand_loss_val, color='blue')
plt.text(
    best_rand_loss_idx,
    best_rand_loss_val + 0.005,
    f"Min Random Loss: {best_rand_loss_val:.2f}",
    color='blue'
)

best_tpe_loss_idx = tpe_trial_nums[np.argmin(tpe_trial_losses)]
best_tpe_loss_val = min(tpe_trial_losses)
plt.scatter(best_tpe_loss_idx, best_tpe_loss_val, color='orange')
plt.text(
    best_tpe_loss_idx,
    best_tpe_loss_val + 0.005,
    f"Min TPE Loss: {best_tpe_loss_val:.2f}",
    color='orange'
)

plt.xlabel("Trial Number")
plt.ylabel("Mean CV Loss")
plt.title(
    f"Hyperparameter Tuning (Loss) "
    f"({CV_FOLDS}-Fold CV, {CV_EPOCHS} Epochs/Fold)"
)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(VAL_LOSS_CURVES_FILENAME, dpi=300)
plt.show()

# %%
# --------- 3. RETRAIN BEST MODELS ON FULL TRAINING SET & EVALUATE ---------
# We use best hyperparameters from Random Search and TPE Search (chosen by CV accuracy), train each model for multiple epochs,
# track training & test accuracy and loss per epoch, plot learning curves, and save hyperparameters alongside model weights.

criterion_final = nn.CrossEntropyLoss()

# DataLoader for test set
final_test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE_TEST,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

# Helper to build model + optimizer given hyperparameters
def build_model_optimizer(params: dict):
    model = CNN(
        num_filters=params["num_filters"],
        dropout=params["dropout"]
    ).to(DEVICE)
    optimizer = (
        optim.Adam(model.parameters(), lr=params["lr"])
        if params["optimizer"] == "Adam"
        else optim.SGD(model.parameters(), lr=params["lr"])
    )
    return model, optimizer

# 3a) Random Search best model
best_params_random = random_study.best_params
print(f"\nRetraining Random-Search-best model with params: {best_params_random}")

model_random, optimizer_random = build_model_optimizer(best_params_random)
full_train_loader_random = DataLoader(
    train_dataset,
    batch_size=best_params_random["batch_size"],
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

print("Training Random-Search best model on full training set...")
rand_train_losses, rand_train_accs, rand_test_losses, rand_test_accs = train_full_model(
    model_random,
    full_train_loader_random,
    final_test_loader,
    criterion_final,
    optimizer_random,
    DEVICE,
    NUM_EPOCHS_FINAL
)

epochs_final = np.arange(1, NUM_EPOCHS_FINAL + 1)
plot_learning_curves(
    epochs=epochs_final,
    train_accs=rand_train_accs,
    test_accs=rand_test_accs,
    train_losses=rand_train_losses,
    test_losses=rand_test_losses,
    acc_title="Random-Search Best Model: Train & Test Accuracy per Epoch",
    loss_title="Random-Search Best Model: Train & Test Loss per Epoch",
    acc_filename=RANDOM_CURVE_FILENAME,
    loss_filename="random_loss_curve.png"
)

print(
    f"\nRandom-Search Best Model highest training accuracy: {max(rand_train_accs):.4f} "
    f"(Epoch {np.argmax(rand_train_accs) + 1})"
)
print(
    f"Random-Search Best Model highest test accuracy:     {max(rand_test_accs):.4f} "
    f"(Epoch {np.argmax(rand_test_accs) + 1})"
)
print(
    f"Random-Search Best Model lowest training loss:      {min(rand_train_losses):.4f} "
    f"(Epoch {np.argmin(rand_train_losses) + 1})"
)
print(
    f"Random-Search Best Model lowest test loss:          {min(rand_test_losses):.4f} "
    f"(Epoch {np.argmin(rand_test_losses) + 1})"
)

# Save Random best model checkpoint with hyperparameters
torch.save({
    "model_state_dict": model_random.state_dict(),
    "optimizer_state_dict": optimizer_random.state_dict(),
    "hyperparameters": best_params_random
}, RANDOM_MODEL_FILENAME)

# 3b) TPE Search best model
best_params_tpe = tpe_study.best_params
print(f"\nRetraining TPE-Search-best model with params: {best_params_tpe}")

model_tpe, optimizer_tpe = build_model_optimizer(best_params_tpe)
full_train_loader_tpe = DataLoader(
    train_dataset,
    batch_size=best_params_tpe["batch_size"],
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

print("Training TPE-Search best model on full training set...")
tpe_train_losses, tpe_train_accs, tpe_test_losses, tpe_test_accs = train_full_model(
    model_tpe,
    full_train_loader_tpe,
    final_test_loader,
    criterion_final,
    optimizer_tpe,
    DEVICE,
    NUM_EPOCHS_FINAL
)

plot_learning_curves(
    epochs=epochs_final,
    train_accs=tpe_train_accs,
    test_accs=tpe_test_accs,
    train_losses=tpe_train_losses,
    test_losses=tpe_test_losses,
    acc_title="TPE-Search Best Model: Train & Test Accuracy per Epoch",
    loss_title="TPE-Search Best Model: Train & Test Loss per Epoch",
    acc_filename=TPE_CURVE_FILENAME,
    loss_filename="tpe_loss_curve.png"
)

print(
    f"\nTPE-Search Best Model highest training accuracy: {max(tpe_train_accs):.4f} "
    f"(Epoch {np.argmax(tpe_train_accs) + 1})"
)
print(
    f"TPE-Search Best Model highest test accuracy:     {max(tpe_test_accs):.4f} "
    f"(Epoch {np.argmax(tpe_test_accs) + 1})"
)
print(
    f"TPE-Search Best Model lowest training loss:      {min(tpe_train_losses):.4f} "
    f"(Epoch {np.argmin(tpe_train_losses) + 1})"
)
print(
    f"TPE-Search Best Model lowest test loss:          {min(tpe_test_losses):.4f} "
    f"(Epoch {np.argmin(tpe_test_losses) + 1})"
)

# Save TPE best model checkpoint with hyperparameters
torch.save({
    "model_state_dict": model_tpe.state_dict(),
    "optimizer_state_dict": optimizer_tpe.state_dict(),
    "hyperparameters": best_params_tpe
}, TPE_MODEL_FILENAME)

print("\nFinal models and plots saved:")
print(f"  • {RANDOM_MODEL_FILENAME}")
print(f"  • {TPE_MODEL_FILENAME}")
print(f"  • {BASELINE_CURVE_FILENAME}")
print(f"  • baseline_loss_curve.png")
print(f"  • {VAL_CURVES_FILENAME}")
print(f"  • {VAL_LOSS_CURVES_FILENAME}")
print(f"  • {RANDOM_CURVE_FILENAME}")
print(f"  • random_loss_curve.png")
print(f"  • {TPE_CURVE_FILENAME}")
print(f"  • tpe_loss_curve.png")

# End of script