In [None]:
import kagglehub
kagglehub.login()


In [None]:
grassknoted_asl_alphabet_path = kagglehub.dataset_download('grassknoted/asl-alphabet')

jadecw_try_custom2_path = kagglehub.dataset_download('jadecw/try-custom2')

print('Data source import complete.')


In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import datasets, transforms, models
from torchvision.models import resnet18, ResNet18_Weights
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns

In [None]:
# ============================================================================
# SECTION 0: SETUP & REPRODUCIBILITY
# ============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")

SEED = 429
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

In [None]:
# ===== DATASET PATHS =====
import kagglehub

path = kagglehub.dataset_download("grassknoted/asl-alphabet")
print(f"Dataset downloaded to: {path}\n")

DATA_ROOT_TRAIN = os.path.join(path, "asl_alphabet_train", "asl_alphabet_train")
DATA_ROOT_TEST = os.path.join(path, "asl_alphabet_test", "asl_alphabet_test")

print(f"Training path: {DATA_ROOT_TRAIN}")
print(f"Test path: {DATA_ROOT_TEST}\n")

In [None]:
# ============================================================================
# SECTION 1: FLAT TEST DATASET CLASS
# ============================================================================

class FlatTestDataset(Dataset):
    """Handles test set with one image per class."""

    def __init__(self, image_dir, class_names, transform=None,
                 valid_extensions=('.jpg', '.jpeg', '.png')):
        self.image_dir = image_dir
        self.class_names = class_names
        self.transform = transform
        self.valid_extensions = valid_extensions

        self.image_files = []
        for filename in sorted(os.listdir(image_dir)):
            if any(filename.lower().endswith(ext) for ext in valid_extensions):
                self.image_files.append(filename)

        if len(self.image_files) == 0:
            raise FileNotFoundError(f"No images found in {image_dir}")

        print(f"Found {len(self.image_files)} test images")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        filename = self.image_files[idx]
        img_path = os.path.join(self.image_dir, filename)

        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            raise

        if self.transform:
            img = self.transform(img)

        true_label = self.get_class_from_filename(filename)
        return img, filename, true_label

    def get_class_from_filename(self, filename):
        """Extract class label from filename."""
        if filename.lower().startswith("nothing"):
            return "nothing"
        elif filename.lower().startswith("space"):
            return "space"
        else:
            return filename[0].upper()


In [None]:
# ============================================================================
# SECTION 2: DATA LOADING WITH ADVANCED AUGMENTATION
# ============================================================================

# ===== LOAD TRAINING SET =====
full_train_dataset = datasets.ImageFolder(root=DATA_ROOT_TRAIN, transform=None)
num_classes = len(full_train_dataset.classes)

print(f"✓ Training set loaded: {len(full_train_dataset)} images")
print(f"✓ Number of classes: {num_classes}")
print(f"✓ Classes: {full_train_dataset.classes}\n")

# ===== CREATE 80/20 STRATIFIED SPLIT (SEED=429) =====
indices = np.arange(len(full_train_dataset))
labels = np.array(full_train_dataset.targets)

train_idx, val_idx = train_test_split(
    indices, test_size=0.2, stratify=labels, random_state=429
)

print(f"✓ Train/Val split created:")
print(f"  Train: {len(train_idx)} images")
print(f"  Val: {len(val_idx)} images\n")

# ===== DEFINE TRANSFORMS WITH DATA AUGMENTATION =====
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# TRAINING: augmentation to handle real-world variations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # Geometric augmentations
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),

    # Color/brightness augmentations
    transforms.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.2,
        hue=0.1
    ),

    # Blur
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)),

    # Final conversions
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# VALIDATION & TEST: No augmentation
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

print("=" * 80)
print("DATA AUGMENTATION STRATEGY")
print("=" * 80)
print("""
Training Set Augmentations:
  1. Geometric:
     - Random horizontal flip (50%)
     - Random rotation (±15°)
     - Random translation (±10%)

  2. Color/Brightness:
     - Brightness (±30%)
     - Contrast (±30%)
     - Saturation (±20%)
     - Hue (±10%)

  3. Blur: Gaussian blur for camera simulation

Validation & Test: No augmentation (consistent evaluation)
""")
print("=" * 80 + "\n")

# ===== CREATE DATASETS WITH TRANSFORMS =====
train_full = datasets.ImageFolder(root=DATA_ROOT_TRAIN, transform=train_transform)
val_full = datasets.ImageFolder(root=DATA_ROOT_TRAIN, transform=val_test_transform)

train_dataset = Subset(train_full, train_idx)
val_dataset = Subset(val_full, val_idx)

# ===== LOAD TEST SET =====
test_dataset_original = FlatTestDataset(
    DATA_ROOT_TEST, full_train_dataset.classes,
    transform=val_test_transform
)
print(f"✓ Test set loaded: {len(test_dataset_original)} images\n")


In [None]:
# ============================================================================
# SECTION 3: HYPERPARAMETER CONFIGURATION FOR ALL POLICIES
# ============================================================================

HYPERPARAMS = {
    "T-A": {
        "num_epochs": 5,
        "batch_size": 128,
        "lr": 1e-2,
        "weight_decay": 1e-5,
        "optimizer": "Adam",
    },
    "T-B": {
        "num_epochs": 5,
        "batch_size": 64,
        "lr": 5e-4,
        "weight_decay": 1e-4,
        "optimizer": "Adam",
    },
    "T-C": {
        "num_epochs": 5,
        "batch_size": 64,
        "lr": 1e-4,
        "weight_decay": 1e-4,
        "optimizer": "Adam",
    },
    "S-A": {
        "num_epochs": 5,
        "batch_size": 32,
        "lr": 1e-3,
        "weight_decay": 1e-4,
        "optimizer": "SGD",
    },
}

os.makedirs("checkpoints", exist_ok=True)

print("=" * 80)
print("HYPERPARAMETER CONFIGURATION")
print("=" * 80 + "\n")

for policy, params in HYPERPARAMS.items():
    print(f"{policy}:")
    for key, value in params.items():
        print(f"  {key:15s}: {value}")
    print()

# ===== HELPER: Create DataLoaders with Specific Batch Size =====
def create_dataloaders(batch_size, train_idx, val_idx):
    """Create dataloaders with specified batch size."""
    train_full_local = datasets.ImageFolder(root=DATA_ROOT_TRAIN, transform=train_transform)
    val_full_local = datasets.ImageFolder(root=DATA_ROOT_TRAIN, transform=val_test_transform)

    train_dataset_local = Subset(train_full_local, train_idx)
    val_dataset_local = Subset(val_full_local, val_idx)

    pin = True if device.type == "cuda" else False

    train_loader = DataLoader(
        train_dataset_local, batch_size=batch_size, shuffle=True,
        num_workers=2, pin_memory=pin
    )
    val_loader = DataLoader(
        val_dataset_local, batch_size=batch_size, shuffle=False,
        num_workers=2, pin_memory=pin
    )

    return train_loader, val_loader


In [None]:
# ============================================================================
# SECTION 4: MODEL CREATION & FREEZING POLICIES
# ============================================================================

criterion = nn.CrossEntropyLoss()

def create_resnet18(num_classes, pretrained=True):
    """Create ResNet-18 with custom head."""
    if pretrained:
        weights = ResNet18_Weights.IMAGENET1K_V1
        model = resnet18(weights=weights)
        print(f"✓ Created ResNet-18 with ImageNet weights")
    else:
        model = resnet18(weights=None)
        print(f"✓ Created ResNet-18 from scratch")

    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def apply_policy(model, policy):
    """Apply freezing policy."""
    for param in model.parameters():
        param.requires_grad = False

    if policy == "T-A":
        for param in model.fc.parameters():
            param.requires_grad = True

    elif policy == "T-B":
        for param in model.layer4.parameters():
            param.requires_grad = True
        for param in model.fc.parameters():
            param.requires_grad = True

    elif policy == "T-C":
        for param in model.layer3.parameters():
            param.requires_grad = True
        for param in model.layer4.parameters():
            param.requires_grad = True
        for param in model.fc.parameters():
            param.requires_grad = True

    elif policy == "S-A":
        for param in model.parameters():
            param.requires_grad = True

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Policy {policy}: {trainable:,} / {total:,} parameters trainable")



In [None]:
# ============================================================================
# SECTION 5: TRAINING FUNCTIONS
# ============================================================================

def get_optimizer(model, lr=1e-3, weight_decay=1e-4, optimizer_name="Adam"):
    """Get optimizer for trainable parameters."""
    params = [p for p in model.parameters() if p.requires_grad]

    if optimizer_name == "Adam":
        optimizer = optim.Adam(params, lr=lr, weight_decay=weight_decay)
    elif optimizer_name == "SGD":
        optimizer = optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {optimizer_name}")

    return optimizer

def train_one_epoch(model, dataloader, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        all_preds.append(preds.detach().cpu())
        all_targets.append(labels.detach().cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(all_targets, all_preds)
    epoch_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0)

    return epoch_loss, epoch_acc, epoch_f1

def evaluate(model, dataloader, device):
    """Evaluate on validation/test set."""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in dataloader:
            images = batch[0].to(device)

            if len(batch) == 2:
                labels = batch[1].to(device)
            else:
                labels = batch[2].to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            all_preds.append(preds.detach().cpu())
            all_targets.append(labels.detach().cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(all_targets, all_preds)
    epoch_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0)

    return epoch_loss, epoch_acc, epoch_f1, all_preds, all_targets

def train_model(model, train_loader, val_loader, device, num_epochs=10,
                lr=1e-3, weight_decay=1e-4, optimizer_name="Adam",
                experiment_name="exp"):
    """Train model with best checkpoint tracking."""
    optimizer = get_optimizer(model, lr=lr, weight_decay=weight_decay,
                             optimizer_name=optimizer_name)

    history = {
        "train_loss": [], "train_acc": [], "train_f1": [],
        "val_loss": [], "val_acc": [], "val_f1": [],
    }

    best_val_f1 = -1.0
    best_state = None

    for epoch in range(1, num_epochs + 1):
        train_loss, train_acc, train_f1 = train_one_epoch(
            model, train_loader, optimizer, device
        )
        val_loss, val_acc, val_f1, _, _ = evaluate(model, val_loader, device)

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["train_f1"].append(train_f1)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["val_f1"].append(val_f1)

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "val_f1": val_f1,
                "val_acc": val_acc,
            }

        print(
            f"[{experiment_name}] Epoch {epoch:02d}/{num_epochs:02d} | "
            f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f} | "
            f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}"
        )

    print(f"\n[{experiment_name}] Best val macro-F1: {best_val_f1:.4f}\n")
    return model, history, best_state


In [None]:
# ============================================================================
# SECTION 6: TRAIN ALL FOUR MODELS WITH OPTIMIZED HYPERPARAMETERS
# ============================================================================

results_summary = {}
all_histories = {}

In [None]:
# ===== T-A: HEAD ONLY =====
print("=" * 80)
print("TRAINING: T-A (Head Only)")
print("=" * 80)
print_params = HYPERPARAMS["T-A"]
print(f"LR={print_params['lr']}, BatchSize={print_params['batch_size']}, "
      f"Optimizer={print_params['optimizer']}, Epochs={print_params['num_epochs']}\n")

train_loader_TA, val_loader_TA = create_dataloaders(
    HYPERPARAMS["T-A"]["batch_size"], train_idx, val_idx
)

model_TA = create_resnet18(num_classes=num_classes, pretrained=True)
apply_policy(model_TA, "T-A")
model_TA = model_TA.to(device)

model_TA, history_TA, best_TA = train_model(
    model_TA, train_loader_TA, val_loader_TA, device,
    num_epochs=HYPERPARAMS["T-A"]["num_epochs"],
    lr=HYPERPARAMS["T-A"]["lr"],
    weight_decay=HYPERPARAMS["T-A"]["weight_decay"],
    optimizer_name=HYPERPARAMS["T-A"]["optimizer"],
    experiment_name="T-A"
)

torch.save(best_TA, "checkpoints/resnet18_T-A_best.pth")
results_summary["T-A"] = best_TA
all_histories["T-A"] = history_TA
print("✓ Saved T-A checkpoint\n")

In [None]:
# ===== T-B: LAST BLOCK + HEAD =====
print("=" * 80)
print("TRAINING: T-B (Last Block + Head)")
print("=" * 80)
print_params = HYPERPARAMS["T-B"]
print(f"LR={print_params['lr']}, BatchSize={print_params['batch_size']}, "
      f"Optimizer={print_params['optimizer']}, Epochs={print_params['num_epochs']}\n")

train_loader_TB, val_loader_TB = create_dataloaders(
    HYPERPARAMS["T-B"]["batch_size"], train_idx, val_idx
)

model_TB = create_resnet18(num_classes=num_classes, pretrained=True)
apply_policy(model_TB, "T-B")
model_TB = model_TB.to(device)

model_TB, history_TB, best_TB = train_model(
    model_TB, train_loader_TB, val_loader_TB, device,
    num_epochs=HYPERPARAMS["T-B"]["num_epochs"],
    lr=HYPERPARAMS["T-B"]["lr"],
    weight_decay=HYPERPARAMS["T-B"]["weight_decay"],
    optimizer_name=HYPERPARAMS["T-B"]["optimizer"],
    experiment_name="T-B"
)

torch.save(best_TB, "checkpoints/resnet18_T-B_best.pth")
results_summary["T-B"] = best_TB
all_histories["T-B"] = history_TB
print("✓ Saved T-B checkpoint\n")

In [None]:
# ===== T-C: PROGRESSIVE UNFREEZING =====
print("=" * 80)
print("TRAINING: T-C (Progressive Unfreezing)")
print("=" * 80)
print_params = HYPERPARAMS["T-C"]
print(f"LR={print_params['lr']}, BatchSize={print_params['batch_size']}, "
      f"Optimizer={print_params['optimizer']}, Epochs={print_params['num_epochs']}\n")

train_loader_TC, val_loader_TC = create_dataloaders(
    HYPERPARAMS["T-C"]["batch_size"], train_idx, val_idx
)

model_TC = create_resnet18(num_classes=num_classes, pretrained=True)
tb_checkpoint = torch.load("checkpoints/resnet18_T-B_best.pth", weights_only=False)
model_TC.load_state_dict(tb_checkpoint["model_state_dict"])
print(f"Loaded T-B checkpoint from epoch {tb_checkpoint['epoch']} "
      f"(val F1: {tb_checkpoint['val_f1']:.4f})\n")

apply_policy(model_TC, "T-C")
model_TC = model_TC.to(device)

model_TC, history_TC, best_TC = train_model(
    model_TC, train_loader_TC, val_loader_TC, device,
    num_epochs=HYPERPARAMS["T-C"]["num_epochs"],
    lr=HYPERPARAMS["T-C"]["lr"],
    weight_decay=HYPERPARAMS["T-C"]["weight_decay"],
    optimizer_name=HYPERPARAMS["T-C"]["optimizer"],
    experiment_name="T-C"
)

torch.save(best_TC, "checkpoints/resnet18_T-C_best.pth")
results_summary["T-C"] = best_TC
all_histories["T-C"] = history_TC
print("✓ Saved T-C checkpoint\n")

In [None]:
# ===== S-A: TRAIN FROM SCRATCH =====
print("=" * 80)
print("TRAINING: S-A (Train from Scratch)")
print("=" * 80)
print_params = HYPERPARAMS["S-A"]
print(f"LR={print_params['lr']}, BatchSize={print_params['batch_size']}, "
      f"Optimizer={print_params['optimizer']}, Epochs={print_params['num_epochs']}\n")

train_loader_SA, val_loader_SA = create_dataloaders(
    HYPERPARAMS["S-A"]["batch_size"], train_idx, val_idx
)

model_SA = create_resnet18(num_classes=num_classes, pretrained=False)
apply_policy(model_SA, "S-A")
model_SA = model_SA.to(device)

model_SA, history_SA, best_SA = train_model(
    model_SA, train_loader_SA, val_loader_SA, device,
    num_epochs=HYPERPARAMS["S-A"]["num_epochs"],
    lr=HYPERPARAMS["S-A"]["lr"],
    weight_decay=HYPERPARAMS["S-A"]["weight_decay"],
    optimizer_name=HYPERPARAMS["S-A"]["optimizer"],
    experiment_name="S-A"
)

torch.save(best_SA, "checkpoints/resnet18_S-A_best.pth")
results_summary["S-A"] = best_SA
all_histories["S-A"] = history_SA
print("✓ Saved S-A checkpoint\n")



In [None]:
# ============================================================================
# SECTION 7: ABLATION STUDY - VALIDATION SET COMPARISON
# ============================================================================

print("=" * 80)
print("ABLATION STUDY: Comparing All 4 Models on Validation Set")
print("=" * 80 + "\n")

# Use standard dataloader for ablation
pin = True if device.type == "cuda" else False
val_loader_standard = DataLoader(
    val_dataset, batch_size=64, shuffle=False,
    num_workers=2, pin_memory=pin
)

ablation_results = {}

for policy_name in ["T-A", "T-B", "T-C", "S-A"]:
    checkpoint = results_summary[policy_name]
    model = create_resnet18(num_classes=num_classes,
                           pretrained=(policy_name != "S-A"))
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)

    val_loss, val_acc, val_f1, preds, targets = evaluate(model, val_loader_standard, device)
    cm = confusion_matrix(targets, preds)

    ablation_results[policy_name] = {
        "val_loss": val_loss,
        "val_acc": val_acc,
        "val_f1": val_f1,
        "cm": cm,
    }

    print(f"{policy_name:5s} | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")

print()

best_policy = max(ablation_results, key=lambda x: ablation_results[x]["val_f1"])
print(f"✓ Best model: {best_policy} (F1: {ablation_results[best_policy]['val_f1']:.4f})\n")



In [None]:
# ============================================================================
# SECTION 8: PLOT TRAINING CURVES
# ============================================================================

def plot_training_curves(histories, experiment_names):
    """Plot training and validation curves."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle("Training Curves - Phase 1 Ablation Study", fontsize=16, fontweight="bold")

    # Loss
    ax = axes[0, 0]
    for name, hist in zip(experiment_names, histories):
        ax.plot(hist["train_loss"], label=f"{name} (train)", marker="o", alpha=0.7)
        ax.plot(hist["val_loss"], label=f"{name} (val)", marker="s", alpha=0.7)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title("Loss vs Epoch")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # Accuracy
    ax = axes[0, 1]
    for name, hist in zip(experiment_names, histories):
        ax.plot(hist["train_acc"], label=f"{name} (train)", marker="o", alpha=0.7)
        ax.plot(hist["val_acc"], label=f"{name} (val)", marker="s", alpha=0.7)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Accuracy")
    ax.set_title("Accuracy vs Epoch")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # F1 Score
    ax = axes[1, 0]
    for name, hist in zip(experiment_names, histories):
        ax.plot(hist["train_f1"], label=f"{name} (train)", marker="o", alpha=0.7)
        ax.plot(hist["val_f1"], label=f"{name} (val)", marker="s", alpha=0.7)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Macro-F1")
    ax.set_title("Macro-F1 vs Epoch")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # Summary table
    ax = axes[1, 1]
    ax.axis("off")

    table_data = []
    for name in experiment_names:
        hist = histories[experiment_names.index(name)]
        table_data.append([
            name,
            f"{hist['val_loss'][-1]:.4f}",
            f"{hist['val_acc'][-1]:.4f}",
            f"{hist['val_f1'][-1]:.4f}",
        ])

    table = ax.table(
        cellText=table_data,
        colLabels=["Policy", "Val Loss", "Val Acc", "Val F1"],
        cellLoc="center",
        loc="center",
    )
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)

    plt.tight_layout()
    plt.savefig("training_curves.png", dpi=150, bbox_inches="tight")
    plt.show()
    print("✓ Training curves saved to training_curves.png\n")

histories = [history_TA, history_TB, history_TC, history_SA]
experiment_names = ["T-A", "T-B", "T-C", "S-A"]
plot_training_curves(histories, experiment_names)


In [None]:
# ============================================================================
# SECTION 9: TEST SET EVALUATION
# ============================================================================

print("=" * 80)
print("EVALUATING BEST MODEL ON ORIGINAL TEST SET (28 IMAGES)")
print("=" * 80 + "\n")

def evaluate_test_set_with_filenames(model, test_loader, device, class_names):
    """Evaluate on flat test set and return predictions."""
    model.eval()
    all_preds = []
    all_targets = []
    all_filenames = []

    with torch.no_grad():
        for batch in test_loader:
            images = batch[0].to(device)
            filenames = batch[1]
            true_labels_str = batch[2]

            true_labels = torch.tensor([class_names.index(label) for label in true_labels_str]).to(device)

            outputs = model(images)
            preds = outputs.argmax(dim=1)

            all_preds.append(preds.cpu())
            all_targets.append(true_labels.cpu())
            all_filenames.extend(filenames)

    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    accuracy = accuracy_score(all_targets, all_preds)
    macro_f1 = f1_score(all_targets, all_preds, average="macro", zero_division=0)
    cm = confusion_matrix(all_targets, all_preds)

    return accuracy, macro_f1, cm, all_preds, all_targets, all_filenames

# Create test dataloader
pin = True if device.type == "cuda" else False
test_loader_original = DataLoader(
    test_dataset_original, batch_size=64, shuffle=False,
    num_workers=2, pin_memory=pin
)

# Check which checkpoint files exist
checkpoint_dir = "checkpoints"
existing_checkpoints = {}

for policy in ["T-A", "T-B", "T-C", "S-A"]:
    checkpoint_path = os.path.join(checkpoint_dir, f"resnet18_{policy}_best.pth")
    if os.path.exists(checkpoint_path):
        existing_checkpoints[policy] = checkpoint_path
        print(f"✓ Found checkpoint: {policy}")
    else:
        print(f"✗ Missing checkpoint: {policy}")

print()

# Determine best policy from existing checkpoints and ablation results
available_policies = list(existing_checkpoints.keys())

if available_policies:
    # Find best among available
    best_policy_available = None
    best_f1_available = -1

    for policy in available_policies:
        if policy in ablation_results:
            if ablation_results[policy]["val_f1"] > best_f1_available:
                best_f1_available = ablation_results[policy]["val_f1"]
                best_policy_available = policy

    if best_policy_available is None:
        best_policy_available = available_policies[0]

    best_policy = best_policy_available
    print(f"Using {best_policy} as best model\n")
else:
    print("ERROR: No checkpoints found!")
    print("Available policies in ablation_results:")
    for policy in ablation_results:
        print(f"  {policy}: F1={ablation_results[policy]['val_f1']:.4f}")
    print("\nPlease check if training completed successfully.")
    best_policy = max(ablation_results, key=lambda x: ablation_results[x]["val_f1"])
    print(f"Using {best_policy} from ablation results (but checkpoint may not exist)\n")

# Load best model
checkpoint_path = os.path.join(checkpoint_dir, f"resnet18_{best_policy}_best.pth")

if not os.path.exists(checkpoint_path):
    print(f"ERROR: Checkpoint {checkpoint_path} does not exist!")
    print(f"Available files in {checkpoint_dir}:")
    if os.path.exists(checkpoint_dir):
        for f in os.listdir(checkpoint_dir):
            print(f"  - {f}")
    else:
        print(f"  Directory {checkpoint_dir} does not exist!")

    # Try to use any available checkpoint
    if existing_checkpoints:
        best_policy = list(existing_checkpoints.keys())[0]
        checkpoint_path = existing_checkpoints[best_policy]
        print(f"\nUsing fallback: {best_policy}\n")
    else:
        print("\nCannot proceed without checkpoint!")
        print("Please ensure training completed successfully.")
        raise FileNotFoundError(f"No checkpoint found for {best_policy}")

print(f"Loading checkpoint: {checkpoint_path}\n")

try:
    best_checkpoint = torch.load(checkpoint_path, weights_only=False)
except Exception as e:
    print(f"Error loading checkpoint: {e}")
    raise

best_model = create_resnet18(num_classes=num_classes,
                            pretrained=(best_policy != "S-A"))
best_model.load_state_dict(best_checkpoint["model_state_dict"])
best_model = best_model.to(device)

print(f"Loaded {best_policy} model from epoch {best_checkpoint['epoch']}")
print(f"(Val F1 at checkpoint: {best_checkpoint['val_f1']:.4f})\n")

# Evaluate on original test set
test_acc, test_f1, test_cm, test_preds, test_targets, test_fnames = \
    evaluate_test_set_with_filenames(best_model, test_loader_original, device,
                                    full_train_dataset.classes)

print(f"Test Set Results ({best_policy}):")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  Macro-F1: {test_f1:.4f}")
print(f"  Images tested: {len(test_fnames)}")
print(f"  Confusion Matrix shape: {test_cm.shape}\n")

# Show individual predictions
print("Individual Test Predictions:")
print("-" * 70)
for fname, pred, true in zip(test_fnames, test_preds, test_targets):
    pred_class = full_train_dataset.classes[pred]
    true_class = full_train_dataset.classes[true]
    status = "✓ CORRECT" if pred == true else "✗ WRONG"
    print(f"{status:12s} | File: {fname:20s} | Predicted: {pred_class:10s} | True: {true_class:10s}")

print("\n" + "-" * 70)
correct_count = sum(1 for p, t in zip(test_preds, test_targets) if p == t)
print(f"Summary: {correct_count}/{len(test_fnames)} correct ({100*test_acc:.1f}%)\n")

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(14, 12))
sns.heatmap(test_cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=full_train_dataset.classes,
            yticklabels=full_train_dataset.classes,
            ax=ax, cbar_kws={"label": "Count"})
ax.set_title(f"Test Set Confusion Matrix - {best_policy} (28 Images)\nAccuracy: {test_acc:.4f} | Macro-F1: {test_f1:.4f}")
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
plt.tight_layout()
plt.savefig(f"confusion_matrix_{best_policy}_test.png", dpi=150, bbox_inches="tight")
plt.show()
print("✓ Test confusion matrix saved\n")

In [None]:
# ============================================================================
# SECTION 10: CUSTOM TEST SET EVALUATION
# ============================================================================

print("=" * 80)
print("CUSTOM TEST SET EVALUATION")
print("=" * 80 + "\n")

# Check if we have the best model loaded
if 'best_model' in locals() and 'best_policy' in locals():
    try:
        # Try custom test set from jadecw/try-custom2
        CUSTOM_TEST_DIR = "/kaggle/input/try-custom2/custom_test_set"

        if os.path.exists(CUSTOM_TEST_DIR):
            print(f"✓ Custom test set found at: {CUSTOM_TEST_DIR}\n")

            custom_test_dataset = datasets.ImageFolder(
                root=CUSTOM_TEST_DIR, transform=val_test_transform
            )
            custom_test_loader = DataLoader(
                custom_test_dataset, batch_size=64, shuffle=False,
                num_workers=2, pin_memory=pin
            )

            print(f"Custom test set loaded:")
            print(f"  Total images: {len(custom_test_dataset)}")
            print(f"  Classes: {custom_test_dataset.classes}\n")

            # Evaluate best model on custom test set
            custom_loss, custom_acc, custom_f1, custom_preds, custom_targets = \
                evaluate(best_model, custom_test_loader, device)
            custom_cm = confusion_matrix(custom_targets, custom_preds)

            print(f"Custom Test Set Results ({best_policy}):")
            print(f"  Accuracy: {custom_acc:.4f}")
            print(f"  Macro-F1: {custom_f1:.4f}")
            print(f"  Images tested: {len(custom_test_dataset)}\n")

            # Per-class results
            print("Per-Class Results:")
            print("-" * 70)
            for i, class_name in enumerate(custom_test_dataset.classes):
                class_mask = custom_targets == i
                if class_mask.sum() > 0:
                    class_correct = (custom_preds[class_mask] == custom_targets[class_mask]).sum()
                    class_total = class_mask.sum()
                    class_acc = class_correct / class_total
                    status = "✓" if class_acc == 1.0 else "✗" if class_acc == 0.0 else "~"
                    print(f"{status} {class_name:15s}: {class_acc:.4f} ({int(class_correct)}/{int(class_total)})")

            print()

            # Plot confusion matrix
            fig, ax = plt.subplots(figsize=(12, 10))
            sns.heatmap(custom_cm, annot=True, fmt="d", cmap="Greens",
                        xticklabels=custom_test_dataset.classes,
                        yticklabels=custom_test_dataset.classes,
                        ax=ax, cbar_kws={"label": "Count"})
            ax.set_title(f"Custom Test Confusion Matrix - {best_policy}\nAccuracy: {custom_acc:.4f} | Macro-F1: {custom_f1:.4f}")
            ax.set_xlabel("Predicted")
            ax.set_ylabel("True")
            plt.tight_layout()
            plt.savefig(f"confusion_matrix_{best_policy}_custom.png", dpi=150, bbox_inches="tight")
            plt.show()
            print("✓ Custom test confusion matrix saved\n")

            # Comparison with original test set
            print("=" * 70)
            print("COMPARISON: Original vs Custom Test Set")
            print("=" * 70)
            print(f"Original Test (28 images):  Acc={test_acc:.4f} | F1={test_f1:.4f}")
            print(f"Custom Test ({len(custom_test_dataset)} images):   Acc={custom_acc:.4f} | F1={custom_f1:.4f}")

            acc_diff = custom_acc - test_acc
            f1_diff = custom_f1 - test_f1
            print(f"\nDifference (Custom - Original):")
            print(f"  Accuracy: {acc_diff:+.4f}")
            print(f"  Macro-F1: {f1_diff:+.4f}")

            if acc_diff > 0 and f1_diff > 0:
                print(f"  ✓ Custom test performs BETTER than original")
            elif acc_diff < 0 or f1_diff < 0:
                print(f"  ⚠️  Custom test performs WORSE than original (may indicate overfitting)")
            else:
                print(f"  ~ Similar performance")
            print()

        else:
            print(f"⚠️  Custom test set not found at: {CUSTOM_TEST_DIR}")
            print("Tried path: /kaggle/input/jadecw-try-custom2/custom_test_set\n")

    except Exception as e:
        print(f"⚠️  Error loading custom test set: {e}\n")

else:
    print("⚠️  Best model not loaded - skipping custom test evaluation\n")


In [None]:
# ============================================================================
# SECTION 11: FINAL SUMMARY
# ============================================================================

print("=" * 80)
print("FINAL SUMMARY - PHASE 1 COMPLETE")
print("=" * 80 + "\n")

# Check if we have results
if 'ablation_results' in locals() and ablation_results:

    # Summary table
    print("Ablation Study Results (Validation Set):")
    print("-" * 70)
    print(f"{'Policy':<8} {'Val Loss':<12} {'Val Acc':<12} {'Val F1':<12}")
    print("-" * 70)

    for policy_name in ["T-A", "T-B", "T-C", "S-A"]:
        if policy_name in ablation_results:
            result = ablation_results[policy_name]
            print(f"{policy_name:<8} {result['val_loss']:<12.4f} "
                  f"{result['val_acc']:<12.4f} {result['val_f1']:<12.4f}")

    print()

    # Best model summary
    best_val_result = ablation_results[best_policy]
    print(f"\n{'='*70}")
    print(f"BEST MODEL: {best_policy}")
    print(f"{'='*70}")
    print(f"  Val Loss: {best_val_result['val_loss']:.4f}")
    print(f"  Val Accuracy: {best_val_result['val_acc']:.4f}")
    print(f"  Val Macro-F1: {best_val_result['val_f1']:.4f}\n")

    print("Test Set Performance:")
    print("-" * 70)
    print(f"  Original Test (28 images):")
    print(f"    - Accuracy: {test_acc:.4f}")
    print(f"    - Macro-F1: {test_f1:.4f}\n")

else:
    print("⚠️  No training results available\n")

print("=" * 80)
print("✓ PHASE 1 COMPLETE - All Models Trained & Evaluated!\n")