# GNR 638: Machine Learning for Remote Sensing

**Task:** Multi-class image classification of remotely sensed images  
**Competition:** [gnr638-mls4rs-a1 on Kaggle](https://www.kaggle.com/competitions/gnr638-mls4rs-a1)  
**Architecture:** ResNet-50 with ImageNet pre-training (transfer learning)  
**Classes (7):** Basketball Court, Beach, Forest, Railway, Tennis Court, Water Pool, Others

---

## Prerequisites

1. **Accept competition rules:** Visit [kaggle.com/competitions/gnr638-mls4rs-a1](https://www.kaggle.com/competitions/gnr638-mls4rs-a1) and click *Join Competition* / accept the rules before running the data download cell.
2. **Kaggle authentication:** `kagglehub` will prompt for your Kaggle username and API key on first use, or read from `~/.kaggle/kaggle.json`.

---

## 0. Environment Setup

In [None]:
import os
import sys
import random
import shutil
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from sklearn.metrics import confusion_matrix, classification_report

# ── Reproducibility ──────────────────────────────────────────────────────────
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'PyTorch {torch.__version__} | Device: {DEVICE}')
if DEVICE.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)} | '
          f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

## 1. Data Download via kagglehub

> **Important:** You must have accepted the competition rules on Kaggle before running this cell.

In [None]:
import kagglehub
from kagglehub import KaggleApiError

DATA_DIR = Path('data')
DATA_DIR.mkdir(exist_ok=True)

TRAIN_DIR = DATA_DIR / 'train'
TEST_DIR  = DATA_DIR / 'test'

print('Downloading competition data via kagglehub ...')
print('Note: You will be prompted for Kaggle credentials if not already cached.\n')

try:
    download_path = kagglehub.competition_download(
        competition='gnr638-mls4rs-a1',
        path=str(DATA_DIR)
    )
    print(f'Files downloaded to: {download_path}')
except KaggleApiError as e:
    print(f'[ERROR] {e}')
    print('\nTroubleshooting steps:')
    print('  1. Visit https://www.kaggle.com/competitions/gnr638-mls4rs-a1')
    print('  2. Click "Join Competition" and accept the rules.')
    print('  3. Ensure your Kaggle API key is valid (~/.kaggle/kaggle.json).')
    raise

In [None]:
# Inspect what was downloaded
print('Downloaded files:')
for p in sorted(DATA_DIR.rglob('*')):
    if p.is_file():
        size_mb = p.stat().st_size / 1e6
        print(f'  {p.relative_to(DATA_DIR)}  ({size_mb:.1f} MB)')

In [None]:
# ── Extract zips if present ───────────────────────────────────────────────────
import zipfile

for zip_path in sorted(DATA_DIR.glob('*.zip')):
    print(f'Extracting {zip_path.name} ...')
    with zipfile.ZipFile(zip_path, 'r') as zf:
        zf.extractall(DATA_DIR)
    print(f'  Done.')

# Show the directory structure after extraction
print('\nData directory structure:')
for p in sorted(DATA_DIR.rglob('*')):
    depth = len(p.relative_to(DATA_DIR).parts)
    indent = '  ' * (depth - 1)
    if p.is_dir():
        n_files = sum(1 for _ in p.iterdir() if _.is_file())
        print(f'{indent}{p.name}/  [{n_files} files]')
    elif depth <= 3 and p.suffix in ('.csv', '.txt', '.json'):
        print(f'{indent}{p.name}')

## 2. Dataset Inspection

Identify train / test folder layout and label mappings.

In [None]:
# ── Locate training images and classes ───────────────────────────────────────
# The competition typically provides a folder-per-class structure for train
# and a flat folder + CSV for test.

# Try to auto-detect train directory
candidates = [DATA_DIR / 'train', DATA_DIR / 'Train', DATA_DIR / 'training']
TRAIN_DIR = next((c for c in candidates if c.is_dir()), None)
if TRAIN_DIR is None:
    # Fall back: any subdirectory containing class subdirectories
    for d in DATA_DIR.iterdir():
        if d.is_dir() and any(sd.is_dir() for sd in d.iterdir()):
            TRAIN_DIR = d
            break

print(f'Train directory: {TRAIN_DIR}')

CLASS_NAMES = sorted([d.name for d in TRAIN_DIR.iterdir() if d.is_dir()])
CLASS_TO_IDX = {cls: i for i, cls in enumerate(CLASS_NAMES)}
NUM_CLASSES  = len(CLASS_NAMES)

print(f'Number of classes : {NUM_CLASSES}')
print(f'Class names       : {CLASS_NAMES}')

# Count images per class
counts = {}
for cls in CLASS_NAMES:
    imgs = list((TRAIN_DIR / cls).glob('*'))
    imgs = [i for i in imgs if i.suffix.lower() in ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')]
    counts[cls] = len(imgs)

print('\nImages per class:')
total = 0
for cls, n in counts.items():
    print(f'  {cls:<20} {n}')
    total += n
print(f'  {"TOTAL":<20} {total}')

In [None]:
# ── Locate test directory and sample submission ───────────────────────────────
test_candidates = [DATA_DIR / 'test', DATA_DIR / 'Test', DATA_DIR / 'testing']
TEST_DIR = next((c for c in test_candidates if c.is_dir()), None)
print(f'Test directory: {TEST_DIR}')

if TEST_DIR:
    test_images = sorted(TEST_DIR.glob('*'))
    test_images = [i for i in test_images if i.suffix.lower() in ('.jpg','.jpeg','.png','.bmp')]
    print(f'Test images   : {len(test_images)}')

# Check for sample submission CSV
sample_csvs = list(DATA_DIR.glob('sample*.csv')) + list(DATA_DIR.glob('*.csv'))
if sample_csvs:
    print(f'\nSample submission: {sample_csvs[0].name}')
    print(pd.read_csv(sample_csvs[0]).head(5).to_string(index=False))

## 3. Exploratory Data Analysis

In [None]:
# ── Class distribution bar chart ─────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(9, 4))
bars = ax.bar(counts.keys(), counts.values(), color='steelblue', edgecolor='white')
ax.set_title('Training Set — Class Distribution', fontsize=13, fontweight='bold')
ax.set_xlabel('Class', fontsize=11)
ax.set_ylabel('Number of Images', fontsize=11)
ax.tick_params(axis='x', rotation=30)
for bar, val in zip(bars, counts.values()):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
            str(val), ha='center', va='bottom', fontsize=9)
plt.tight_layout()
plt.savefig('class_distribution.png', dpi=100, bbox_inches='tight')
plt.show()
print(f'Class balance ratio (min/max): {min(counts.values())/max(counts.values()):.2f}')

In [None]:
# ── Sample images per class ───────────────────────────────────────────────────
SAMPLES_PER_CLASS = 3

fig, axes = plt.subplots(NUM_CLASSES, SAMPLES_PER_CLASS,
                         figsize=(SAMPLES_PER_CLASS * 3, NUM_CLASSES * 3))
fig.suptitle('Sample Training Images per Class', fontsize=13, fontweight='bold', y=1.01)

for row, cls in enumerate(CLASS_NAMES):
    cls_dir = TRAIN_DIR / cls
    imgs = sorted(cls_dir.glob('*'))
    imgs = [i for i in imgs if i.suffix.lower() in ('.jpg','.jpeg','.png','.bmp')]
    samples = random.sample(imgs, min(SAMPLES_PER_CLASS, len(imgs)))
    for col, img_path in enumerate(samples):
        ax = axes[row][col] if NUM_CLASSES > 1 else axes[col]
        img = Image.open(img_path).convert('RGB')
        ax.imshow(img)
        ax.axis('off')
        if col == 0:
            ax.set_ylabel(cls, fontsize=9, rotation=0, labelpad=55, va='center')

plt.tight_layout()
plt.savefig('sample_images.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
# ── Image size statistics ─────────────────────────────────────────────────────
widths, heights = [], []
all_img_paths = []
for cls in CLASS_NAMES:
    cls_dir = TRAIN_DIR / cls
    paths = [p for p in cls_dir.glob('*')
             if p.suffix.lower() in ('.jpg','.jpeg','.png','.bmp')]
    all_img_paths.extend(paths)
    for p in paths:
        with Image.open(p) as img:
            w, h = img.size
            widths.append(w)
            heights.append(h)

print(f'Image dimensions (W x H):')
print(f'  Width  — min: {min(widths)}, max: {max(widths)}, mean: {np.mean(widths):.0f}')
print(f'  Height — min: {min(heights)}, max: {max(heights)}, mean: {np.mean(heights):.0f}')
print(f'  Total training images: {len(widths)}')

## 4. Data Augmentation & DataLoaders

ImageNet normalisation statistics are used since ResNet-50 was pre-trained on ImageNet.  
Training transforms include random flips, rotations, colour jitter, and random erasing for regularisation.

In [None]:
# ── Hyperparameters ───────────────────────────────────────────────────────────
IMG_SIZE    = 224       # ResNet-50 native input size
BATCH_SIZE  = 32
VAL_SPLIT   = 0.2      # 20% of training data used for validation
NUM_WORKERS = 4

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),   # Slightly larger for crop
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(degrees=20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1)),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

test_transform = val_transform

print('Transforms defined.')
print(f'  Input size : {IMG_SIZE}x{IMG_SIZE}')
print(f'  Batch size : {BATCH_SIZE}')
print(f'  Val split  : {VAL_SPLIT*100:.0f}%')

In [None]:
class RemoteSensingDataset(Dataset):
    """Dataset for folder-per-class remote sensing images."""

    IMG_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}

    def __init__(self, root_dir, class_to_idx, transform=None):
        self.root_dir     = Path(root_dir)
        self.class_to_idx = class_to_idx
        self.transform    = transform
        self.samples      = []   # list of (path, label)

        for cls, idx in class_to_idx.items():
            cls_dir = self.root_dir / cls
            if not cls_dir.is_dir():
                continue
            for img_path in sorted(cls_dir.iterdir()):
                if img_path.suffix.lower() in self.IMG_EXTS:
                    self.samples.append((img_path, idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


class TestDataset(Dataset):
    """Dataset for flat test directory (no labels)."""

    IMG_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}

    def __init__(self, test_dir, transform=None):
        self.test_dir  = Path(test_dir)
        self.transform = transform
        self.samples   = sorted(
            p for p in self.test_dir.iterdir()
            if p.suffix.lower() in self.IMG_EXTS
        )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path = self.samples[idx]
        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, path.name


print('Dataset classes defined.')

In [None]:
# ── Build train / validation split ───────────────────────────────────────────
full_dataset = RemoteSensingDataset(TRAIN_DIR, CLASS_TO_IDX, transform=None)
n_total = len(full_dataset)
n_val   = int(n_total * VAL_SPLIT)
n_train = n_total - n_val

# Split indices (reproducible)
generator = torch.Generator().manual_seed(SEED)
train_subset, val_subset = random_split(full_dataset, [n_train, n_val], generator=generator)

# Apply transforms via wrapper
class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset    = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        image, label = self.subset[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

train_ds = TransformedSubset(train_subset, train_transform)
val_ds   = TransformedSubset(val_subset,   val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f'Total training images : {n_total}')
print(f'  Train split         : {n_train}')
print(f'  Validation split    : {n_val}')
print(f'Train batches         : {len(train_loader)}')
print(f'Val batches           : {len(val_loader)}')

In [None]:
# ── Test DataLoader (if test images are present) ──────────────────────────────
test_loader = None
if TEST_DIR and TEST_DIR.is_dir():
    test_ds     = TestDataset(TEST_DIR, transform=test_transform)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print(f'Test images : {len(test_ds)}')
else:
    print('No test directory found — skipping test DataLoader.')

## 5. Model — ResNet-50 with Fine-Tuning

Strategy: load ImageNet pre-trained weights, replace the final fully-connected layer for 7-class output, and fine-tune the full network with a lower learning rate for early layers.

In [None]:
def build_resnet50(num_classes: int, pretrained: bool = True, freeze_backbone: bool = False):
    """Build ResNet-50 with a custom classification head."""
    weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
    model   = models.resnet50(weights=weights)

    if freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False

    # Replace final FC layer
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(p=0.4),
        nn.Linear(in_features, num_classes)
    )
    return model


model = build_resnet50(NUM_CLASSES, pretrained=True, freeze_backbone=False)
model = model.to(DEVICE)

total_params     = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters     : {total_params:,}')
print(f'Trainable parameters : {trainable_params:,}')
print(f'Model head           : {model.fc}')

## 6. Training

- **Loss:** Cross-entropy  
- **Optimiser:** AdamW with weight decay  
- **Scheduler:** Cosine annealing with warm restarts  
- **Early stopping:** Halts training if validation accuracy does not improve for `PATIENCE` epochs

In [None]:
# ── Training hyperparameters ──────────────────────────────────────────────────
NUM_EPOCHS    = 30
LR            = 3e-4
WEIGHT_DECAY  = 1e-4
PATIENCE      = 7       # Early stopping patience
CHECKPOINT    = Path('checkpoints')
CHECKPOINT.mkdir(exist_ok=True)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1)

print(f'Epochs       : {NUM_EPOCHS}')
print(f'Learning rate: {LR}')
print(f'Weight decay : {WEIGHT_DECAY}')
print(f'Patience     : {PATIENCE}')

In [None]:
def print_progress_bar(current, total, length=40):
    """Print a simple progress bar that updates in place."""
    progress = current / total
    filled   = int(length * progress)
    bar      = '#' * filled + '.' * (length - filled)
    print(f'\rProgress: [{bar}] {current}/{total} ({progress*100:.1f}%)',
          end='', flush=True)
    if current == total:
        print()


def run_epoch(model, loader, criterion, optimizer=None, device=DEVICE):
    """Run one epoch; if optimizer is None, operate in evaluation mode."""
    training = optimizer is not None
    model.train() if training else model.eval()

    running_loss, correct, total = 0.0, 0, 0

    with torch.set_grad_enabled(training):
        for batch_idx, (images, labels) in enumerate(loader):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss    = criterion(outputs, labels)

            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total   += images.size(0)

    epoch_loss = running_loss / total
    epoch_acc  = correct / total
    return epoch_loss, epoch_acc


print('Training utilities defined.')

In [None]:
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_acc  = 0.0
epochs_no_imp = 0

print(f'Training ResNet-50 for up to {NUM_EPOCHS} epochs on {DEVICE} ...\n')
print(f'{"Epoch":>6} | {"Train Loss":>10} | {"Train Acc":>9} | '
      f'{"Val Loss":>8} | {"Val Acc":>8} | {"LR":>10}')
print('-' * 65)

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer)
    val_loss,   val_acc   = run_epoch(model, val_loader,   criterion)
    scheduler.step()

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    current_lr = optimizer.param_groups[0]['lr']
    marker     = ' *' if val_acc > best_val_acc else ''

    print(f'{epoch:>6} | {train_loss:>10.4f} | {train_acc*100:>8.2f}% | '
          f'{val_loss:>8.4f} | {val_acc*100:>7.2f}%{marker} | {current_lr:>10.2e}')

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        epochs_no_imp = 0
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                    'val_acc': best_val_acc},
                   CHECKPOINT / 'best_model.pth')
    else:
        epochs_no_imp += 1
        if epochs_no_imp >= PATIENCE:
            print(f'\nEarly stopping triggered at epoch {epoch} (no improvement for {PATIENCE} epochs).')
            break

print(f'\nBest validation accuracy: {best_val_acc*100:.2f}%')

## 7. Training Curves

In [None]:
epochs_ran = len(history['train_loss'])
x = range(1, epochs_ran + 1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(x, history['train_loss'], label='Train', color='steelblue')
ax1.plot(x, history['val_loss'],   label='Validation', color='coral')
ax1.set_title('Cross-Entropy Loss', fontweight='bold')
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss')
ax1.legend(); ax1.grid(alpha=0.3)

ax2.plot(x, [a*100 for a in history['train_acc']], label='Train', color='steelblue')
ax2.plot(x, [a*100 for a in history['val_acc']],   label='Validation', color='coral')
ax2.set_title('Classification Accuracy (%)', fontweight='bold')
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy (%)')
ax2.legend(); ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=100, bbox_inches='tight')
plt.show()
print(f'Final epoch {epochs_ran}: '
      f'train acc = {history["train_acc"][-1]*100:.2f}%, '
      f'val acc = {history["val_acc"][-1]*100:.2f}%')

## 8. Evaluation on Validation Set

In [None]:
# ── Load best checkpoint ──────────────────────────────────────────────────────
checkpoint = torch.load(CHECKPOINT / 'best_model.pth', map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
print(f'Loaded checkpoint from epoch {checkpoint["epoch"]} '
      f'(val acc: {checkpoint["val_acc"]*100:.2f}%)')

In [None]:
# ── Collect predictions on validation set ────────────────────────────────────
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(DEVICE)
        outputs = model(images)
        preds   = outputs.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

all_preds  = np.array(all_preds)
all_labels = np.array(all_labels)

val_accuracy = (all_preds == all_labels).mean()
print(f'Validation Accuracy: {val_accuracy*100:.2f}%\n')
print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES, digits=3))

In [None]:
# ── Confusion Matrix ──────────────────────────────────────────────────────────
cm = confusion_matrix(all_labels, all_preds)
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)

fig, ax = plt.subplots(figsize=(9, 7))
im = ax.imshow(cm_norm, interpolation='nearest', cmap='Blues', vmin=0, vmax=1)
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

tick_marks = np.arange(NUM_CLASSES)
ax.set_xticks(tick_marks); ax.set_xticklabels(CLASS_NAMES, rotation=35, ha='right', fontsize=9)
ax.set_yticks(tick_marks); ax.set_yticklabels(CLASS_NAMES, fontsize=9)

thresh = 0.5
for i in range(NUM_CLASSES):
    for j in range(NUM_CLASSES):
        colour = 'white' if cm_norm[i, j] > thresh else 'black'
        ax.text(j, i, f'{cm_norm[i, j]:.2f}\n({cm[i, j]})',
                ha='center', va='center', fontsize=8, color=colour)

ax.set_title('Normalised Confusion Matrix (Validation Set)', fontweight='bold')
ax.set_xlabel('Predicted Class'); ax.set_ylabel('True Class')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=100, bbox_inches='tight')
plt.show()

## 9. Test Set Inference & Submission Generation

In [None]:
if test_loader is None:
    print('No test loader — skipping inference.')
else:
    model.eval()
    filenames, predictions = [], []
    n_batches = len(test_loader)

    print(f'Running inference on {len(test_ds)} test images ...')
    with torch.no_grad():
        for i, (images, fnames) in enumerate(test_loader, 1):
            images  = images.to(DEVICE)
            outputs = model(images)
            preds   = outputs.argmax(dim=1).cpu().numpy()
            filenames.extend(fnames)
            predictions.extend(preds)
            print_progress_bar(i, n_batches)

    IDX_TO_CLASS = {v: k for k, v in CLASS_TO_IDX.items()}
    pred_labels  = [IDX_TO_CLASS[p] for p in predictions]

    submission = pd.DataFrame({'id': filenames, 'label': pred_labels})

    # Match Kaggle sample submission format if available
    if sample_csvs:
        sample_df = pd.read_csv(sample_csvs[0])
        id_col    = sample_df.columns[0]
        label_col = sample_df.columns[1]
        submission.columns = [id_col, label_col]

    submission_path = 'submission.csv'
    submission.to_csv(submission_path, index=False)

    print(f'\nSubmission saved to: {submission_path}')
    print(f'Shape: {submission.shape}')
    print('\nPrediction distribution:')
    print(submission.iloc[:, 1].value_counts().to_string())
    print('\nFirst 5 rows:')
    print(submission.head(5).to_string(index=False))

## 10. Test-Time Augmentation (TTA) — Optional Refinement

TTA averages predictions over multiple augmented views of each test image, typically improving accuracy by 0.5–2%.

In [None]:
if test_loader is None:
    print('No test loader — skipping TTA.')
else:
    TTA_TRANSFORMS = [
        val_transform,  # original
        transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]),
        transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomVerticalFlip(p=1.0),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]),
    ]

    model.eval()
    tta_probs   = None
    tta_fnames  = None

    for t_idx, tta_tf in enumerate(TTA_TRANSFORMS):
        tta_ds     = TestDataset(TEST_DIR, transform=tta_tf)
        tta_loader = DataLoader(tta_ds, batch_size=BATCH_SIZE,
                                shuffle=False, num_workers=NUM_WORKERS)
        probs_list, fname_list = [], []

        print(f'TTA pass {t_idx + 1}/{len(TTA_TRANSFORMS)} ...')
        with torch.no_grad():
            for i, (images, fnames) in enumerate(tta_loader, 1):
                images = images.to(DEVICE)
                logits = model(images)
                probs  = torch.softmax(logits, dim=1).cpu().numpy()
                probs_list.extend(probs)
                fname_list.extend(fnames)
                print_progress_bar(i, len(tta_loader))

        if tta_probs is None:
            tta_probs  = np.array(probs_list)
            tta_fnames = fname_list
        else:
            tta_probs += np.array(probs_list)

    tta_preds       = tta_probs.argmax(axis=1)
    IDX_TO_CLASS    = {v: k for k, v in CLASS_TO_IDX.items()}
    tta_pred_labels = [IDX_TO_CLASS[p] for p in tta_preds]

    tta_submission = pd.DataFrame({'id': tta_fnames, 'label': tta_pred_labels})
    if sample_csvs:
        sample_df = pd.read_csv(sample_csvs[0])
        tta_submission.columns = [sample_df.columns[0], sample_df.columns[1]]

    tta_submission.to_csv('submission_tta.csv', index=False)
    print('\nTTA submission saved to: submission_tta.csv')
    print('Prediction distribution (TTA):')
    print(tta_submission.iloc[:, 1].value_counts().to_string())

## 11. Summary

| Item | Value |
|------|-------|
| Architecture | ResNet-50 (ImageNet pre-trained) |
| Classes | 7 (basketball court, beach, forest, railway, tennis court, water pool, others) |
| Input size | 224 x 224 |
| Optimiser | AdamW (lr=3e-4, wd=1e-4) |
| Scheduler | Cosine annealing with warm restarts (T_0=10) |
| Loss | Cross-entropy with label smoothing (0.1) |
| Augmentation | Random crop, flip, rotation, colour jitter, random erasing |
| TTA passes | 3 (original, H-flip, V-flip) |

**Submission files:**
- `submission.csv` — standard inference
- `submission_tta.csv` — test-time augmented (recommended for leaderboard)

**Possible further improvements:**
- Increase training data via additional augmentation (MixUp, CutMix)
- Use EfficientNet-B4 or ViT-B/16 for higher capacity
- Apply class-weighted loss if class imbalance is significant
- Ensemble multiple ResNet variants