In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import optuna
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from support import load_dataset
import matplotlib.pyplot as plt

# -----------------------------
#        INPUT PARAMETERS
# -----------------------------
SEARCH_METHOD = "random"  # or "TPE"
SEED = 18

# CV settings
CV_FOLDS = 5
CV_EPOCHS = 5

# Optuna settings
OPTUNA_TRIALS = 5

# Hyperparameter search spaces
LR_LOW, LR_HIGH = 1e-5, 1e-2
DROPOUT_LOW, DROPOUT_HIGH = 0.1, 0.7
OPTIMIZER_CHOICES = ["Adam", "SGD"]
BATCH_SIZE_CHOICES = [16, 32]
NUM_FILTERS_CHOICES = [16, 32, 48, 64]

# Final training settings
NUM_EPOCHS_FINAL = 5
BATCH_SIZE_TEST = 16

# Storage URL depends on method
STORAGE_URL = f"sqlite:///{SEARCH_METHOD.lower()}_search.db"

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure plot directory exists
os.makedirs(f"plots/{SEARCH_METHOD.lower()}", exist_ok=True)

# -----------------------------
#   Set seed (reproducibility)
# -----------------------------
def set_seed(seed: int = SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# -----------------------------
#   Load datasets
# -----------------------------
print(f"Device in use: {DEVICE}")
print("Loading datasets…")
train_dataset, test_dataset = load_dataset()
print(f"  Training set size: {len(train_dataset)}, Test set size: {len(test_dataset)}")

# -----------------------------
#   Define CNN
# -----------------------------
class CNN(nn.Module):
    def __init__(self, num_filters: int = 16, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, num_filters, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(num_filters, num_filters * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 60, 30)
            out = self.features(dummy)
            flat_size = out.view(1, -1).shape[1]
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_size, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 2),
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

# -----------------------------
#   Helper functions
# -----------------------------
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == labels).sum().item()
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / len(loader.dataset)
    return epoch_loss, epoch_acc

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
    return correct / len(loader.dataset)

def evaluate_loss(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            count += inputs.size(0)
    return total_loss / count

def train_full_model(model, train_loader, test_loader, criterion, optimizer, device, num_epochs: int):
    train_accs = []
    train_losses = []
    test_accs = []
    test_losses = []
    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        correct = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
        train_loss = running_loss / len(train_loader.dataset)
        train_acc = correct / len(train_loader.dataset)

        model.eval()
        total_loss = 0.0
        correct = 0
        count = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * inputs.size(0)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                count += inputs.size(0)
        test_loss = total_loss / count
        test_acc = correct / count

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        print(
            f"Epoch {epoch:02d}/{num_epochs} | "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}"
        )

    return train_losses, train_accs, test_losses, test_accs

def plot_learning_curves(
    epochs,
    train_accs,
    test_accs,
    train_losses,
    test_losses,
    acc_title: str,
    loss_title: str,
    acc_filename: str,
    loss_filename: str
):
    # Accuracy
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_accs, marker='o', label="Train Accuracy")
    plt.plot(epochs, test_accs, marker='s', label="Test Accuracy")
    if train_accs:
        best_train_idx = int(np.argmax(train_accs))
        best_train_val = max(train_accs)
        plt.scatter(epochs[best_train_idx], best_train_val, color='blue')
        plt.text(
            epochs[best_train_idx],
            best_train_val + 0.01,
            f"Max Train Acc: {best_train_val:.2f}",
            color='blue'
        )
    if test_accs:
        best_test_idx = int(np.argmax(test_accs))
        best_test_val = max(test_accs)
        plt.scatter(epochs[best_test_idx], best_test_val, color='orange')
        plt.text(
            epochs[best_test_idx],
            best_test_val + 0.01,
            f"Max Test Acc: {best_test_val:.2f}",
            color='orange'
        )
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(acc_title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"plots/{SEARCH_METHOD.lower()}/{acc_filename}", dpi=300)
    plt.show()

    # Loss
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_losses, marker='o', label="Train Loss")
    plt.plot(epochs, test_losses, marker='s', label="Test Loss")
    if train_losses:
        best_train_loss_idx = int(np.argmin(train_losses))
        best_train_loss_val = min(train_losses)
        plt.scatter(epochs[best_train_loss_idx], best_train_loss_val, color='blue')
        plt.text(
            epochs[best_train_loss_idx],
            best_train_loss_val + 0.01,
            f"Min Train Loss: {best_train_loss_val:.2f}",
            color='blue'
        )
    if test_losses:
        best_test_loss_idx = int(np.argmin(test_losses))
        best_test_loss_val = min(test_losses)
        plt.scatter(epochs[best_test_loss_idx], best_test_loss_val, color='orange')
        plt.text(
            epochs[best_test_loss_idx],
            best_test_loss_val + 0.01,
            f"Min Test Loss: {best_test_loss_val:.2f}",
            color='orange'
        )
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(loss_title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"plots/{SEARCH_METHOD.lower()}/{loss_filename}", dpi=300)
    plt.show()


In [None]:
# -----------------------------
#   Cell 1: Hyperparameter Tuning
# -----------------------------

print(f"\n### Starting {SEARCH_METHOD.upper()} Search hyperparameter tuning ###")

labels_list = [int(train_dataset[i][1]) for i in range(len(train_dataset))]
all_indices = list(range(len(train_dataset)))

def objective(trial):
    lr = trial.suggest_float("lr", LR_LOW, LR_HIGH, log=True)
    dropout = trial.suggest_float("dropout", DROPOUT_LOW, DROPOUT_HIGH)
    optimizer_name = trial.suggest_categorical("optimizer", OPTIMIZER_CHOICES)
    batch_size = trial.suggest_categorical("batch_size", BATCH_SIZE_CHOICES)
    num_filters = trial.suggest_categorical("num_filters", NUM_FILTERS_CHOICES)

    print(
        f"Trial {trial.number}: "
        f"lr={lr:.2e}, dropout={dropout:.2f}, "
        f"optimizer={optimizer_name}, batch_size={batch_size}, num_filters={num_filters}"
    )

    skf = StratifiedKFold(n_splits=CV_FOLDS, shuffle=True, random_state=SEED)
    fold_accuracies = []
    fold_losses = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(all_indices, labels_list), start=1):
        train_sub = Subset(train_dataset, train_idx)
        val_sub = Subset(train_dataset, val_idx)

        train_loader = DataLoader(
            train_sub,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        val_loader = DataLoader(
            val_sub,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        model = CNN(num_filters=num_filters, dropout=dropout).to(DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = (
            optim.Adam(model.parameters(), lr=lr)
            if optimizer_name == "Adam"
            else optim.SGD(model.parameters(), lr=lr)
        )

        for epoch in range(1, CV_EPOCHS + 1):
            train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
            if epoch % 5 == 0 or epoch == CV_EPOCHS:
                print(f"  Trial {trial.number} • Fold {fold_idx} Epoch {epoch}/{CV_EPOCHS}")

        total_loss = 0.0
        correct = 0
        total = 0
        model.eval()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * inputs.size(0)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                total += inputs.size(0)

        mean_loss_fold = total_loss / total
        acc_fold = correct / total

        print(
            f"  Trial {trial.number} Fold {fold_idx} → "
            f"Acc: {acc_fold:.4f}, Loss: {mean_loss_fold:.4f}"
        )

        fold_losses.append(mean_loss_fold)
        fold_accuracies.append(acc_fold)

    mean_cv_acc = float(np.mean(fold_accuracies))
    mean_cv_loss = float(np.mean(fold_losses))
    trial.set_user_attr("mean_cv_loss", mean_cv_loss)
    print(
        f"Trial {trial.number} mean CV Acc: {mean_cv_acc:.4f}, "
        f"mean CV Loss: {mean_cv_loss:.4f}\n"
    )
    return mean_cv_acc

if SEARCH_METHOD.lower() == "random":
    sampler = optuna.samplers.RandomSampler()
    study_name = "random_search"
elif SEARCH_METHOD.lower() == "tpe":
    sampler = optuna.samplers.TPESampler()
    study_name = "tpe_search"
else:
    raise ValueError("SEARCH_METHOD must be either 'random' or 'TPE'")

print(f"Starting {SEARCH_METHOD.upper()} Search…")
study = optuna.create_study(
    study_name=study_name,
    storage=STORAGE_URL,
    load_if_exists=True,
    direction="maximize",
    sampler=sampler
)
study.optimize(objective, n_trials=OPTUNA_TRIALS)

best_trial = study.best_trial
best_params = best_trial.params
best_cv_acc = best_trial.value
best_cv_loss = best_trial.user_attrs.get("mean_cv_loss", float("nan"))

print(
    f"\n{SEARCH_METHOD.upper()} Search best trial #{best_trial.number} | "
    f"Hyperparams: {best_params} | CV Acc: {best_cv_acc:.4f} | CV Loss: {best_cv_loss:.4f}"
)


In [None]:
# -----------------------------
#   Cell 2: Plot & Full Retrain
# -----------------------------

import os
import optuna
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

# -----------------------------
#   Re‐define constants from Cell 0
# -----------------------------
SEARCH_METHOD = "random"   # or "TPE"
SEED = 18

CV_FOLDS = 5
CV_EPOCHS = 5
NUM_EPOCHS_FINAL = 5
BATCH_SIZE_TEST = 16

VAL_CURVES_FILENAME = "validation_curves.png"
VAL_LOSS_CURVES_FILENAME = "validation_loss_curves.png"

os.makedirs(f"plots/{SEARCH_METHOD.lower()}", exist_ok=True)

# -----------------------------
#   Re‐define helpers from Cell 0
# -----------------------------
def train_full_model(model, train_loader, test_loader, criterion, optimizer, device, num_epochs: int):
    train_accs = []
    train_losses = []
    test_accs = []
    test_losses = []
    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        correct = 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
        train_loss = running_loss / len(train_loader.dataset)
        train_acc = correct / len(train_loader.dataset)

        model.eval()
        total_loss = 0.0
        correct = 0
        count = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item() * inputs.size(0)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                count += inputs.size(0)
        test_loss = total_loss / count
        test_acc = correct / count

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        print(
            f"Epoch {epoch:02d}/{num_epochs} | "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}"
        )

    return train_losses, train_accs, test_losses, test_accs

def plot_learning_curves(
    epochs,
    train_accs,
    test_accs,
    train_losses,
    test_losses,
    acc_title: str,
    loss_title: str,
    acc_filename: str,
    loss_filename: str
):
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_accs, marker='o', label="Train Accuracy")
    plt.plot(epochs, test_accs, marker='s', label="Test Accuracy")
    if train_accs:
        best_train_idx = int(np.argmax(train_accs))
        best_train_val = max(train_accs)
        plt.scatter(epochs[best_train_idx], best_train_val, color='blue')
        plt.text(
            epochs[best_train_idx],
            best_train_val + 0.01,
            f"Max Train Acc: {best_train_val:.2f}",
            color='blue'
        )
    if test_accs:
        best_test_idx = int(np.argmax(test_accs))
        best_test_val = max(test_accs)
        plt.scatter(epochs[best_test_idx], best_test_val, color='orange')
        plt.text(
            epochs[best_test_idx],
            best_test_val + 0.01,
            f"Max Test Acc: {best_test_val:.2f}",
            color='orange'
        )
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(acc_title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"plots/{SEARCH_METHOD.lower()}/{acc_filename}", dpi=300)
    plt.show()

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_losses, marker='o', label="Train Loss")
    plt.plot(epochs, test_losses, marker='s', label="Test Loss")
    if train_losses:
        best_train_loss_idx = int(np.argmin(train_losses))
        best_train_loss_val = min(train_losses)
        plt.scatter(epochs[best_train_loss_idx], best_train_loss_val, color='blue')
        plt.text(
            epochs[best_train_loss_idx],
            best_train_loss_val + 0.01,
            f"Min Train Loss: {best_train_loss_val:.2f}",
            color='blue'
        )
    if test_losses:
        best_test_loss_idx = int(np.argmin(test_losses))
        best_test_loss_val = min(test_losses)
        plt.scatter(epochs[best_test_loss_idx], best_test_loss_val, color='orange')
        plt.text(
            epochs[best_test_loss_idx],
            best_test_loss_val + 0.01,
            f"Min Test Loss: {best_test_loss_val:.2f}",
            color='orange'
        )
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(loss_title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"plots/{SEARCH_METHOD.lower()}/{loss_filename}", dpi=300)
    plt.show()

# -----------------------------
#   2) Load the single study, print summaries, scatter-plots, CV curves
# -----------------------------
print("Loading the saved study from disk…")

study_name = f"{SEARCH_METHOD.lower()}_search"
try:
    study = optuna.load_study(
        study_name=study_name,
        storage=f"sqlite:///{SEARCH_METHOD.lower()}_search.db"
    )
    exists = True
except KeyError:
    print(f"Warning: '{study_name}' study not found. Nothing to plot or retrain.")
    exists = False

if not exists:
    raise SystemExit()

print(f"\n{SEARCH_METHOD.upper()} Search trial summary:")
for t in study.trials:
    params = t.params
    acc = t.value if t.value is not None else "N/A"
    loss = t.user_attrs.get("mean_cv_loss", "N/A")
    print(f"  Trial #{t.number:2d} | Params={params} | CV Acc={acc} | CV Loss={loss}")

valid_trials = [t for t in study.trials if t.value is not None]
trial_nums = [t.number for t in valid_trials]
trial_accs = [t.value for t in valid_trials]
trial_losses = [t.user_attrs.get("mean_cv_loss", np.nan) for t in valid_trials]

print(f"\nPlotting scatter-plots of {SEARCH_METHOD.upper()} hyperparameters vs. CV accuracy…")
param_names = list(study.best_params.keys())
for param in param_names:
    x_vals = [t.params[param] for t in valid_trials]
    y_vals = [t.value for t in valid_trials]
    plt.figure(figsize=(6, 4))
    plt.scatter(x_vals, y_vals, marker='o', edgecolor='k')
    plt.xlabel(param)
    plt.ylabel("Mean CV Accuracy")
    plt.title(f"{SEARCH_METHOD.capitalize()} Search: {param} vs. CV Accuracy")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"plots/{SEARCH_METHOD.lower()}/scatter_{SEARCH_METHOD.lower()}_{param}.png", dpi=300)
    plt.show()

print("\nPlotting validation accuracy curve…")
plt.figure(figsize=(8, 5))
plt.plot(trial_nums, trial_accs, marker='o', label=f"{SEARCH_METHOD.capitalize()} CV Acc")
if trial_accs:
    best_idx = trial_nums[np.argmax(trial_accs)]
    best_val = max(trial_accs)
    plt.scatter(best_idx, best_val, color='blue')
    plt.text(best_idx, best_val + 0.005, f"Max CV Acc: {best_val:.2f}", color='blue')
plt.xlabel("Trial Number")
plt.ylabel("Mean CV Accuracy")
plt.title(f"{SEARCH_METHOD.capitalize()} Search: Validation Accuracy ({CV_FOLDS}-Fold, {CV_EPOCHS} Epochs)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f"plots/{SEARCH_METHOD.lower()}/{VAL_CURVES_FILENAME}", dpi=300)
plt.show()

print("\nPlotting validation loss curve…")
plt.figure(figsize=(8, 5))
plt.plot(trial_nums, trial_losses, marker='s', label=f"{SEARCH_METHOD.capitalize()} CV Loss")
if trial_losses:
    best_loss_idx = trial_nums[np.nanargmin(trial_losses)]
    best_loss_val = min([v for v in trial_losses if not np.isnan(v)])
    plt.scatter(best_loss_idx, best_loss_val, color='orange')
    plt.text(best_loss_idx, best_loss_val + 0.005, f"Min CV Loss: {best_loss_val:.2f}", color='orange')
plt.xlabel("Trial Number")
plt.ylabel("Mean CV Loss")
plt.title(f"{SEARCH_METHOD.capitalize()} Search: Validation Loss ({CV_FOLDS}-Fold, {CV_EPOCHS} Epochs)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f"plots/{SEARCH_METHOD.lower()}/{VAL_LOSS_CURVES_FILENAME}", dpi=300)
plt.show()

# -----------------------------
#   3) Retrain best-so-far trial & plot learning curves
# -----------------------------
criterion_final = nn.CrossEntropyLoss()

def build_model_and_optimizer(params):
    model = CNN(num_filters=params["num_filters"], dropout=params["dropout"]).to(DEVICE)
    optimizer = (
        optim.Adam(model.parameters(), lr=params["lr"])
        if params["optimizer"] == "Adam"
        else optim.SGD(model.parameters(), lr=params["lr"])
    )
    return model, optimizer

best_trial = study.best_trial
best_params = best_trial.params
best_cv_acc = best_trial.value
best_cv_loss = best_trial.user_attrs.get("mean_cv_loss", float("nan"))
print(
    f"\nBest-so-far {SEARCH_METHOD.upper()} trial #{best_trial.number} | "
    f"Hyperparams: {best_params} | CV Acc: {best_cv_acc:.4f} | CV Loss: {best_cv_loss:.4f}"
)

model_best, opt_best = build_model_and_optimizer(best_params)
train_loader_best = DataLoader(
    train_dataset,
    batch_size=best_params["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE_TEST,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"Training {SEARCH_METHOD.capitalize()}-best model on full training set…")
best_train_losses, best_train_accs, best_test_losses, best_test_accs = train_full_model(
    model_best,
    train_loader_best,
    test_loader,
    criterion_final,
    opt_best,
    DEVICE,
    NUM_EPOCHS_FINAL
)

epochs_final = np.arange(1, NUM_EPOCHS_FINAL + 1)
print("Plotting learning curves for the best-so-far model…")
plot_learning_curves(
    epochs=epochs_final,
    train_accs=best_train_accs,
    test_accs=best_test_accs,
    train_losses=best_train_losses,
    test_losses=best_test_losses,
    acc_title=f"{SEARCH_METHOD.capitalize()} Best Model: Train & Test Accuracy",
    loss_title=f"{SEARCH_METHOD.capitalize()} Best Model: Train & Test Loss",
    acc_filename=f"best_{SEARCH_METHOD.lower()}_learning_curve.png",
    loss_filename=f"best_{SEARCH_METHOD.lower()}_loss_curve.png"
)
