In [None]:
# imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
# AMP components (GradScaler, autocast) are removed
from torchvision import transforms, models, datasets
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    confusion_matrix,
)
from PIL import Image
import os
import random
import time
import logging # Logging
from tqdm import tqdm # Use standard tqdm for console
import gc

In [None]:
# --- Configuration ---
# Path to your FER2013 CSV file
#CSV_PATH = "../../fer2013.csv"
# Model/Checkpoint saving directory
local_time = time.localtime()
current_date = time.strftime("%Y-%m-%d", local_time)
MODEL_DIR = "models_checkpointed"
# Log file path
LOG_FILE = "training_log.log"
os.makedirs(MODEL_DIR, exist_ok=True)

# --- Basic Logging Setup ---
# Remove existing handlers if any to avoid duplicate logging on re-runs in some environments
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler(LOG_FILE, mode='w'), # mode='w' to overwrite log file on each run
        logging.StreamHandler()
    ]
)
logging.info(f"Logging to console and {LOG_FILE}")

# --- Reproducibility ---
def set_seed(seed=42):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    logging.info(f"Random seed set to {seed}")

SEED = 42
set_seed(SEED)

# --- Training Hyperparameters ---
BASE_MODEL_NAME = "mnetv3_s"
BATCH_SIZE = 32
EPOCHS_PHASE1 = 7 # Increased
EPOCHS_PHASE2 = 20 # Increased
K_FOLDS = 2
LEARNING_RATE_HEAD = 1e-3
LEARNING_RATE_BACKBONE = 1e-5
WEIGHT_DECAY = 1e-2
SCHEDULER_PATIENCE = 3
SCHEDULER_FACTOR = 0.1
EARLY_STOPPING_PATIENCE = 7 # Increased
DROPOUT_RATE = 0.5

# --- Log Hyperparameters ---
logging.info("--- Hyperparameters ---")
logging.info(f"Base Model: {BASE_MODEL_NAME}")
logging.info(f"Batch Size: {BATCH_SIZE}")
logging.info(f"Epochs Phase 1 (Head Training): {EPOCHS_PHASE1}")
logging.info(f"Epochs Phase 2 (Full Model Fine-tuning): {EPOCHS_PHASE2}")
logging.info(f"K-Folds for Cross-Validation: {K_FOLDS}")
logging.info(f"Learning Rate (Head): {LEARNING_RATE_HEAD}")
logging.info(f"Learning Rate (Backbone): {LEARNING_RATE_BACKBONE}")
logging.info(f"Weight Decay: {WEIGHT_DECAY}")
logging.info(f"Scheduler Patience: {SCHEDULER_PATIENCE}")
logging.info(f"Scheduler Factor: {SCHEDULER_FACTOR}")
logging.info(f"Early Stopping Patience: {EARLY_STOPPING_PATIENCE}")
logging.info(f"Dropout Rate: {DROPOUT_RATE}")
logging.info(f"Random Seed: {SEED}")
logging.info(f"AMP (Mixed Precision): False") # Explicitly state AMP is off
logging.info("-----------------------")

In [None]:
## --- Load Data Img ---
dataset_dir = "fer2013/"
TRAIN_DIR = os.path.join(dataset_dir, "train")
TEST_DIR = os.path.join(dataset_dir, "test")

logging.info(f"Loading data from {dataset_dir}")

def load_data_from_dirs(directory_path, positive_class_str="happy"):
    data = []
    supported_extensions = (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff")

    if not os.path.isdir(directory_path):
        logging.error(f"Directory not found: {directory_path}")
        return pd.DataFrame(
            columns=["filepath", "emotion", "label"]
        )  # Return empty DataFrame

    for emotion_name in os.listdir(directory_path):
        emotion_subdir_path = os.path.join(directory_path, emotion_name)
        if os.path.isdir(emotion_subdir_path):
            # Determine binary label based on the original logic (emotion == 3)
            label = 1 if emotion_name == positive_class_str else 0
            for filename in os.listdir(emotion_subdir_path):
                if filename.lower().endswith(supported_extensions):
                    filepath = os.path.join(emotion_subdir_path, filename)
                    data.append(
                        {
                            "filepath": filepath,
                            "emotion": emotion_name,  # Original emotion string
                            "label": label,  # Binary label
                        }
                    )
    if not data:
        logging.warning(
            f"No image files found in {directory_path} or its subdirectories."
        )
    return pd.DataFrame(data)

# Load training/validation data
df_train_val = load_data_from_dirs(TRAIN_DIR, positive_class_str="happy")
if df_train_val.empty and not os.path.isdir(TRAIN_DIR):
    logging.error(
        f"Critical error: Training data directory not found or inaccessible: {TRAIN_DIR}. Exiting."
    )
    exit()
elif df_train_val.empty:
    logging.warning(
        f"No training/validation data loaded from {TRAIN_DIR}. Proceeding with an empty training/validation set."
    )

# Load test data
df_test = load_data_from_dirs(TEST_DIR, positive_class_str="happy")
if df_test.empty and not os.path.isdir(TEST_DIR):
    logging.error(
        f"Critical error: Test data directory not found or inaccessible: {TEST_DIR}. Exiting."
    )
    exit()
elif df_test.empty:
    logging.warning(
        f"No test data loaded from {TEST_DIR}. Proceeding with an empty test set."
    )

logging.info(f"Loaded {len(df_train_val)} samples for Training/Validation")
logging.info(f"Loaded {len(df_test)} samples for Final Testing")

# Class weights calculation for df_train_val
# The 'label' column is already created by load_data_from_dirs
if not df_train_val.empty and "label" in df_train_val.columns:
    class_counts = df_train_val["label"].value_counts().sort_index()
    logging.info(
        f"Class counts (Train/Val): {class_counts.to_dict() if not class_counts.empty else 'No data for class counts'}"
    )

    use_equal_weights = False
    # We expect two classes (0 and 1) for binary classification
    expected_classes = {0, 1}
    present_classes = set(class_counts.index)

    if not expected_classes.issubset(present_classes) or len(present_classes) < 2 :
        logging.warning(
            f"Expected classes {expected_classes} but found {present_classes} in training/validation data. "
            "Using equal class weights [1.0, 1.0]."
        )
        use_equal_weights = True
    elif class_counts.get(0, 0) == 0 or class_counts.get(1, 0) == 0:
        logging.warning(
            "One of the expected binary classes (0 or 1) has zero samples in training/validation data. "
            "Using equal class weights [1.0, 1.0]."
        )
        use_equal_weights = True

    if use_equal_weights:
        class_weights_tensor = torch.tensor([1.0, 1.0], dtype=torch.float32)
    else:
        total_samples = len(df_train_val)
        num_binary_classes = 2  # For classes 0 and 1

        # Calculate weights: total_samples / (num_classes * count_for_class_i)
        # Ensure order is [weight_for_class_0, weight_for_class_1]
        weight_for_class_0 = total_samples / (
            num_binary_classes * class_counts.get(0, 1)
        )  # Avoid division by zero if somehow count is 0
        weight_for_class_1 = total_samples / (
            num_binary_classes * class_counts.get(1, 1)
        )

        # Safety check for counts, though 'use_equal_weights' should prevent 0 counts here
        if class_counts.get(0,0) == 0 or class_counts.get(1,0) == 0:
             logging.error("Logic error: Attempting to calculate weights with zero count for a class. Defaulting to equal weights.")
             class_weights_tensor = torch.tensor([1.0, 1.0], dtype=torch.float32)
        else:
            class_weights_tensor = torch.tensor(
                [weight_for_class_0, weight_for_class_1], dtype=torch.float32
            )
    logging.info(f"Calculated class weights: {class_weights_tensor.numpy()}")

else:
    logging.warning(
        "Training/Validation DataFrame is empty or 'label' column is missing. "
        "Skipping class weight calculation and using default equal weights [1.0, 1.0]."
    )
    class_weights_tensor = torch.tensor(
        [1.0, 1.0], dtype=torch.float32
    ) # Default equal weights
    logging.info(f"Using default class weights: {class_weights_tensor.numpy()}")

In [None]:

# --- Dataset Class ---
class FERDataset(Dataset):
    """
    FERDataset that loads images from file paths.
    Expects DataFrame columns: 'filepath', 'label'
    """
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        if "filepath" not in self.df.columns:
            raise ValueError(f"'filepath' column not found in DataFrame columns: {self.df.columns.tolist()}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["filepath"]).convert("RGB")
        label = row["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

# --- Transforms ---
train_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

val_test_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [None]:
def get_model(dropout_rate_param=DROPOUT_RATE):
    logging.info(f"Initializing {BASE_MODEL_NAME} model.")
    model = models.mobilenet_v3_small(pretrained=True)
    
    # MobileNetV3's classifier is usually: [Linear, Hardswish, Dropout, Linear]
    if hasattr(model, 'classifier'):
        # Adjust dropout rate if present
        for layer in model.classifier:
            if isinstance(layer, nn.Dropout):
                layer.p = dropout_rate_param
                logging.info(f"Adjusted Dropout rate in classifier to {dropout_rate_param}")

        # Replace the final linear layer for binary classification
        if isinstance(model.classifier[-1], nn.Linear):
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, 2)
            logging.info("Replaced final classifier layer for 2 output classes.")
        else:
            logging.error("Final layer of classifier is not Linear as expected.")
            raise AttributeError("Unexpected classifier structure in MobileNetV3.")
    else:
        logging.error(f"Model {BASE_MODEL_NAME} does not have a 'classifier' attribute.")
        raise AttributeError(f"Model {BASE_MODEL_NAME} structure not as expected.")
    return model

In [None]:
# --- Training & Evaluation Functions ---
def train_one_epoch(
    model, loader, criterion, optimizer, device, phase="Head" # scaler removed
):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc=f"Training Epoch ({phase})", leave=False)
    for i, (imgs, labels) in enumerate(pbar):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)

        try:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        except Exception as e:
            logging.error(f"Error during forward/loss calculation in training batch {i}: {e}")
            continue

        try:
            loss.backward()
            optimizer.step()
        except Exception as e:
             logging.error(f"Error during backward/step in training batch {i}: {e}")
             optimizer.zero_grad(set_to_none=True)
             continue

        if not torch.isnan(loss) and not torch.isinf(loss):
            running_loss += loss.item() * imgs.size(0)
        else:
            logging.warning(f"NaN or Inf loss detected in training batch {i}. Skipping accumulation.")
        pbar.set_postfix(loss=loss.item() if not torch.isnan(loss) else float('nan'))

    if not loader.dataset or len(loader.dataset) == 0: return 0.0
    epoch_loss = running_loss / len(loader.dataset)
    return epoch_loss

def evaluate(model, loader, criterion, device): # scaler removed
    model.eval()
    all_preds, all_targets = [], []
    all_probs = []
    running_loss = 0.0
    pbar = tqdm(loader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for i, (imgs, labels) in enumerate(pbar):
            imgs, labels = imgs.to(device), labels.to(device)
            try:
                outputs = model(imgs)
                loss = criterion(outputs, labels)

                if not torch.isnan(loss) and not torch.isinf(loss):
                    running_loss += loss.item() * imgs.size(0)
                else:
                    logging.warning(f"NaN or Inf loss detected in evaluation batch {i}. Skipping accumulation.")

                _, predicted = torch.max(outputs, 1)
                probabilities = torch.softmax(outputs, dim=1)[:, 1]

                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(labels.cpu().numpy())
                all_probs.extend(probabilities.cpu().numpy())
            except Exception as e:
                logging.error(f"Error during evaluation batch {i}: {e}")
                continue

    if not loader.dataset or not all_targets:
        logging.warning("Evaluation dataset empty or all batches failed.")
        return {
            "loss": float('inf'), "accuracy": 0.0, "f1": 0.0, "precision": 0.0,
            "recall": 0.0, "auc": 0.0, "cm": np.zeros((2, 2)), "report": {},
            "preds": np.array([]), "targets": np.array([])
        }

    val_loss = running_loss / len(all_targets) if len(all_targets) > 0 else float('inf')
    all_targets_np = np.array(all_targets)
    all_preds_np = np.array(all_preds)
    all_probs_np = np.array(all_probs)

    val_acc = accuracy_score(all_targets_np, all_preds_np) if len(all_targets_np) > 0 else 0.0
    val_f1 = f1_score(all_targets_np, all_preds_np, average="binary", zero_division=0) if len(all_targets_np) > 0 else 0.0
    val_prec = precision_score(all_targets_np, all_preds_np, average="binary", zero_division=0) if len(all_targets_np) > 0 else 0.0
    val_rec = recall_score(all_targets_np, all_preds_np, average="binary", zero_division=0) if len(all_targets_np) > 0 else 0.0
    val_auc = 0.0
    try:
        if len(np.unique(all_targets_np)) > 1: # Check for at least two classes for AUC
             val_auc = roc_auc_score(all_targets_np, all_probs_np)
        elif len(all_targets_np) > 0: # Only one class present
             logging.warning("AUC calculation skipped: only one class present in evaluation targets.")
    except ValueError as e: # Catch specific sklearn error if all predictions are one class
        logging.warning(f"AUC calculation failed (ValueError): {e}")
    except Exception as e: # Catch any other unexpected error
        logging.error(f"AUC calculation failed (General Error): {e}")


    cm = confusion_matrix(all_targets_np, all_preds_np) if len(all_targets_np) > 0 else np.zeros((2,2))
    report_dict = {}
    if len(all_targets_np) > 0:
        try:
            report_dict = classification_report(
                    all_targets_np, all_preds_np, target_names=["Not Happy", "Happy"], output_dict=True, zero_division=0
                )
        except Exception as e:
            logging.error(f"Error generating classification report: {e}")

    metrics = {
        "loss": val_loss, "accuracy": val_acc, "f1": val_f1, "precision": val_prec,
        "recall": val_rec, "auc": val_auc, "cm": cm, "report": report_dict,
        "preds": all_preds_np, "targets": all_targets_np
    }
    return metrics

def evaluate_loss_only(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            if not torch.isnan(loss) and not torch.isinf(loss):
                running_loss += loss.item() * imgs.size(0)
                total_samples += imgs.size(0)
    if total_samples == 0:
        return float('inf')
    return running_loss / total_samples


# --- Main K-Fold Cross-Validation ---
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
logging.info(f"Using device: {device}")
num_workers = 4 if os.name != "nt" and torch.cuda.is_available() else 0
logging.info(f"Using {num_workers} workers for DataLoaders")

skf = StratifiedKFold(n_splits=K_FOLDS, shuffle=True, random_state=SEED)
X_data = df_train_val.index.values
y_data = df_train_val["label"].values

fold_results = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X_data, y_data)):
    logging.info(f"--- Starting Fold {fold+1}/{K_FOLDS} ---")
    fold_start_time = time.time()

    train_df = df_train_val.iloc[train_idx]
    val_df = df_train_val.iloc[val_idx]

    train_ds = FERDataset(train_df, transform=train_transform)
    val_ds = FERDataset(val_df, transform=val_test_transform)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=False, drop_last=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=False
    )

    model = get_model().to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor.to(device))

    # --- Phase 1: Train the Head ---
    logging.info("--- Phase 1: Training Head ---")
    if hasattr(model, 'features'):
        for param in model.features.parameters():
            param.requires_grad = False
    if hasattr(model, 'classifier'):
        for param in model.classifier.parameters():
            param.requires_grad = True

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LEARNING_RATE_HEAD,
        weight_decay=WEIGHT_DECAY
    )

    for epoch in range(EPOCHS_PHASE1):
        epoch_start_time = time.time()
        train_loss = train_one_epoch(
            model, train_loader, criterion, optimizer, device, phase="Head"
        )
        epoch_time = time.time() - epoch_start_time
        logging.info(
            f"Fold {fold+1} Phase 1 - Epoch {epoch+1}/{EPOCHS_PHASE1}, Train Loss: {train_loss:.4f}, Time: {epoch_time:.2f}s"
        )

    # --- Phase 2: Fine-tune the whole model ---
    logging.info("--- Phase 2: Fine-tuning Full Model ---")
    if hasattr(model, 'features'):
        for param in model.features.parameters():
            param.requires_grad = True

    optimizer = optim.AdamW(
    [
        {
            "params": model.features.parameters() if hasattr(model, 'features') else [],
            "lr": LEARNING_RATE_BACKBONE,
        },
        {
            "params": model.classifier.parameters() if hasattr(model, 'classifier') else model.parameters(),
            "lr": LEARNING_RATE_HEAD,
        },
    ],
    weight_decay=WEIGHT_DECAY,
)

    scheduler = ReduceLROnPlateau(
        optimizer, mode='min', factor=SCHEDULER_FACTOR, patience=SCHEDULER_PATIENCE
    )

    best_val_loss = float('inf')
    patience_counter = 0
    best_epoch = -1
    checkpoint_path = os.path.join(MODEL_DIR, f"best_model_fold_{fold+1}.pth")

    for epoch in range(EPOCHS_PHASE2):
        epoch_start_time = time.time()
        current_epoch_total = EPOCHS_PHASE1 + epoch + 1

        train_loss = train_one_epoch(
            model, train_loader, criterion, optimizer, device, phase="Full"
        )
        val_loss = evaluate_loss_only(model, val_loader, criterion, device)

        epoch_time = time.time() - epoch_start_time
        current_lrs = [group['lr'] for group in optimizer.param_groups]
        logging.info(
            f"Fold {fold+1} Phase 2 - Epoch {epoch+1}/{EPOCHS_PHASE2} (Total: {current_epoch_total}), "
            f"LR: {current_lrs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
            f"Time: {epoch_time:.2f}s"
        )

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_epoch = current_epoch_total
            try:
                torch.save({
                    'epoch': current_epoch_total,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': best_val_loss,
                }, checkpoint_path)
                logging.info(
                    f"  -> New best Val Loss: {best_val_loss:.4f} at epoch {best_epoch}. Checkpoint saved to {checkpoint_path}"
                )
            except Exception as e:
                logging.error(f"Error saving checkpoint: {e}")
        else:
            patience_counter += 1
            logging.info(f"  -> Val Loss did not improve. Patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")

        if patience_counter >= EARLY_STOPPING_PATIENCE:
            logging.info(f"  -> Early stopping triggered at epoch {current_epoch_total}.")
            break

    # Load best model and do full evaluation (all metrics) at the end of the fold
    if os.path.exists(checkpoint_path):
        logging.info(f"Loading best model from {checkpoint_path} (Epoch {best_epoch}, Val Loss: {best_val_loss:.4f})")
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
        except Exception as e:
            logging.error(f"Error loading checkpoint: {e}. Using last model state.")
    else:
        logging.warning("No best model checkpoint found for this fold. Using last model state.")

    logging.info(f"--- Evaluating Best Model for Fold {fold+1} ---")
    final_fold_metrics = evaluate(model, val_loader, criterion, device)  # Full metrics only here

    logging.info(f"Fold {fold+1} Final Validation Results (Best Model):")
    logging.info(f"  Accuracy:  {final_fold_metrics['accuracy']:.4f}")
    logging.info(f"  F1 Score:  {final_fold_metrics['f1']:.4f}")
    logging.info(f"  Precision: {final_fold_metrics['precision']:.4f}")
    logging.info(f"  Recall:    {final_fold_metrics['recall']:.4f}")
    logging.info(f"  AUC:       {final_fold_metrics['auc']:.4f}")
    logging.info(f"  Loss:      {final_fold_metrics['loss']:.4f}")
    logging.info("  Confusion Matrix:")
    logging.info(f"\n{final_fold_metrics['cm']}")

    fold_results.append(final_fold_metrics)
    fold_time = time.time() - fold_start_time
    logging.info(f"--- Fold {fold+1} completed in {fold_time:.2f}s ---")
    del model
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
# --- Cross-Validation Summary ---
logging.info("--- Cross-Validation Summary ---")
if fold_results:
    accs = [r["accuracy"] for r in fold_results]
    f1s = [r["f1"] for r in fold_results]
    precs = [r["precision"] for r in fold_results]
    recs = [r["recall"] for r in fold_results]
    aucs = [r["auc"] for r in fold_results]

    logging.info(f"Mean Accuracy:  {np.mean(accs):.4f} ± {np.std(accs):.4f}")
    logging.info(f"Mean F1 Score:  {np.mean(f1s):.4f} ± {np.std(f1s):.4f}")
    logging.info(f"Mean Precision: {np.mean(precs):.4f} ± {np.std(precs):.4f}")
    logging.info(f"Mean Recall:    {np.mean(recs):.4f} ± {np.std(recs):.4f}")
    logging.info(f"Mean AUC:       {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
else:
    logging.warning("No fold results to summarize.")

# --- Training Final Model on Full Train/Val Data ---
logging.info("--- Training Final Model on Full Train/Val Dataset ---")
final_model = get_model().to(device)
final_ds = FERDataset(df_train_val, transform=train_transform)
final_loader = DataLoader(
    final_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=False, drop_last=True
)
criterion_final = nn.CrossEntropyLoss(weight=class_weights_tensor.to(device))
# scaler_final is removed

# --- Final Model - Phase 1: Train the Head ---
logging.info("--- Final Model - Phase 1: Training Head ---")
if hasattr(final_model, 'features'):
    for param in final_model.features.parameters():
        param.requires_grad = False
if hasattr(final_model, 'classifier'):
    for param in final_model.classifier.parameters():
        param.requires_grad = True

optimizer_head_final = optim.AdamW(
    filter(lambda p: p.requires_grad, final_model.parameters()),
    lr=LEARNING_RATE_HEAD,
    weight_decay=WEIGHT_DECAY
)
for epoch in range(EPOCHS_PHASE1):
    epoch_start_time = time.time()
    train_loss = train_one_epoch(
        final_model, final_loader, criterion_final, optimizer_head_final, device, phase="Head" # scaler removed
    )
    epoch_time = time.time() - epoch_start_time
    logging.info(
        f"Final Phase 1 - Epoch {epoch+1}/{EPOCHS_PHASE1}, Train Loss: {train_loss:.4f}, Time: {epoch_time:.2f}s"
    )

# --- Final Model - Phase 2: Fine-tune the whole model ---
logging.info("--- Final Model - Phase 2: Fine-tuning Full Model ---")
if hasattr(final_model, 'features'):
    for param in final_model.features.parameters():
        param.requires_grad = True

optimizer = optim.AdamW(
    [
        {
            "params": model.features.parameters() if hasattr(model, 'features') else [],
            "lr": LEARNING_RATE_BACKBONE,
        },
        {
            "params": model.classifier.parameters() if hasattr(model, 'classifier') else model.parameters(),
            "lr": LEARNING_RATE_HEAD,
        },
    ],
    weight_decay=WEIGHT_DECAY,
)

# Optionally, you can use a validation split from df_train_val for early stopping,
# or just train for all epochs if you want to use all data.

for epoch in range(EPOCHS_PHASE2):
    epoch_start_time = time.time()
    train_loss = train_one_epoch(
        final_model, final_loader, criterion_final, optimizer_full_final, device, phase="Full"
    )
    epoch_time = time.time() - epoch_start_time
    logging.info(
        f"Final Phase 2 - Epoch {epoch+1}/{EPOCHS_PHASE2}, Train Loss: {train_loss:.4f}, Time: {epoch_time:.2f}s"
    )

final_model_save_path = os.path.join(MODEL_DIR, f"{BASE_MODEL_NAME.lower()}_fer2013_happy_final.pth")
try:
    torch.save(final_model.state_dict(), final_model_save_path)
    logging.info(f"Final model state_dict saved to {final_model_save_path}")
    del final_model
    torch.cuda.empty_cache()
    gc.collect()
except Exception as e:
    logging.error(f"Error saving final model: {e}")

In [None]:
# --- Final Evaluation on Test Set ---
torch.cuda.empty_cache()
gc.collect()

final_model_save_path=os.path.join(MODEL_DIR, f"{BASE_MODEL_NAME.lower()}_fer2013_happy_final.pth")
logging.info("--- Evaluating Final Model on PrivateTest Set ---")
if len(df_test) > 0:
    test_ds = FERDataset(df_test, transform=val_test_transform)
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=False
    )

    try:
        eval_model = get_model().to(device) # Create a new model instance for evaluation
        eval_model.load_state_dict(torch.load(final_model_save_path, map_location=device))
        eval_model.eval() # Ensure model is in eval mode
        logging.info(f"Successfully loaded final model from {final_model_save_path}")

        # 1. Standard Metrics Evaluation
        logging.info("--- Calculating Performance Metrics on Test Set ---")
        test_metrics = evaluate(eval_model, test_loader, criterion_final, device)

        logging.info("Final Test Set Performance Metrics:")
        logging.info(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
        logging.info(f"  F1 Score:  {test_metrics['f1']:.4f}")
        logging.info(f"  Precision: {test_metrics['precision']:.4f}")
        logging.info(f"  Recall:    {test_metrics['recall']:.4f}")
        logging.info(f"  AUC:       {test_metrics['auc']:.4f}")
        logging.info(f"  Loss:      {test_metrics['loss']:.4f}")
        logging.info("  Confusion Matrix:")
        logging.info(f"\n{test_metrics['cm']}")
        logging.info("  Classification Report:")
        if test_metrics['targets'].size > 0:
             logging.info(f"\n{classification_report(test_metrics['targets'], test_metrics['preds'], target_names=['Not Happy', 'Happy'], zero_division=0)}")
        else:
             logging.warning("No predictions generated during final test evaluation for report.")

        # 2. Inference Speed Test
        logging.info("--- Performing Inference Speed Test on Test Set ---")
        dummy_batch = None
        try:
            dummy_batch = next(iter(test_loader)) # Get one batch for warm-up
        except StopIteration:
            logging.warning("Test loader is empty, cannot perform warm-up for speed test.")
        
        if dummy_batch:
            dummy_imgs, _ = dummy_batch
            dummy_imgs = dummy_imgs.to(device)

            # Warm-up iterations
            logging.info("Performing warm-up inferences...")
            with torch.no_grad():
                for _ in range(5): # Number of warm-up iterations
                    _ = eval_model(dummy_imgs)
                if device.type == 'cuda':
                    torch.cuda.synchronize()
            logging.info("Warm-up complete.")

        total_inference_time = 0
        total_images_processed = 0
        inference_times_per_batch = []

        logging.info("Starting timed inferences...")
        with torch.no_grad():
            for imgs, _ in tqdm(test_loader, desc="Inference Speed Test"):
                imgs = imgs.to(device)
                
                if device.type == 'cuda':
                    torch.cuda.synchronize()
                start_time = time.perf_counter() # More precise timer

                _ = eval_model(imgs)

                if device.type == 'cuda':
                    torch.cuda.synchronize()
                end_time = time.perf_counter()

                batch_time = end_time - start_time
                inference_times_per_batch.append(batch_time)
                total_inference_time += batch_time
                total_images_processed += imgs.size(0)
        
        if total_images_processed > 0:
            avg_time_per_image = total_inference_time / total_images_processed
            images_per_second = total_images_processed / total_inference_time
            avg_time_per_batch = total_inference_time / len(test_loader)

            logging.info("Inference Speed Test Results:")
            logging.info(f"  Total images processed: {total_images_processed}")
            logging.info(f"  Total inference time: {total_inference_time:.4f} seconds")
            logging.info(f"  Average inference time per batch: {avg_time_per_batch:.6f} seconds (Batch Size: {BATCH_SIZE})")
            logging.info(f"  Average inference time per image: {avg_time_per_image:.6f} seconds")
            logging.info(f"  Images Per Second (IPS): {images_per_second:.2f}")
        else:
            logging.warning("No images were processed during the inference speed test (Test set might be empty or all batches failed).")

    except FileNotFoundError:
        logging.error(f"Final model file not found at {final_model_save_path}. Cannot evaluate on test set.")
    except Exception as e:
        logging.error(f"Error during final evaluation or speed test on test set: {e}", exc_info=True)

else:
    logging.warning("PrivateTest set is empty or could not be loaded. Skipping final evaluation.")