In [None]:
# Check and install required packages (supercomputer-optimized)
import subprocess
import sys
import os
from pathlib import Path

def check_and_install_package(package_name, import_name=None):
    """Check if package is installed, install if not"""
    if import_name is None:
        import_name = package_name
    try:
        __import__(import_name)
        print(f"✓ {package_name} is already installed")
        return True
    except ImportError:
        print(f"⚠ {package_name} not found, installing...")
        # Use --user flag for installations (works in read-only systems)
        subprocess.check_call([sys.executable, "-m", "pip", "install", "--user", package_name])
        return False

# Setup scratch directory for caches and configs
scratch_base = None
scratch_paths = [
    Path("/scratch") / os.getenv("USER", "user"),
    Path("/tmp") / os.getenv("USER", "user"),
    Path.home() / "scratch",
    Path.home() / "work",
    Path.home() / "tmp",
]

for scratch_path in scratch_paths:
    try:
        scratch_path.mkdir(parents=True, exist_ok=True)
        test_file = scratch_path / ".write_test"
        test_file.write_text("test")
        test_file.unlink()
        scratch_base = scratch_path
        break
    except:
        continue

if scratch_base:
    # Set matplotlib config directory to scratch
    mpl_config_dir = scratch_base / ".matplotlib"
    mpl_config_dir.mkdir(exist_ok=True)
    os.environ['MPLCONFIGDIR'] = str(mpl_config_dir)
    
    # Set fontconfig cache to scratch
    fontconfig_cache = scratch_base / ".fontconfig"
    fontconfig_cache.mkdir(exist_ok=True)
    os.environ['XDG_CACHE_HOME'] = str(scratch_base / ".cache")
    os.environ['XDG_CONFIG_HOME'] = str(scratch_base / ".config")
    
    # Set pip cache to scratch
    pip_cache_dir = scratch_base / ".pip_cache"
    pip_cache_dir.mkdir(exist_ok=True)
    os.environ['PIP_CACHE_DIR'] = str(pip_cache_dir)
    
    print(f"✓ Using scratch directory: {scratch_base}")
    print(f"✓ Matplotlib config: {mpl_config_dir}")
    print(f"✓ Fontconfig cache: {fontconfig_cache}")
    print(f"✓ Pip cache: {pip_cache_dir}")
else:
    print("⚠ Scratch not available, using standard paths")

# Check for required packages
required_packages = [
    ("torch", "torch"),
    ("torchvision", "torchvision"),
    ("timm", "timm"),
    ("albumentations", "albumentations"),
    ("pillow", "PIL"),
    ("numpy", "numpy"),
    ("scikit-learn", "sklearn"),
    ("pandas", "pandas"),
    ("tqdm", "tqdm"),
    ("datasets", "datasets"),
    ("peft", "peft"),  # For LoRA training
    ("transformers", "transformers"),  # Required by peft
    ("matplotlib", "matplotlib"),
]

print("\n" + "="*60)
print("CHECKING AND INSTALLING PACKAGES")
print("="*60)
for pkg, imp in required_packages:
    check_and_install_package(pkg, imp)

# Verify GPU availability and system info
import torch
print("\n" + "="*60)
print("SYSTEM INFORMATION")
print("="*60)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    print(f"CUDA version: {torch.version.cuda}")
    gpu_props = torch.cuda.get_device_properties(0)
    print(f"GPU Memory: {gpu_props.total_memory / 1e9:.2f} GB")
    print(f"GPU Compute Capability: {gpu_props.major}.{gpu_props.minor}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    # RTX 8000 specific optimizations
    if "RTX 8000" in gpu_name or "8000" in gpu_name:
        print("\n✓ RTX 8000 detected - Optimizing for 48GB VRAM")
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
else:
    print("⚠ No GPU detected - will use CPU")

# Check available memory
try:
    import psutil
    mem = psutil.virtual_memory()
    print(f"\nSystem RAM: {mem.total / 1e9:.2f} GB")
    print(f"Available RAM: {mem.available / 1e9:.2f} GB")
except:
    pass

print("="*60)

In [None]:
# Load Hugging Face dataset
from datasets import load_dataset
import shutil
from pathlib import Path

# ======================
# Setup Scratch Directory for HF Dataset
# ======================
# Reuse scratch_base from Cell 0 if available, otherwise detect again
if 'scratch_base' not in globals() or scratch_base is None:
    scratch_base = None
    scratch_paths = [
        Path("/scratch") / os.getenv("USER", "user"),
        Path("/tmp") / os.getenv("USER", "user"),
        Path.home() / "scratch",
        Path.home() / "work",
        Path.home() / "tmp",
    ]
    
    print("="*60)
    print("SETTING UP SCRATCH DIRECTORY FOR HF DATASET")
    print("="*60)
    
    for scratch_path in scratch_paths:
        try:
            scratch_path.mkdir(parents=True, exist_ok=True)
            test_file = scratch_path / ".write_test"
            test_file.write_text("test")
            test_file.unlink()
            scratch_base = scratch_path
            print(f"✓ Using scratch directory: {scratch_base}")
            break
        except (PermissionError, OSError) as e:
            print(f"✗ Cannot use {scratch_path}: {e}")
            continue
    
    if scratch_base is None:
        scratch_base = Path.cwd()
        print(f"⚠ Scratch not available, using current directory: {scratch_base}")
else:
    print("="*60)
    print("USING SCRATCH DIRECTORY FROM CELL 0")
    print("="*60)
    print(f"✓ Scratch directory: {scratch_base}")

# Create HF dataset directory in scratch
hf_data_dir = scratch_base / "hf_dataset_temp"
hf_data_dir.mkdir(parents=True, exist_ok=True)
print(f"✓ HF dataset will be stored in: {hf_data_dir}")
print("="*60)

# Load Hugging Face dataset: HamdiJr/Egyptian_hieroglyphs
# This dataset will be merged with the local dataset in Cell 4
print("\nLoading Hugging Face dataset: HamdiJr/Egyptian_hieroglyphs...")
hf_dataset = load_dataset("HamdiJr/Egyptian_hieroglyphs")

print(f"✓ Loaded dataset with splits: {list(hf_dataset.keys())}")
print(f"  Train samples: {len(hf_dataset['train'])}")
print(f"  Test samples: {len(hf_dataset['test'])}")

print(f"\nExtracting images from Hugging Face dataset...")
hf_data = []

# Process train split
for idx, example in enumerate(tqdm(hf_dataset['train'], desc="Processing HF train")):
    try:
        image = example['image']
        # Handle different possible label field names
        if 'label' in example:
            label = example['label']
        elif 'labels' in example:
            label = example['labels']
        else:
            # Try to get from image path or skip
            print(f"Warning: No label found for example {idx}, skipping...")
            continue
        
        # Skip UNKNOWN labels if desired (optional - you can include them)
        # if label == 'UNKNOWN':
        #     continue
        
        # Save image with label in filename
        filename = f"{label}_{idx}.png"
        filepath = hf_data_dir / filename
        image.save(str(filepath))
        hf_data.append({'file_name': filename, 'label': label, 'source': 'hf'})
    except Exception as e:
        print(f"Error processing HF train example {idx}: {e}")
        continue

# Process test split
for idx, example in enumerate(tqdm(hf_dataset['test'], desc="Processing HF test")):
    try:
        image = example['image']
        # Handle different possible label field names
        if 'label' in example:
            label = example['label']
        elif 'labels' in example:
            label = example['labels']
        else:
            print(f"Warning: No label found for test example {idx}, skipping...")
            continue
        
        # Skip UNKNOWN labels if desired (optional)
        # if label == 'UNKNOWN':
        #     continue
        
        filename = f"{label}_test_{idx}.png"
        filepath = hf_data_dir / filename
        image.save(str(filepath))
        hf_data.append({'file_name': filename, 'label': label, 'source': 'hf'})
    except Exception as e:
        print(f"Error processing HF test example {idx}: {e}")
        continue

print(f"✓ Extracted {len(hf_data)} images from Hugging Face dataset")
print(f"  Unique HF classes: {len(set([d['label'] for d in hf_data]))}")

# Store for later use in Cell 4
# DataFrame format: file_name, label (class name as string), source='hf'
hf_df = pd.DataFrame(hf_data)
print(f"\n✓ HF dataset ready for merging in Cell 4")
print(f"  Format: {list(hf_df.columns)}")
print(f"  Will be merged with local dataset from finalized_dataset/homogenized/train_flat/")

In [None]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

import timm
import albumentations as albu
from albumentations.pytorch import ToTensorV2

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import StepLR

from sklearn.metrics import f1_score
from tqdm import tqdm
import copy
import time

import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import albumentations as albu
from albumentations.pytorch import ToTensorV2

from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from peft import LoraConfig, get_peft_model, TaskType

In [None]:
class CustomDataset(Dataset):
    """
    A custom Dataset for images labeled as either real or AI-generated.

    Args:
        df (pd.DataFrame): DataFrame with columns ['file_name', 'label'].
        data_dir (str): Directory where the image files are stored.
        transform (albu.Compose, optional): Albumentations transform to apply.

    Returns:
        (image, label): Transformed image (as a torch.Tensor) and the corresponding label.
    """
    def __init__(self, df, data_dir, transform=None):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transform

        # Pre-extract image paths and labels for speed
        self.image_paths = [
            os.path.join(self.data_dir, fname) for fname in df['file_name'].values
        ]
        self.labels = df['label'].values

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 1. Load image
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        image = np.array(image)  # Convert PIL image to NumPy

        # 2. Apply transforms
        if self.transform is not None:
            image = self.transform(image=image)['image']

        # 3. Fetch label
        label = self.labels[idx]

        return image, label


In [None]:
from pathlib import Path

BATCH_SIZE = 32
SEED = 42
IMG_SIZE = 224

# ======================
# 1. Load all data from homogenized/train_flat and extract classes
# ======================
# As per plan.plan.md: Scan homogenized/train_flat/ directory for all PNG files
# Extract class name from filename:
#   - A1.png -> class "A1"
#   - D21_1.png -> class "D21" (strip _1, _2, etc. suffixes)
train_data_dir = "finalized_dataset/homogenized/train_flat"

# Get all PNG files
all_files = [f for f in os.listdir(train_data_dir) if f.endswith('.png')]
print(f"Found {len(all_files)} PNG files in {train_data_dir}")

# Extract class names from filenames (e.g., "A1.png" -> "A1", "D21_1.png" -> "D21")
def extract_class_name(filename):
    # Remove .png extension
    name = filename.replace('.png', '')
    # Split by underscore and take first part (e.g., "D21_1" -> "D21")
    class_name = name.split('_')[0]
    return class_name

# Create DataFrame from local dataset with columns: file_name, label (class name as string)
data = []
for filename in all_files:
    class_name = extract_class_name(filename)
    data.append({'file_name': filename, 'label': class_name, 'source': 'local'})

local_df = pd.DataFrame(data)
print(f"Local dataset: {len(local_df)} images, {len(local_df['label'].unique())} unique classes")

# ======================
# 2. Merge with Hugging Face dataset
# ======================
# Merge local dataset (from finalized_dataset/homogenized/train_flat/) 
# with Hugging Face dataset (from Cell 1)
# Both datasets use the same format: file_name, label (class name as string), source
print(f"\n{'='*60}")
print(f"MERGING DATASETS")
print(f"{'='*60}")

# HF dataframe already has correct format from Cell 1
hf_df_merged = hf_df.copy()

# Combine both datasets
print(f"  Local dataset: {len(local_df)} images, {len(local_df['label'].unique())} classes")
print(f"  HF dataset: {len(hf_df)} images, {len(hf_df['label'].unique())} classes")

# Combine DataFrames (both have: file_name, label, source)
all_df = pd.concat([local_df, hf_df_merged], ignore_index=True)

print(f"  Combined dataset: {len(all_df)} images")
print(f"  Total unique classes: {len(all_df['label'].unique())}")
print(f"{'='*60}")

# ======================
# 3. Create class mapping (class name -> numeric index)
# ======================
# As per plan.plan.md: Get unique classes, map to numeric indices (0, 1, 2, ..., N)
# Store mapping dictionary for later reference
unique_classes = sorted(all_df['label'].unique())
class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}
num_classes = len(unique_classes)

print(f"\n{'='*60}")
print(f"CLASS MAPPING")
print(f"{'='*60}")
print(f"Total images: {len(all_df)}")
print(f"Total classes: {num_classes}")
print(f"Sample classes: {list(unique_classes[:15])}")
print(f"{'='*60}")

# Map class names to numeric indices (0, 1, 2, ..., N)
all_df['label'] = all_df['label'].map(class_to_idx)

# ======================
# 4. Apply 80/20 train/test split
# ======================
# As per plan.plan.md: Use train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)
# Stratified split ensures balanced class distribution in both sets
# Result: train_df (80%), test_df (20%)
# Optionally split train_df further into train/val for validation
print(f"\n{'='*60}")
print(f"APPLYING 80/20 TRAIN/TEST SPLIT")
print(f"{'='*60}")

train_df, test_df = train_test_split(
    all_df,
    test_size=0.2,  # 20% for test
    random_state=SEED,
    stratify=all_df['label']  # Stratified to ensure balanced class distribution
)

# Further split train into train/val (70% train, 10% val from original)
# This gives us: 70% train, 10% val, 20% test
train_df, val_df = train_test_split(
    train_df,
    test_size=0.125,  # 10% of original = 0.1/0.8 = 0.125
    random_state=SEED,
    stratify=train_df['label']  # Stratified split
)

print(f"Train samples: {len(train_df)} ({100*len(train_df)/len(all_df):.1f}%)")
print(f"Validation samples: {len(val_df)} ({100*len(val_df)/len(all_df):.1f}%)")
print(f"Test samples: {len(test_df)} ({100*len(test_df)/len(all_df):.1f}%)")
print(f"{'='*60}")

# ======================
# 4. Define Transforms
# ======================
train_transform = albu.Compose([
    albu.HorizontalFlip(p=0.5),
    albu.Resize(IMG_SIZE, IMG_SIZE),
    albu.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = albu.Compose([
    albu.Resize(IMG_SIZE, IMG_SIZE),
    albu.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

test_transform = albu.Compose([
    albu.Resize(IMG_SIZE, IMG_SIZE),
    albu.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])


# ======================
# 5. Create Datasets & Dataloaders
# ======================
# CustomDataset needs to handle multiple data directories
class MergedDataset(Dataset):
    """
    Dataset that handles images from multiple directories (local + HF)
    """
    def __init__(self, df, local_data_dir, hf_data_dir, transform=None):
        super().__init__()
        self.transform = transform
        self.image_paths = []
        
        # Convert hf_data_dir to string if it's a Path object
        hf_data_dir = str(hf_data_dir) if isinstance(hf_data_dir, Path) else hf_data_dir
        
        for _, row in df.iterrows():
            fname = row['file_name']
            source = row.get('source', 'local')
            
            if source == 'hf':
                image_path = os.path.join(hf_data_dir, fname)
            else:
                image_path = os.path.join(local_data_dir, fname)
            
            self.image_paths.append(image_path)
        
        self.labels = df['label'].values

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        image = np.array(image)
        if self.transform is not None:
            image = self.transform(image=image)['image']
        label = self.labels[idx]
        return image, label

# hf_data_dir is set in Cell 1 (uses scratch directory)
# If not set, fallback to local directory
if 'hf_data_dir' not in globals():
    from pathlib import Path
    hf_data_dir = Path("hf_dataset_temp")
    print("⚠ Warning: hf_data_dir not found, using local directory. Run Cell 1 first!")

# MergedDataset will handle Path/string conversion internally
train_dataset = MergedDataset(
    df=train_df,
    local_data_dir=train_data_dir,
    hf_data_dir=hf_data_dir,
    transform=train_transform
)

val_dataset = MergedDataset(
    df=val_df,
    local_data_dir=train_data_dir,
    hf_data_dir=hf_data_dir,
    transform=val_transform
)

test_dataset = MergedDataset(
    df=test_df,
    local_data_dir=train_data_dir,
    hf_data_dir=hf_data_dir,
    transform=test_transform
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4
)

# ======================
# 6. Quick Sanity Check
# ======================
print("Train Batch:")
for images, labels in train_dataloader:
    print(" Images shape:", images.shape)
    print(" Labels:", labels)
    break

print("\nValidation Batch:")
for images, labels in val_dataloader:
    print(" Images shape:", images.shape)
    print(" Labels:", labels)
    break

print("\Test Batch:")
for images, labels in test_dataloader:
    print(" Images shape:", images.shape)
    print(" Labels:", labels)
    break

In [None]:
import timm

class ViTClassifier(nn.Module):
    def __init__(self, model_name="vit_base_patch16_224", num_classes=2, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.head.in_features
        self.model.head = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)

model_name = "vit_base_patch16_224"
# num_classes is set dynamically from the data in Cell 4
# num_classes variable should be available from previous cell

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ======================
# Create Model 1: Full Fine-tuning
# ======================
print("\n" + "="*60)
print("CREATING MODEL 1: FULL FINE-TUNING")
print("="*60)
model_full = ViTClassifier(model_name=model_name, num_classes=num_classes, pretrained=True)
model_full.to(device)

total_params_full = sum(p.numel() for p in model_full.parameters())
trainable_params_full = sum(p.numel() for p in model_full.parameters() if p.requires_grad)
print(f"Total parameters: {total_params_full:,}")
print(f"Trainable parameters: {trainable_params_full:,} ({100*trainable_params_full/total_params_full:.2f}%)")

# ======================
# Create Model 2: LoRA
# ======================
print("\n" + "="*60)
print("CREATING MODEL 2: LoRA (Low-Rank Adaptation)")
print("="*60)

# Create base model for LoRA
model_lora_base = ViTClassifier(model_name=model_name, num_classes=num_classes, pretrained=True)

# LoRA configuration
lora_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    r=16,  # Rank (low-rank dimension)
    lora_alpha=32,  # LoRA alpha scaling
    target_modules=["qkv", "proj", "fc1", "fc2"],  # Target attention and MLP layers
    lora_dropout=0.1,
    bias="none",
)

# Apply LoRA to model
model_lora = get_peft_model(model_lora_base, lora_config)
model_lora.to(device)

total_params_lora = sum(p.numel() for p in model_lora.parameters())
trainable_params_lora = sum(p.numel() for p in model_lora.parameters() if p.requires_grad)
print(f"Total parameters: {total_params_lora:,}")
print(f"Trainable parameters (LoRA): {trainable_params_lora:,} ({100*trainable_params_lora/total_params_lora:.2f}%)")
print(f"\nParameter reduction: {100*(1 - trainable_params_lora/trainable_params_full):.1f}% fewer trainable parameters")
print("="*60)

In [None]:
criterion = nn.CrossEntropyLoss()

# Create separate optimizers for both models
optimizer_full = optim.AdamW(model_full.parameters(), lr=1e-4, weight_decay=1e-4)
optimizer_lora = optim.AdamW(model_lora.parameters(), lr=2e-4, weight_decay=1e-4)  # Slightly higher LR for LoRA

# Create separate schedulers
scheduler_full = optim.lr_scheduler.CosineAnnealingLR(optimizer_full, T_max=20, verbose=False)
scheduler_lora = optim.lr_scheduler.CosineAnnealingLR(optimizer_lora, T_max=20, verbose=False)

print("Optimizers created:")
print(f"  Full model: AdamW, lr=1e-4")
print(f"  LoRA model: AdamW, lr=2e-4")

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def validate_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def test_one_epoch(model, loader, criterion, device):

    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Testing", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
# ======================
# Training Progress Visualization
# ======================
# Run this cell after Cell 9 (training) completes

if 'history' in globals() and len(history['full']['train_loss']) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Training Progress: Full Fine-tuning vs LoRA', fontsize=16, fontweight='bold')

    epochs = range(1, len(history['full']['train_loss']) + 1)

    # Loss curves
    axes[0, 0].plot(epochs, history['full']['train_loss'], 'b-', label='Full - Train', linewidth=2)
    axes[0, 0].plot(epochs, history['full']['val_loss'], 'b--', label='Full - Val', linewidth=2)
    axes[0, 0].plot(epochs, history['lora']['train_loss'], 'r-', label='LoRA - Train', linewidth=2)
    axes[0, 0].plot(epochs, history['lora']['val_loss'], 'r--', label='LoRA - Val', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy curves
    axes[0, 1].plot(epochs, history['full']['train_acc'], 'b-', label='Full - Train', linewidth=2)
    axes[0, 1].plot(epochs, history['full']['val_acc'], 'b--', label='Full - Val', linewidth=2)
    axes[0, 1].plot(epochs, history['lora']['train_acc'], 'r-', label='LoRA - Train', linewidth=2)
    axes[0, 1].plot(epochs, history['lora']['val_acc'], 'r--', label='LoRA - Val', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Accuracy Curves')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Validation accuracy comparison
    axes[1, 0].plot(epochs, history['full']['val_acc'], 'b-o', label='Full Model', linewidth=2, markersize=6)
    axes[1, 0].plot(epochs, history['lora']['val_acc'], 'r-s', label='LoRA Model', linewidth=2, markersize=6)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Validation Accuracy')
    axes[1, 0].set_title('Validation Accuracy Comparison')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Training time comparison
    models = ['Full Model', 'LoRA Model']
    times = [time_full, time_lora]
    colors = ['blue', 'red']
    bars = axes[1, 1].bar(models, times, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
    axes[1, 1].set_ylabel('Total Training Time (seconds)')
    axes[1, 1].set_title('Training Time Comparison')
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, time_val in zip(bars, times):
        height = bar.get_height()
        axes[1, 1].text(bar.get_x() + bar.get_width()/2., height,
                       f'{time_val:.1f}s',
                       ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()
    plt.show()

    # Print summary statistics
    print("\n" + "="*80)
    print("TRAINING SUMMARY")
    print("="*80)
    print(f"\n{'Metric':<30} {'Full Model':<20} {'LoRA Model':<20}")
    print("-"*80)
    print(f"{'Best Val Accuracy':<30} {best_val_acc_full:<20.4f} {best_val_acc_lora:<20.4f}")
    print(f"{'Final Train Accuracy':<30} {history['full']['train_acc'][-1]:<20.4f} {history['lora']['train_acc'][-1]:<20.4f}")
    print(f"{'Final Val Accuracy':<30} {history['full']['val_acc'][-1]:<20.4f} {history['lora']['val_acc'][-1]:<20.4f}")
    print(f"{'Trainable Parameters':<30} {trainable_params_full:<20,} {trainable_params_lora:<20,}")
    print(f"{'Total Training Time':<30} {time_full:<20.2f}s {time_lora:<20.2f}s")
    print("="*80)
else:
    print("⚠ Training history not found. Please run Cell 9 (training) first.")
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training Progress: Full Fine-tuning vs LoRA', fontsize=16, fontweight='bold')

epochs = range(1, len(history['full']['train_loss']) + 1)

# Loss curves
axes[0, 0].plot(epochs, history['full']['train_loss'], 'b-', label='Full - Train', linewidth=2)
axes[0, 0].plot(epochs, history['full']['val_loss'], 'b--', label='Full - Val', linewidth=2)
axes[0, 0].plot(epochs, history['lora']['train_loss'], 'r-', label='LoRA - Train', linewidth=2)
axes[0, 0].plot(epochs, history['lora']['val_loss'], 'r--', label='LoRA - Val', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Loss Curves')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy curves
axes[0, 1].plot(epochs, history['full']['train_acc'], 'b-', label='Full - Train', linewidth=2)
axes[0, 1].plot(epochs, history['full']['val_acc'], 'b--', label='Full - Val', linewidth=2)
axes[0, 1].plot(epochs, history['lora']['train_acc'], 'r-', label='LoRA - Train', linewidth=2)
axes[0, 1].plot(epochs, history['lora']['val_acc'], 'r--', label='LoRA - Val', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Accuracy Curves')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Validation accuracy comparison
axes[1, 0].plot(epochs, history['full']['val_acc'], 'b-o', label='Full Model', linewidth=2, markersize=6)
axes[1, 0].plot(epochs, history['lora']['val_acc'], 'r-s', label='LoRA Model', linewidth=2, markersize=6)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Validation Accuracy')
axes[1, 0].set_title('Validation Accuracy Comparison')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Training time comparison (if available)
if 'time_full' in globals() and 'time_lora' in globals():
    models = ['Full Model', 'LoRA Model']
    times = [time_full, time_lora]
    colors = ['blue', 'red']
    bars = axes[1, 1].bar(models, times, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
    axes[1, 1].set_ylabel('Total Training Time (seconds)')
    axes[1, 1].set_title('Training Time Comparison')
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, time_val in zip(bars, times):
        height = bar.get_height()
        axes[1, 1].text(bar.get_x() + bar.get_width()/2., height,
                       f'{time_val:.1f}s',
                       ha='center', va='bottom', fontweight='bold')
else:
    # Parameter count comparison
    models = ['Full Model', 'LoRA Model']
    trainable = [trainable_params_full, trainable_params_lora]
    colors = ['blue', 'red']
    bars = axes[1, 1].bar(models, trainable, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
    axes[1, 1].set_ylabel('Trainable Parameters')
    axes[1, 1].set_title('Trainable Parameters Comparison')
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, params in zip(bars, trainable):
        height = bar.get_height()
        axes[1, 1].text(bar.get_x() + bar.get_width()/2., height,
                       f'{params/1e6:.2f}M',
                       ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# Print summary statistics
print("\n" + "="*80)
print("TRAINING SUMMARY")
print("="*80)
print(f"\n{'Metric':<30} {'Full Model':<20} {'LoRA Model':<20}")
print("-"*80)
print(f"{'Best Val Accuracy':<30} {best_val_acc_full:<20.4f} {best_val_acc_lora:<20.4f}")
print(f"{'Final Train Accuracy':<30} {history['full']['train_acc'][-1]:<20.4f} {history['lora']['train_acc'][-1]:<20.4f}")
print(f"{'Final Val Accuracy':<30} {history['full']['val_acc'][-1]:<20.4f} {history['lora']['val_acc'][-1]:<20.4f}")
print(f"{'Trainable Parameters':<30} {trainable_params_full:<20,} {trainable_params_lora:<20,}")
print(f"{'Total Training Time':<30} {time_full:<20.2f}s {time_lora:<20.2f}s")
print("="*80)


In [None]:
EPOCHS = 20
patience = 4

# Track metrics for both models
best_val_acc_full = 0.0
best_val_acc_lora = 0.0
wait_full = 0
wait_lora = 0

# Training history for visualization
history = {
    'full': {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []},
    'lora': {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
}

# Track training times
time_full = 0.0
time_lora = 0.0

print("="*80)
print("STARTING DUAL MODEL TRAINING")
print("="*80)

for epoch in range(EPOCHS):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"{'='*80}")
    
    # Train Full Model
    print("\n[Full Model]")
    start_time = time.time()
    train_loss_full, train_acc_full = train_one_epoch(model_full, train_dataloader, criterion, optimizer_full, device)
    val_loss_full, val_acc_full = validate_one_epoch(model_full, val_dataloader, criterion, device)
    epoch_time_full = time.time() - start_time
    time_full += epoch_time_full
    scheduler_full.step()
    
    # Train LoRA Model
    print("\n[LoRA Model]")
    start_time = time.time()
    train_loss_lora, train_acc_lora = train_one_epoch(model_lora, train_dataloader, criterion, optimizer_lora, device)
    val_loss_lora, val_acc_lora = validate_one_epoch(model_lora, val_dataloader, criterion, device)
    epoch_time_lora = time.time() - start_time
    time_lora += epoch_time_lora
    scheduler_lora.step()
    
    # Store history
    history['full']['train_loss'].append(train_loss_full)
    history['full']['train_acc'].append(train_acc_full)
    history['full']['val_loss'].append(val_loss_full)
    history['full']['val_acc'].append(val_acc_full)
    
    history['lora']['train_loss'].append(train_loss_lora)
    history['lora']['train_acc'].append(train_acc_lora)
    history['lora']['val_loss'].append(val_loss_lora)
    history['lora']['val_acc'].append(val_acc_lora)
    
    # Print results
    print(f"\n{'─'*80}")
    print(f"{'Metric':<20} {'Full Model':<20} {'LoRA Model':<20}")
    print(f"{'─'*80}")
    print(f"{'Train Loss':<20} {train_loss_full:<20.4f} {train_loss_lora:<20.4f}")
    print(f"{'Train Acc':<20} {train_acc_full:<20.4f} {train_acc_lora:<20.4f}")
    print(f"{'Val Loss':<20} {val_loss_full:<20.4f} {val_loss_lora:<20.4f}")
    print(f"{'Val Acc':<20} {val_acc_full:<20.4f} {val_acc_lora:<20.4f}")
    print(f"{'Epoch Time':<20} {epoch_time_full:<20.2f}s {epoch_time_lora:<20.2f}s")
    print(f"{'─'*80}")
    
    # Save best models
    if val_acc_full > best_val_acc_full:
        best_val_acc_full = val_acc_full
        torch.save(model_full.state_dict(), f"VIT-full-{val_acc_full:.4f}.pth")
        print(f"✓ Full model: Best saved! (Val Acc: {val_acc_full:.4f})")
        wait_full = 0
    else:
        wait_full += 1
    
    if val_acc_lora > best_val_acc_lora:
        best_val_acc_lora = val_acc_lora
        torch.save(model_lora.state_dict(), f"VIT-lora-{val_acc_lora:.4f}.pth")
        print(f"✓ LoRA model: Best saved! (Val Acc: {val_acc_lora:.4f})")
        wait_lora = 0
    else:
        wait_lora += 1
    
    # Early stopping
    if wait_full >= patience and wait_lora >= patience:
        print("\n⚠ Both models hit early stopping patience. Training stopped.")
        break
    
    # Periodic test evaluation
    if (epoch + 1) % 5 == 0:
        print(f"\n[Test Evaluation @ Epoch {epoch+1}]")
        test_loss_full, test_acc_full = test_one_epoch(model_full, test_dataloader, criterion, device)
        test_loss_lora, test_acc_lora = test_one_epoch(model_lora, test_dataloader, criterion, device)
        print(f"  Full Model - Test Loss: {test_loss_full:.4f} | Test Acc: {test_acc_full:.4f}")
        print(f"  LoRA Model - Test Loss: {test_loss_lora:.4f} | Test Acc: {test_acc_lora:.4f}")

print(f"\n{'='*80}")
print("TRAINING COMPLETED")
print(f"{'='*80}")
print(f"Total training time - Full: {time_full:.2f}s, LoRA: {time_lora:.2f}s")
print(f"Best Val Acc - Full: {best_val_acc_full:.4f}, LoRA: {best_val_acc_lora:.4f}")

# Now run Cell 8 for visualizations

In [None]:
# Test data is already created from the 80/20 split in Cell 4
# test_df is available from the split above
# We can use the existing test_dataset and test_dataloader from Cell 4

print("Test DataFrame head:\n", test_df.head())
print(f"\nNumber of test samples: {len(test_df)}")

# The test_dataset and test_dataloader are already created in Cell 4
# They use the same CustomDataset class and have labels for evaluation

In [None]:
# Load best models and evaluate on test set
print("="*80)
print("FINAL TEST EVALUATION")
print("="*80)

# Load best full model
best_model_path_full = f"VIT-full-{best_val_acc_full:.4f}.pth"
model_full.load_state_dict(torch.load(best_model_path_full, map_location=device, weights_only=False))
model_full.eval()
print(f"✓ Loaded best full model: {best_model_path_full}")

# Load best LoRA model
best_model_path_lora = f"VIT-lora-{best_val_acc_lora:.4f}.pth"
model_lora.load_state_dict(torch.load(best_model_path_lora, map_location=device, weights_only=False))
model_lora.eval()
print(f"✓ Loaded best LoRA model: {best_model_path_lora}")

# Evaluate Full Model
print("\n[Full Model - Test Evaluation]")
all_predictions_full = []
all_labels_full = []

with torch.no_grad():
    for images, labels in tqdm(test_dataloader, desc="Full Model"):
        images, labels = images.to(device), labels.to(device)
        outputs = model_full(images)
        _, preds = torch.max(outputs, 1)
        all_predictions_full.extend(preds.cpu().numpy())
        all_labels_full.extend(labels.cpu().numpy())

# Evaluate LoRA Model
print("\n[LoRA Model - Test Evaluation]")
all_predictions_lora = []
all_labels_lora = []

with torch.no_grad():
    for images, labels in tqdm(test_dataloader, desc="LoRA Model"):
        images, labels = images.to(device), labels.to(device)
        outputs = model_lora(images)
        _, preds = torch.max(outputs, 1)
        all_predictions_lora.extend(preds.cpu().numpy())
        all_labels_lora.extend(labels.cpu().numpy())

# Calculate accuracies
accuracy_full = (np.array(all_labels_full) == np.array(all_predictions_full)).mean()
accuracy_lora = (np.array(all_labels_lora) == np.array(all_predictions_lora)).mean()

print("\n" + "="*80)
print("FINAL TEST RESULTS")
print("="*80)
print(f"Full Model Test Accuracy: {accuracy_full:.4f}")
print(f"LoRA Model Test Accuracy: {accuracy_lora:.4f}")
print("="*80)


In [None]:
# Create comparison DataFrames
submission_df_full = pd.DataFrame({
    "file_name": test_df["file_name"].values,
    "true_label": all_labels_full,
    "predicted_label": all_predictions_full,
    "model": "full"
})

submission_df_lora = pd.DataFrame({
    "file_name": test_df["file_name"].values,
    "true_label": all_labels_lora,
    "predicted_label": all_predictions_lora,
    "model": "lora"
})

print("\nFull Model Predictions Summary:")
print(submission_df_full.head(10))
print(f"\nLoRA Model Predictions Summary:")
print(submission_df_lora.head(10))

In [None]:
# @title label

from matplotlib import pyplot as plt
submission_df['label'].plot(kind='hist', bins=20, title='label')
plt.gca().spines[['top', 'right',]].set_visible(False)

In [None]:
submission_df.to_csv('submission.csv', index=False)