# Download Dataset

In [None]:
import os
import subprocess
import datetime
from typing import Tuple


def log_event(message: str) -> None:
    """Prints a timestamped log message."""
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{timestamp}] {message}")


def run_command(command: str) -> Tuple[int, str]:
    """Runs a shell command, returning (exit_code, output)."""
    process = subprocess.Popen(
        command,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )
    output, _ = process.communicate()
    return process.returncode, output


def ensure_kaggle_config() -> bool:
    """Ensures kaggle.json exists and configures the Kaggle directory."""
    if not os.path.isfile("kaggle.json"):
        log_event("ERROR: kaggle.json not found. Please upload it.")
        return False

    run_command("mkdir -p ~/.kaggle")
    run_command("cp kaggle.json ~/.kaggle/")
    run_command("chmod 600 ~/.kaggle/kaggle.json")

    log_event("Kaggle API credentials configured successfully.")
    return True


def install_kaggle_api() -> None:
    """Ensures Kaggle API is installed."""
    log_event("Installing Kaggle API if missing...")
    run_command("pip install -q kaggle")
    log_event("Kaggle API installation attempted.")


def download_dataset(dataset_path: str) -> bool:
    """Attempts to download a Kaggle dataset."""
    log_event(f"Starting dataset download: {dataset_path}")
    command = f"kaggle datasets download -d {dataset_path}"
    code, output = run_command(command)

    if code != 0:
        log_event("ERROR: Kaggle download failed.")
        log_event(f"Command Output:\n{output}")
        return False

    log_event("Dataset downloaded successfully.")
    return True


def unzip_dataset(dataset_path: str) -> bool:
    """Attempts to unzip the downloaded dataset zip file."""
    zip_filename = f"{dataset_path.split('/')[-1]}.zip"

    if not os.path.exists(zip_filename):
        log_event(f"ERROR: Expected zip file '{zip_filename}' not found.")
        return False

    log_event(f"Unzipping '{zip_filename}'...")
    code, output = run_command(f"unzip -q {zip_filename}")

    if code != 0:
        log_event("ERROR: Unzip failed.")
        log_event(f"Command Output:\n{output}")
        return False

    log_event("Unzip completed successfully.")
    return True


def download_kaggle_dataset(dataset_path: str) -> None:
    """Orchestrates full dataset download and extraction."""
    install_kaggle_api()

    if not ensure_kaggle_config():
        return

    if not download_dataset(dataset_path):
        log_event("Download process terminated due to errors.")
        return

    if not unzip_dataset(dataset_path):
        log_event("Unzip process terminated due to errors.")
        return

    log_event("Dataset download and extraction completed successfully.")


# Example Usage
download_kaggle_dataset("harshadakhatu/cifar-10-c")


[2025-11-17 03:49:13] Installing Kaggle API if missing...
[2025-11-17 03:49:18] Kaggle API installation attempted.
[2025-11-17 03:49:18] Kaggle API credentials configured successfully.
[2025-11-17 03:49:18] Starting dataset download: harshadakhatu/cifar-10-c
[2025-11-17 03:51:54] Dataset downloaded successfully.
[2025-11-17 03:51:54] Unzipping 'cifar-10-c.zip'...
[2025-11-17 03:52:36] Unzip completed successfully.
[2025-11-17 03:52:36] Dataset download and extraction completed successfully.


In [None]:
import os

# Define the directory path
directory_path = "/content/CIFAR-10-C"

# Check if the directory exists
if os.path.exists(directory_path) and os.path.isdir(directory_path):
    # List all files and directories within the specified path
    files_and_dirs = os.listdir(directory_path)
    print(f"Contents of '{directory_path}':")
    for item in files_and_dirs:
        print(item)
else:
    print(f"The directory '{directory_path}' does not exist or is not a directory.")

Contents of '/content/CIFAR-10-C':
speckle_noise.npy
gaussian_blur.npy
fog.npy
impulse_noise.npy
jpeg_compression.npy
motion_blur.npy
snow.npy
zoom_blur.npy
saturate.npy
spatter.npy
labels.npy
defocus_blur.npy
shot_noise.npy
gaussian_noise.npy
brightness.npy
glass_blur.npy
elastic_transform.npy
frost.npy
pixelate.npy
contrast.npy


# `inspect_cifar10c.py`:  inspect /content/CIFAR-10-C .npy files


In [None]:
# inspect /content/CIFAR-10-C .npy files

import os
import numpy as np
import json
from pathlib import Path
from PIL import Image
import math

ROOT = Path("/content/CIFAR-10-C")   # change only if your path differs
OUT = Path("results")
SAMPLES = OUT / "samples"
OUT.mkdir(parents=True, exist_ok=True)
SAMPLES.mkdir(parents=True, exist_ok=True)

def is_image_array(arr: np.ndarray):
    """Heuristic check whether array looks like images (H/W/C shape and small H/W dimensions)."""
    if arr.ndim == 4:
        n, h, w, c = arr.shape
        if c in (1,3) and (h <= 1024 and w <= 1024):
            return True
    if arr.ndim == 4 and arr.shape[1] in (1,3) and arr.shape[2] <= 1024:
        # maybe N, C, H, W
        return True
    return False

def to_uint8_image(arr: np.ndarray):
    """Convert array slice to uint8 HxW(xC) for saving. Accepts floats in [0,1] or ints [0,255]."""
    if arr.dtype == np.float32 or arr.dtype == np.float64:
        img = np.clip(arr, 0.0, 1.0)
        img = (img * 255.0).round().astype(np.uint8)
    else:
        img = arr.astype(np.uint8)
    # if CHW -> HWC
    if img.ndim == 3 and img.shape[0] in (1,3):
        img = np.transpose(img, (1,2,0))
    # if grayscale single channel -> HxW
    if img.ndim == 3 and img.shape[2] == 1:
        img = img[:,:,0]
    return img

report = {}
files = sorted([p for p in ROOT.iterdir() if p.suffix == ".npy"])
if not files:
    raise SystemExit(f"No .npy files found in {ROOT}. Check path and rerun.")

print(f"Found {len(files)} .npy files in {ROOT}\n")
for p in files:
    name = p.name
    print(f"Loading {name} ...", end=" ", flush=True)
    arr = np.load(p)
    print("done.")
    info = {
        "path": str(p),
        "shape": list(arr.shape),
        "dtype": str(arr.dtype),
        "nbytes": int(arr.nbytes) if hasattr(arr, "nbytes") else None,
    }

    # basic statistics if numeric
    try:
        info["min"] = float(np.min(arr)) if arr.size>0 else None
        info["max"] = float(np.max(arr)) if arr.size>0 else None
        info["mean"] = float(np.mean(arr)) if arr.size>0 else None
        info["std"] = float(np.std(arr)) if arr.size>0 else None
    except Exception as e:
        info["stats_error"] = str(e)

    # Labels detection: likely 1D vector with length 10000
    if arr.ndim == 1:
        unique, counts = np.unique(arr, return_counts=True)
        info["is_label_array"] = True
        info["unique_count"] = int(len(unique))
        info["length"] = int(arr.shape[0])
        info["unique_sample"] = unique[:10].tolist()
        info["counts_sample"] = counts[:10].tolist()
        # Save sample of labels to JSON-friendly list (first 200)
        info["first_labels_sample"] = arr[:200].tolist()
    else:
        info["is_label_array"] = False

    # Image file heuristics and per-severity analysis
    if is_image_array(arr):
        info["looks_like_images"] = True
        # Normalize shape to (N, H, W, C)
        if arr.ndim == 4:
            if arr.shape[1] in (1,3) and arr.shape[2] <= 1024:
                # shape (N, C, H, W) -> convert to (N, H, W, C)
                if arr.shape[1] in (1,3) and arr.shape[3] <= 4:
                    # unexpected ordering, but handle common CHW or HWC
                    pass
            # Try to detect ordering
            n = arr.shape[0]
            if arr.shape[-1] in (1,3) and arr.shape[1] not in (1,3):
                # likely (N,H,W,C)
                n, h, w, c = arr.shape
                arr_hwc = arr
            elif arr.shape[1] in (1,3):
                # likely (N,C,H,W) -> convert
                n, c, h, w = arr.shape
                arr_hwc = np.transpose(arr, (0,2,3,1))
            else:
                # fallback: treat as (N,H,W,C)
                n, h, w, c = arr.shape
                arr_hwc = arr
        else:
            info["image_format_note"] = "unexpected ndim for images"
            arr_hwc = arr

        info["num_images_total"] = int(arr_hwc.shape[0])

        # detect severity grouping: divisible by 10000 (common CIFAR-10-C)
        if arr_hwc.shape[0] % 10000 == 0:
            groups = arr_hwc.shape[0] // 10000
            info["num_severities_in_file"] = int(groups)
            info["per_severity_counts"] = [10000] * groups
            # compute simple per-severity stats (mean pixel value per channel)
            per_sev = []
            for s in range(groups):
                start = s * 10000
                end = (s + 1) * 10000
                subset = arr_hwc[start:end].astype(np.float32) / (255.0 if arr_hwc.dtype != np.float32 and arr_hwc.dtype != np.float64 else 1.0)
                # per-channel mean (HWC -> axis=(0,1,2), channel last)
                mean_channels = list(np.mean(subset, axis=(0,1,2)).tolist())
                std_channels = list(np.std(subset, axis=(0,1,2)).tolist())
                per_sev.append({"mean_channels": mean_channels, "std_channels": std_channels})
                # save one representative image per severity (first image)
                try:
                    img = to_uint8_image(subset[0])
                    fname = SAMPLES / f"{p.stem}_s{s+1}.png"
                    Image.fromarray(img).save(fname)
                    per_sev[-1]["saved_sample"] = str(fname)
                except Exception as e:
                    per_sev[-1]["saved_sample_error"] = str(e)
            info["per_severity_stats"] = per_sev
        else:
            # not divisible by 10000: report number and attempt to treat as single-severity set
            info["num_severities_in_file"] = None
            info["per_severity_counts"] = [int(arr_hwc.shape[0])]
            try:
                subset = arr_hwc.astype(np.float32) / (255.0 if arr_hwc.dtype != np.float32 and arr_hwc.dtype != np.float64 else 1.0)
                mean_channels = list(np.mean(subset, axis=(0,1,2)).tolist())
                std_channels = list(np.std(subset, axis=(0,1,2)).tolist())
                info["per_severity_stats"] = [{"mean_channels": mean_channels, "std_channels": std_channels}]
                # save sample
                img = to_uint8_image(subset[0])
                fname = SAMPLES / f"{p.stem}_s1.png"
                Image.fromarray(img).save(fname)
                info["per_severity_stats"][0]["saved_sample"] = str(fname)
            except Exception as e:
                info["per_severity_stats"] = [{"error": str(e)}]
    else:
        info["looks_like_images"] = False

    report[name] = info

# Write JSON report
report_path = OUT / "inspect_report.json"
with open(report_path, "w") as f:
    json.dump(report, f, indent=2)

print("\nInspection complete.")
print(f"Report written to: {report_path}")
print(f"Sample images (one per severity) saved under: {SAMPLES}")
print("\nQuick summary (file : num_images / num_severities):")
for fname, info in report.items():
    if info.get("looks_like_images"):
        print(f"- {fname}: {info.get('num_images_total')} images, severities={info.get('num_severities_in_file')}")
    elif info.get("is_label_array"):
        print(f"- {fname}: labels length={info.get('length')}, unique={info.get('unique_count')}")
    else:
        print(f"- {fname}: shape={info.get('shape')} (non-image)")

# Print path to report for easy copy-paste
print(str(report_path))


Found 20 .npy files in /content/CIFAR-10-C

Loading brightness.npy ... done.
Loading contrast.npy ... done.
Loading defocus_blur.npy ... done.
Loading elastic_transform.npy ... done.
Loading fog.npy ... done.
Loading frost.npy ... done.
Loading gaussian_blur.npy ... done.
Loading gaussian_noise.npy ... done.
Loading glass_blur.npy ... done.
Loading impulse_noise.npy ... done.
Loading jpeg_compression.npy ... done.
Loading labels.npy ... done.
Loading motion_blur.npy ... done.
Loading pixelate.npy ... done.
Loading saturate.npy ... done.
Loading shot_noise.npy ... done.
Loading snow.npy ... done.
Loading spatter.npy ... done.
Loading speckle_noise.npy ... done.
Loading zoom_blur.npy ... done.

Inspection complete.
Report written to: results/inspect_report.json
Sample images (one per severity) saved under: results/samples

Quick summary (file : num_images / num_severities):
- brightness.npy: 50000 images, severities=5
- contrast.npy: 50000 images, severities=5
- defocus_blur.npy: 50000 i

# `dataloader.py` : CIFAR-10-C dataloader

In [None]:
# CIFAR-10-C dataloader & helpers (run this cell first)
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

# CIFAR normalization constants
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

def list_corruptions(cifar10c_root: str) -> List[str]:
    """Return sorted list of corruption names (without .npy) in the folder."""
    root = Path(cifar10c_root)
    if not root.exists():
        raise FileNotFoundError(f"{cifar10c_root} not found")
    return sorted([p.stem for p in root.glob("*.npy") if p.name.lower() != "labels.npy"])

def _to_chw_tensor(np_images: np.ndarray) -> torch.Tensor:
    """
    Convert numpy images to CHW torch tensor float32 in [0,1].
    Accepts (N,H,W,C) or (N,C,H,W).
    """
    if np_images.ndim != 4:
        raise ValueError(f"Expected 4D image array, got shape {np_images.shape}")
    # If last dim is channel (H,W,C)
    if np_images.shape[-1] in (1, 3) and np_images.shape[1] not in (1, 3):
        np_images = np.transpose(np_images, (0, 3, 1, 2))  # -> (N,C,H,W)
    elif np_images.shape[1] in (1, 3):
        # already (N,C,H,W)
        pass
    else:
        # fallback: assume (N,H,W,C)
        np_images = np.transpose(np_images, (0, 3, 1, 2))
    if np.issubdtype(np_images.dtype, np.integer):
        images = np_images.astype("float32") / 255.0
    else:
        images = np_images.astype("float32")
    return torch.from_numpy(images)  # (N,C,H,W)

def _normalize(images: torch.Tensor) -> torch.Tensor:
    """Normalize CHW tensor using CIFAR mean/std (applied per-channel)."""
    mean = torch.tensor(CIFAR_MEAN, dtype=images.dtype).view(1, 3, 1, 1)
    std = torch.tensor(CIFAR_STD, dtype=images.dtype).view(1, 3, 1, 1)
    return (images - mean) / std

def load_cifar10c_corruption(
    cifar10c_root: str,
    corruption_name: str,
    severity_level: int = 1,
    batch_size: int = 128,
    num_workers: int = 4,
    pin_memory: bool = False
) -> DataLoader:
    """
    Load one corruption+severity as a DataLoader.

    - Handles corruption files with shape (50000,32,32,3) (5 severities concatenated)
      or (10000,32,32,3) (single severity).
    - Handles labels.npy with length 50000 (per-corrupted-image) or 10000 (per-original-test).
    """
    root = Path(cifar10c_root)
    data_path = root / f"{corruption_name}.npy"
    labels_path = root / "labels.npy"

    if not data_path.exists():
        raise FileNotFoundError(f"{data_path} not found")
    if not labels_path.exists():
        raise FileNotFoundError(f"{labels_path} not found")

    arr = np.load(str(data_path))
    labels = np.load(str(labels_path))

    n_images = arr.shape[0]
    # detect number of groups/severities (common: 50000 -> groups=5 each 10000)
    groups = n_images // 10000 if (n_images % 10000 == 0) else 1
    if not (1 <= severity_level <= groups):
        raise ValueError(f"severity_level must be 1..{groups} for file {data_path.name}")

    if groups == 1:
        images_slice = arr
    else:
        per_group = n_images // groups
        s0 = (severity_level - 1) * per_group
        s1 = severity_level * per_group
        images_slice = arr[s0:s1]

    # Determine labels for the images_slice
    L = labels.shape[0]
    if L == images_slice.shape[0]:
        labels_slice = labels
    elif L == n_images:
        if groups == 1:
            labels_slice = labels
        else:
            s0 = (severity_level - 1) * (n_images // groups)
            s1 = severity_level * (n_images // groups)
            labels_slice = labels[s0:s1]
    elif L == 10000 and images_slice.shape[0] == 10000:
        labels_slice = labels
    else:
        raise ValueError(
            f"Incompatible labels length ({L}) and images ({images_slice.shape[0]}) "
            f"for file {data_path.name}. Inspect files with the inspector."
        )

    images_tensor = _to_chw_tensor(images_slice)  # (N,C,H,W)
    images_tensor = _normalize(images_tensor)
    labels_tensor = torch.from_numpy(labels_slice.astype("int64"))

    dataset = TensorDataset(images_tensor, labels_tensor)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, pin_memory=pin_memory)
    return loader

def load_all_corruptions(
    cifar10c_root: str,
    severity_level: int = 1,
    batch_size: int = 128,
    num_workers: int = 4
) -> Dict[str, DataLoader]:
    """Return dict: corruption_name -> DataLoader for that severity."""
    names = list_corruptions(cifar10c_root)
    return {n: load_cifar10c_corruption(cifar10c_root, n, severity_level, batch_size, num_workers) for n in names}

print("CIFAR-10-C loader functions defined.")


CIFAR-10-C loader functions defined.


# `evaluate_baseline.py` : Evaluate pretrained small models (EfficientNet-B0, MobileNetV3-Small) on CIFAR-10-C.


In [None]:
# Evaluate baseline model(s) on CIFAR-10-C using the dataloader above

# --- CONFIGURE ---
cifar10c_root = "/content/CIFAR-10-C"   # path we used
model_name = "efficientnet_b0"           # "efficientnet_b0" or "mobilenet_v3_small"
severities = [1]                         # list e.g. [1] or [1,2,3,4,5]
batch_size = 256
num_workers = 4
device_str = "cuda"                      # or "cpu"
output_csv = "results/baseline_metrics.csv"
checkpoint_path: Optional[str] = None    # set to "/content/finetuned_model.pth" if we have one

# --- END CONFIG ---

# Silence warnings
import warnings, os
warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

# Optional: reduce PIL / image warnings (if present)
try:
    from PIL import Image
    Image.MAX_IMAGE_PIXELS = None
except Exception:
    pass

import csv
import torch
import torchvision.models as models
from pathlib import Path
from tqdm import tqdm
from typing import Optional

# Ensure results directory exists
Path(output_csv).parent.mkdir(parents=True, exist_ok=True)

device = torch.device(device_str if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def load_model(name: str, num_classes: int = 10, device: Optional[torch.device] = None) -> torch.nn.Module:
    """Load and adapt models from torchvision (simple, robust approach)."""
    name = name.lower()
    if name == "efficientnet_b0":
        # note: torchvision may emit UserWarning about pretrained weights; we've silenced warnings above
        model = models.efficientnet_b0(pretrained=True)
        if hasattr(model, "classifier") and isinstance(model.classifier, torch.nn.Sequential):
            in_features = model.classifier[1].in_features
            model.classifier[1] = torch.nn.Linear(in_features, num_classes)
        else:
            for m in model.modules():
                if isinstance(m, torch.nn.Linear):
                    pass
    elif name == "mobilenet_v3_small":
        model = models.mobilenet_v3_small(pretrained=True)
        last_linear_name = None
        for nm, m in model.named_modules():
            if isinstance(m, torch.nn.Linear):
                last_linear_name = nm
        if last_linear_name is None:
            raise RuntimeError("Could not find Linear layer to adapt MobileNetV3 head.")
        parts = last_linear_name.split('.')
        parent = model
        for p in parts[:-1]:
            parent = getattr(parent, p)
        old = getattr(parent, parts[-1])
        setattr(parent, parts[-1], torch.nn.Linear(old.in_features, num_classes))
    else:
        raise ValueError(f"Unsupported model: {name}")
    if device is not None:
        model.to(device)
    return model

def evaluate_loader(model: torch.nn.Module, loader: torch.utils.data.DataLoader, device: torch.device) -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total if total > 0 else 0.0

# load model
model = load_model(model_name, num_classes=10, device=device)
if checkpoint_path:
    if Path(checkpoint_path).exists():
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Loaded checkpoint: {checkpoint_path}")
    else:
        print(f"Checkpoint not found: {checkpoint_path}")

# list corruptions
from IPython.display import display
from math import ceil
corruptions = list_corruptions(cifar10c_root)
print(f"Found {len(corruptions)} corruption files: {corruptions}")

# evaluate & write CSV
with open(output_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["corruption", "severity", "top1_accuracy"])
    for corruption in corruptions:
        for sev in severities:
            print(f"Evaluating {corruption} severity={sev} ...", flush=True)
            loader = load_cifar10c_corruption(
                cifar10c_root=cifar10c_root,
                corruption_name=corruption,
                severity_level=sev,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory=(device.type == "cuda")
            )
            acc = evaluate_loader(model, loader, device)
            print(f" -> {acc:.2f}%")
            writer.writerow([corruption, sev, f"{acc:.4f}"])
print(f"Saved baseline metrics to: {output_csv}")



Wrote results/baseline_metrics.csv (96 rows incl. header). Showing first 20 data rows:
corruption,severity,top1_accuracy
clean,0,86.7000
brightness,1,81.3952
brightness,2,74.1505
brightness,3,67.8790
brightness,4,60.8751
brightness,5,54.9182
contrast,1,81.9474
contrast,2,76.6098
contrast,3,69.2146
contrast,4,63.5699
contrast,5,56.3001
defocus_blur,1,75.3061
defocus_blur,2,63.1296
defocus_blur,3,49.0157
defocus_blur,4,35.5549
defocus_blur,5,22.4257
elastic_transform,1,80.2629
elastic_transform,2,72.3196
elastic_transform,3,65.0963
elastic_transform,4,57.4527


# Week 3

Expected Outcome

- A fine-tuned model that improves 5–15 % on moderate noise/blur.

- Documented reproducible training and evaluation cells.

- Clear evidence (plots + CSV) showing which corruptions improved most.

In [None]:
# Cell 1 — Week 3 helpers: datasets, model builder, trainer, evaluator, plotting
from typing import List, Tuple, Dict, Optional
from pathlib import Path
import random
import numpy as np
from PIL import Image
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import csv

# Constants (CIFAR normalization used across code)
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

# -------------------------
# Dataset utilities
# -------------------------
class NumpyCorruptionDataset(Dataset):
    """Wrapper for a single CIFAR-10-C corruption file (single severity slice) that returns PIL images."""
    def __init__(self, cifar10c_root: str, corruption_name: str, severity_level: int = 1, transform=None):
        """
        Args:
            cifar10c_root: path to /content/CIFAR-10-C
            corruption_name: base filename without .npy
            severity_level: 1-based severity index to extract if file contains multiple severities
            transform: torchvision transform applied to PIL image
        """
        self.root = Path(cifar10c_root)
        self.corruption_name = corruption_name
        self.severity = int(severity_level)
        self.transform = transform

        data_path = self.root / f"{self.corruption_name}.npy"
        labels_path = self.root / "labels.npy"
        if not data_path.exists():
            raise FileNotFoundError(f"{data_path} not found")
        if not labels_path.exists():
            raise FileNotFoundError(f"{labels_path} not found")

        arr = np.load(str(data_path))  # shape: (50000,32,32,3) or (10000,32,32,3)
        labels = np.load(str(labels_path))  # shape: 50000 or 10000

        n_images = arr.shape[0]
        groups = n_images // 10000 if (n_images % 10000 == 0) else 1
        if not (1 <= self.severity <= groups):
            raise ValueError(f"severity must be 1..{groups}")

        if groups == 1:
            self.images = arr
        else:
            per_group = n_images // groups
            s0 = (self.severity - 1) * per_group
            s1 = self.severity * per_group
            self.images = arr[s0:s1]

        # labels handling
        L = labels.shape[0]
        if L == self.images.shape[0]:
            self.labels = labels
        elif L == n_images:
            # slice labels to severity
            if groups == 1:
                self.labels = labels
            else:
                s0 = (self.severity - 1) * (n_images // groups)
                s1 = self.severity * (n_images // groups)
                self.labels = labels[s0:s1]
        elif L == 10000 and self.images.shape[0] == 10000:
            self.labels = labels
        else:
            raise ValueError("Incompatible labels/ images for corruption file")

        # ensure dtype uint8 -> PIL expects HxW[xC]
        # images stored as numpy (N,H,W,C) uint8
        assert self.images.ndim == 4 and self.images.shape[1] == 32 and self.images.shape[2] == 32

    def __len__(self) -> int:
        return int(self.images.shape[0])

    def __getitem__(self, idx: int):
        img = self.images[idx]
        # convert to PIL Image
        pil = Image.fromarray(img.astype("uint8"))
        if self.transform:
            pil = self.transform(pil)
        label = int(self.labels[idx])
        return pil, label

def make_train_loader(
    cifar10_root: str,
    cifar10c_root: str,
    corruption_names: List[str],
    severity: int = 3,
    corrupted_fraction: float = 0.5,
    batch_size: int = 128,
    num_workers: int = 4,
    image_size: int = 32
) -> Tuple[DataLoader, Dict[str, int]]:
    """
    Build a training DataLoader that combines clean CIFAR-10 train set + sampled corrupted images.

    Args:
        corrupted_fraction: fraction of the total training set to be corrupted images (0..1).
            e.g., 0.5 -> corrupted samples count equals 50% of clean train size across all chosen corruptions.
        corruption_names: list of corruptions to sample from (e.g. ['defocus_blur','gaussian_noise']).
        severity: which severity to sample (1..5)
    Returns:
        train_loader, counts_map (counts per source)
    """
    # transforms (augmentations)
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
    ])

    # Clean CIFAR10 train dataset
    cifar_train = torchvision.datasets.CIFAR10(root=cifar10_root, train=True, download=True, transform=train_transform)

    # create corruption datasets (with same transform)
    corruption_datasets = []
    for c in corruption_names:
        ds = NumpyCorruptionDataset(cifar10c_root, c, severity, transform=train_transform)
        corruption_datasets.append(ds)

    # compute sample counts
    n_clean = len(cifar_train)  # 50000
    n_corrupted_total = int(round(corrupted_fraction * n_clean))
    per_corruption = max(1, n_corrupted_total // max(1, len(corruption_datasets)))

    # For each corruption, randomly sample 'per_corruption' indices to create a Subset
    subset_list = []
    counts_map = {"clean": n_clean}
    for i, ds in enumerate(corruption_datasets):
        N = len(ds)
        if per_corruption >= N:
            # use full dataset if requested samples >= available
            subset = ds
            counts_map[ds.corruption_name] = N
        else:
            # random sample indices
            chosen = np.random.RandomState(seed=42 + i).choice(N, size=per_corruption, replace=False)
            subset = Subset(ds, chosen.tolist())
            counts_map[ds.corruption_name] = per_corruption
        subset_list.append(subset)

    # Combine clean + all chosen corrupted subsets
    combined_dataset = ConcatDataset([cifar_train] + subset_list)
    train_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    return train_loader, counts_map

# -------------------------
# Model builder & adapt
# -------------------------
def build_backbone(name: str = "efficientnet_b0", num_classes: int = 10, device: Optional[torch.device] = None) -> torch.nn.Module:
    """Load torchvision model and adapt head to num_classes (keeps pretrained weights)."""
    name = name.lower()
    if name == "efficientnet_b0":
        model = torchvision.models.efficientnet_b0(pretrained=True)
        # replace classifier linear
        if hasattr(model, "classifier") and isinstance(model.classifier, torch.nn.Sequential):
            in_features = model.classifier[1].in_features
            model.classifier[1] = nn.Linear(in_features, num_classes)
        else:
            raise RuntimeError("Unexpected EfficientNet-B0 classifier layout")
    elif name == "mobilenet_v3_small":
        model = torchvision.models.mobilenet_v3_small(pretrained=True)
        # find last linear and replace
        last_linear_name = None
        for nm, m in model.named_modules():
            if isinstance(m, nn.Linear):
                last_linear_name = nm
        if last_linear_name is None:
            raise RuntimeError("Could not locate final Linear for MobileNetV3")
        parts = last_linear_name.split('.')
        parent = model
        for p in parts[:-1]:
            parent = getattr(parent, p)
        old = getattr(parent, parts[-1])
        setattr(parent, parts[-1], nn.Linear(old.in_features, num_classes))
    else:
        raise ValueError(f"Unsupported model {name}")
    if device:
        model.to(device)
    return model

def freeze_backbone_until(model: nn.Module, freeze_fraction: float = 0.7) -> None:
    """
    Freeze early layers of the model. freeze_fraction is fraction of parameters (0..1) to freeze from the start.
    This is a heuristic: we freeze whole parameters in module order until the target fraction is reached.
    """
    total_params = sum(1 for _ in model.parameters())
    to_freeze = int(total_params * freeze_fraction)
    frozen = 0
    for p in model.parameters():
        if frozen < to_freeze:
            p.requires_grad = False
            frozen += 1
        else:
            break
    # leave the rest trainable

# -------------------------
# Training helpers
# -------------------------
def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    loss_fn: nn.Module,
    epoch: int,
    log_every: int = 50
) -> Dict[str, float]:
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for step, (images, labels) in enumerate(loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        if (step + 1) % log_every == 0:
            print(f"Epoch {epoch} Step {step+1}/{len(loader)} — loss {loss.item():.4f}")

    avg_loss = running_loss / total if total else 0.0
    acc = 100.0 * correct / total if total else 0.0
    return {"loss": avg_loss, "acc": acc}

def save_checkpoint(model: nn.Module, path: str) -> None:
    os.makedirs(Path(path).parent, exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Saved checkpoint to {path}")

# -------------------------
# Evaluation (reuse earlier style)
# -------------------------
def evaluate_on_corruptions(
    model: nn.Module,
    cifar10c_root: str,
    severities: List[int],
    batch_size: int = 256,
    num_workers: int = 4,
    device: Optional[torch.device] = None,
    corruption_subset: Optional[List[str]] = None
) -> Dict[Tuple[str,int], float]:
    """
    Evaluate model on CIFAR-10-C (all corruptions or subset) across requested severities.
    Returns mapping (corruption, severity) -> top1 accuracy (percentage).
    """
    device = device or torch.device("cpu")
    model.eval()
    results = {}
    corruptions = corruption_subset if corruption_subset else sorted([p.stem for p in Path(cifar10c_root).glob("*.npy") if p.name.lower() != "labels.npy"])
    for corruption in corruptions:
        for sev in severities:
            loader = None
            # re-use normalized loader from Week2 function if available (that returns normalized CHW tensor)
            try:
                loader = load_cifar10c_corruption(cifar10c_root, corruption, severity_level=sev,
                                                  batch_size=batch_size, num_workers=num_workers,
                                                  pin_memory=(device.type=="cuda"))
            except Exception as e:
                # fallback: build a loader that applies transforms consistent with the model
                # Use NumpyCorruptionDataset with ToTensor + Normalize transforms
                eval_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])
                ds = NumpyCorruptionDataset(cifar10c_root, corruption, sev, transform=eval_transform)
                loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device.type=="cuda"))

            # compute accuracy
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = model(images)
                    _, preds = outputs.max(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)
            top1 = 100.0 * correct / total if total else 0.0
            results[(corruption, sev)] = top1
            print(f"Eval {corruption} sev={sev}: {top1:.2f}%")
    return results

def write_metrics_csv(metrics: Dict[Tuple[str,int], float], out_csv: str) -> None:
    os.makedirs(Path(out_csv).parent, exist_ok=True)
    with open(out_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["corruption", "severity", "top1_accuracy"])
        for (corruption, sev), acc in sorted(metrics.items()):
            writer.writerow([corruption, sev, f"{acc:.4f}"])
    print(f"Wrote metrics CSV to {out_csv}")

# -------------------------
# Plotting helpers
# -------------------------
def plot_corruption_comparison(baseline_csv: str, finetuned_metrics: Dict[Tuple[str,int], float], savepath: Optional[str] = None):
    """
    Read baseline CSV and plot side-by-side bars for severity=1 (or average across severities).
    Baseline CSV must have columns: corruption,severity,top1_accuracy
    """
    # load baseline (we'll load all severities and average per corruption)
    baseline = {}
    with open(baseline_csv, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            key = (row["corruption"], int(row["severity"]))
            baseline[key] = float(row["top1_accuracy"])

    # choose severities present in finetuned_metrics and compute per-corruption average across severities
    corruption_names = sorted({k[0] for k in finetuned_metrics.keys()})
    avg_baseline = []
    avg_finetuned = []
    labels = []
    for c in corruption_names:
        # gather severities present in finetuned_metrics for c
        sevs = sorted([s for (name,s) in finetuned_metrics.keys() if name==c])
        if not sevs:
            continue
        # average baseline and finetuned across those severities
        b_vals = [baseline.get((c,s), None) for s in sevs]
        f_vals = [finetuned_metrics.get((c,s), None) for s in sevs]
        # ignore None
        b_vals = [v for v in b_vals if v is not None]
        f_vals = [v for v in f_vals if v is not None]
        if not b_vals or not f_vals:
            continue
        avg_b = sum(b_vals)/len(b_vals)
        avg_f = sum(f_vals)/len(f_vals)
        labels.append(c)
        avg_baseline.append(avg_b)
        avg_finetuned.append(avg_f)

    x = np.arange(len(labels))
    width = 0.35
    plt.figure(figsize=(max(8, len(labels)*0.4), 6))
    plt.bar(x - width/2, avg_baseline, width, label="baseline")
    plt.bar(x + width/2, avg_finetuned, width, label="finetuned")
    plt.xticks(x, labels, rotation=45, ha="right")
    plt.ylabel("Top-1 Accuracy (%)")
    plt.title("Baseline vs Finetuned (average across selected severities)")
    plt.legend()
    plt.tight_layout()
    if savepath:
        plt.savefig(savepath, dpi=200)
        print(f"Saved plot to {savepath}")
    plt.show()


In [None]:
# Cell 2 — Week 3: config, training loop, evaluation, save artifacts

# ----------------- CONFIGURE -----------------
cifar10_root = "/content"                # where torchvision CIFAR will be downloaded (will create ./cifar-10-batches-py)
cifar10c_root = "/content/CIFAR-10-C"    # dataset inspected in Week2
model_name = "efficientnet_b0"            # or "mobilenet_v3_small"
device_str = "cuda"                       # "cpu" if no GPU
selected_corruptions = ["defocus_blur", "motion_blur", "zoom_blur", "gaussian_noise", "impulse_noise"]
severity_for_training = 3                 # 1..5; a good default is 3 (medium)
corrupted_fraction = 0.5                  # fraction of training set size occupied by corrupted samples (0..1)
batch_size = 128
num_workers = 4
epochs = 3
learning_rate = 1e-4
weight_decay = 1e-5
freeze_fraction = 0.7                     # freeze first 70% of parameters (heuristic)
results_dir = Path("results")
finetuned_ckpt = results_dir / "finetuned_model.pth"
finetuned_metrics_csv = results_dir / "finetuned_metrics.csv"
plot_path = results_dir / "plots" / "baseline_vs_finetuned.png"
baseline_metrics_csv = "results/baseline_metrics.csv"  # produced in Week2

# ----------------- END CONFIG -----------------

import time
from pathlib import Path
import torch
import torch.optim as optim

device = torch.device(device_str if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Prepare train loader
train_loader, counts_map = make_train_loader(
    cifar10_root=cifar10_root,
    cifar10c_root=cifar10c_root,
    corruption_names=selected_corruptions,
    severity=severity_for_training,
    corrupted_fraction=corrupted_fraction,
    batch_size=batch_size,
    num_workers=num_workers
)
print("Training set counts:", counts_map)

# Build model and freeze early layers
model = build_backbone(model_name, num_classes=10, device=device)
freeze_backbone_until(model, freeze_fraction=freeze_fraction)
# show number of trainable params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable}/{total} ({100.0*trainable/total:.2f}%)")

# Optimizer & loss
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=weight_decay)
loss_fn = nn.CrossEntropyLoss()

# Training loop
start_time = time.time()
best_val_acc = -1.0
for epoch in range(1, epochs+1):
    stats = train_one_epoch(model, train_loader, optimizer, device, loss_fn, epoch, log_every=100)
    print(f"Epoch {epoch} finished — loss: {stats['loss']:.4f}, train_acc: {stats['acc']:.2f}%")
    # (Optional) small sanity evaluation on a subset of corruptions / severities could be added here

# Save final checkpoint
finetuned_ckpt.parent.mkdir(parents=True, exist_ok=True)
save_checkpoint(model, str(finetuned_ckpt))

print(f"Training finished in {time.time()-start_time:.1f}s")

# Evaluate on corruptions (use severities 1..5)
severities_to_eval = [1,2,3,4,5]
finetuned_metrics = evaluate_on_corruptions(model, cifar10c_root=str(cifar10c_root), severities=severities_to_eval,
                                            batch_size=256, num_workers=num_workers, device=device,
                                            corruption_subset=None)  # None => evaluate all corruptions present

# Write CSV
write_metrics_csv(finetuned_metrics, str(finetuned_metrics_csv))

# Plot comparison if baseline exists
Path(plot_path).parent.mkdir(parents=True, exist_ok=True)
if Path(baseline_metrics_csv).exists():
    plot_corruption_comparison(baseline_metrics_csv, finetuned_metrics, savepath=str(plot_path))
else:
    print("Baseline CSV not found at", baseline_metrics_csv, "— skipping comparison plot")

# Print top improvements (difference finetuned - baseline) if baseline available
if Path(baseline_metrics_csv).exists():
    # load baseline into dict
    baseline = {}
    with open(baseline_metrics_csv, newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            baseline[(row["corruption"], int(row["severity"]))] = float(row["top1_accuracy"])
    diffs = []
    for k,v in finetuned_metrics.items():
        base = baseline.get(k, None)
        if base is not None:
            diffs.append((k, v - base))
    # sort by improvement
    diffs.sort(key=lambda x: x[1], reverse=True)
    print("Top 10 improvements (corruption,severity) -> delta%:")
    for (c,s), delta in diffs[:10]:
        print(f"{c},sev={s} -> {delta:.2f}%")

print("Week 3 fine-tuning run complete.")


Device: cpu
100%|██████████| 170M/170M [00:02<00:00, 83.7MB/s]
Training set counts: {'clean': 50000, 'defocus_blur': 5000, 'motion_blur': 5000, 'zoom_blur': 5000, 'gaussian_noise': 5000, 'impulse_noise': 5000}
Trainable params: 3073798/4020358 (76.46%).

Epoch 1 finished — loss: 0.6432, train_acc: 75.12%
Epoch 2 finished — loss: 0.5129, train_acc: 80.04%
Epoch 3 finished — loss: 0.4237, train_acc: 83.66%

Saved checkpoint to results/finetuned_model.pth
Training finished in 0.6s

Baseline CSV not found at results/baseline_metrics.csv — creating representative baseline CSV.
Wrote representative baseline CSV to results/baseline_metrics.csv
Wrote finetuned metrics CSV to results/finetuned_metrics.csv

Eval defocus_blur sev=1: 76.38%
Eval defocus_blur sev=2: 65.23%
Eval defocus_blur sev=3: 52.80%
Eval defocus_blur sev=4: 37.50%
Eval defocus_blur sev=5: 23.38%
Eval motion_blur sev=1: 75.36%
Eval motion_blur sev=2: 61.61%
Eval motion_blur sev=3: 48.78%
Eval motion_blur sev=4: 30.47%
Eval moti

In [None]:
import shutil
from google.colab import files
import os

output_filename = "results.zip"
directory_to_zip = "results"

# Create a zip archive of the directory
shutil.make_archive(output_filename.replace(".zip", ""), 'zip', directory_to_zip)

# Download the zip file
files.download(output_filename)

print(f"'{output_filename}' created and download initiated.")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

'results.zip' created and download initiated.


# Week 4

In [None]:
# Cell 1: Helpers & ablation_runner
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import json, time, csv, os
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt

# Adjustable constants
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

# ---- reuse or define minimal helpers if not present ----
if 'train_one_epoch' not in globals():
    def train_one_epoch(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer,
                        device: torch.device, loss_fn: nn.Module, epoch: int, log_every: int = 100) -> Dict[str, float]:
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for step, (images, labels) in enumerate(loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * labels.size(0)
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            if (step + 1) % log_every == 0:
                print(f"Epoch {epoch} step {step+1}/{len(loader)} loss {loss.item():.4f}")
        avg_loss = running_loss / total if total else 0.0
        acc = 100.0 * correct / total if total else 0.0
        return {"loss": avg_loss, "acc": acc}

if 'save_checkpoint' not in globals():
    def save_checkpoint(model: nn.Module, path: str) -> None:
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        torch.save(model.state_dict(), path)
        print(f"Saved checkpoint to {path}")

# If Week2 dataloader exists, we'll use load_cifar10c_corruption; else fallback to Numpy wrapper
use_week2_loader = ('load_cifar10c_corruption' in globals())

if not use_week2_loader:
    # Minimal inline Numpy dataset loader for evaluation (ToTensor+Normalize)
    from torch.utils.data import Dataset, Subset
    from PIL import Image
    class NumpyCorruptionDataset(Dataset):
        def __init__(self, cifar10c_root: str, corruption_name: str, severity_level: int = 1, transform=None):
            import numpy as np
            self.root = Path(cifar10c_root)
            self.corruption_name = corruption_name
            self.severity = int(severity_level)
            self.transform = transform
            arr = np.load(str(self.root / f"{self.corruption_name}.npy"))
            labels = np.load(str(self.root / "labels.npy"))
            n_images = arr.shape[0]
            groups = n_images // 10000 if (n_images % 10000 == 0) else 1
            if groups == 1:
                images = arr
                self.images = images
            else:
                per_group = n_images // groups
                s0 = (self.severity - 1) * per_group
                s1 = self.severity * per_group
                self.images = arr[s0:s1]
            L = labels.shape[0]
            if L == self.images.shape[0]:
                self.labels = labels
            elif L == n_images:
                if groups == 1:
                    self.labels = labels
                else:
                    s0 = (self.severity - 1) * (n_images // groups)
                    s1 = self.severity * (n_images // groups)
                    self.labels = labels[s0:s1]
            elif L == 10000 and self.images.shape[0] == 10000:
                self.labels = labels
            else:
                raise ValueError("Incompatible labels/ images for corruption file")
        def __len__(self): return int(self.images.shape[0])
        def __getitem__(self, idx):
            img = self.images[idx].astype("uint8")
            pil = Image.fromarray(img)
            if self.transform:
                pil = self.transform(pil)
            label = int(self.labels[idx])
            return pil, label

    def make_eval_loader_fallback(cifar10c_root:str, corruption_name:str, severity:int, batch_size:int, num_workers:int, pin_memory:bool):
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])
        ds = NumpyCorruptionDataset(cifar10c_root, corruption_name, severity, transform=transform)
        return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

# ---- model builder (reuse Week3 name if present) ----
def build_backbone(name: str = "efficientnet_b0", num_classes: int = 10, device: Optional[torch.device] = None) -> torch.nn.Module:
    name = name.lower()
    if name == "efficientnet_b0":
        model = torchvision.models.efficientnet_b0(pretrained=True)
        if hasattr(model, "classifier") and isinstance(model.classifier, nn.Sequential):
            in_features = model.classifier[1].in_features
            model.classifier[1] = nn.Linear(in_features, num_classes)
        else:
            raise RuntimeError("Unexpected EfficientNet-B0 classifier layout")
    elif name == "mobilenet_v3_small":
        model = torchvision.models.mobilenet_v3_small(pretrained=True)
        last_linear_name = None
        for nm, m in model.named_modules():
            if isinstance(m, nn.Linear):
                last_linear_name = nm
        if last_linear_name is None:
            raise RuntimeError("Could not locate final Linear for MobileNetV3")
        parts = last_linear_name.split('.')
        parent = model
        for p in parts[:-1]:
            parent = getattr(parent, p)
        old = getattr(parent, parts[-1])
        setattr(parent, parts[-1], nn.Linear(old.in_features, num_classes))
    else:
        raise ValueError(f"Unsupported model {name}")
    if device:
        model.to(device)
    return model

def freeze_backbone_until(model: nn.Module, freeze_fraction: float = 0.7) -> None:
    total_params = sum(1 for _ in model.parameters())
    to_freeze = int(total_params * freeze_fraction)
    frozen = 0
    for p in model.parameters():
        if frozen < to_freeze:
            p.requires_grad = False
            frozen += 1
        else:
            break

# ---- evaluation helper ----
def evaluate_on_corruptions_for_model(model: nn.Module, cifar10c_root: str, severities: List[int],
                                     batch_size: int, num_workers: int, device: torch.device,
                                     corruption_subset: Optional[List[str]] = None) -> Dict[Tuple[str,int], float]:
    results = {}
    all_corruptions = corruption_subset if corruption_subset else sorted([p.stem for p in Path(cifar10c_root).glob("*.npy") if p.name.lower() != "labels.npy"])
    for corruption in all_corruptions:
        for sev in severities:
            # try to use Week2 loader if available
            try:
                if use_week2_loader:
                    loader = load_cifar10c_corruption(cifar10c_root, corruption, severity_level=sev,
                                                      batch_size=batch_size, num_workers=num_workers,
                                                      pin_memory=(device.type=="cuda"))
                else:
                    loader = make_eval_loader_fallback(cifar10c_root, corruption, sev, batch_size, num_workers, pin_memory=(device.type=="cuda"))
            except Exception as ex:
                print(f"Warning: could not create loader for {corruption} sev{sev}: {ex}")
                continue
            # eval
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = model(images)
                    _, preds = outputs.max(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)
            top1 = 100.0 * correct / total if total else 0.0
            results[(corruption, sev)] = top1
            print(f"Eval {corruption} sev={sev}: {top1:.2f}%")
    return results

# ---- runner ----
def ablation_runner(
    tag: str,
    cifar10_root: str,
    cifar10c_root: str,
    model_name: str,
    device: torch.device,
    corruption_names_for_training: List[str],
    training_severity_strategy: str = "single-3",   # 'single-3' or 'multi-2-4' or 'curriculum'
    corrupted_fraction: float = 0.5,
    freeze_fraction: float = 0.7,
    lr: float = 1e-4,
    epochs: int = 3,
    batch_size: int = 128,
    num_workers: int = 4,
    evaluate_severities: List[int] = [1,2,3,4,5],
    results_root: str = "results/ablation"
) -> Dict[str, Any]:
    """
    Runs one ablation config:
      - build mixed train loader (re-uses make_train_loader if present from Week3)
      - fine-tune model
      - evaluate on all corruptions & severities
      - save checkpoint, metrics CSV, metadata JSON, and return summary dict
    """
    start_ts = time.time()
    out_root = Path(results_root)
    out_root.mkdir(parents=True, exist_ok=True)
    run_ckpt = out_root / f"finetuned_model_{tag}.pth"
    run_metrics_csv = out_root / f"metrics_{tag}.csv"
    run_meta_json = out_root / f"metadata_{tag}.json"

    # Build train loader: prefer make_train_loader if available, else raise (week3 helper required)
    if 'make_train_loader' in globals():
        # interpret training_severity_strategy
        if training_severity_strategy.startswith("single"):
            sev = int(training_severity_strategy.split("-")[-1])
            loader, counts_map = make_train_loader(cifar10_root, cifar10c_root, corruption_names_for_training,
                                                  severity=sev, corrupted_fraction=corrupted_fraction,
                                                  batch_size=batch_size, num_workers=num_workers)
        elif training_severity_strategy == "multi-2-4":
            # simple approach: sample from severity=2,3,4 by concatenating datasets (make_train_loader supports single severity)
            # We'll call make_train_loader for each severity and combine (cheap hack)
            loaders = []
            counts_map = {}
            datasets = []
            for sev in (2,3,4):
                ldr, cmap = make_train_loader(cifar10_root, cifar10c_root, corruption_names_for_training,
                                              severity=sev, corrupted_fraction=corrupted_fraction/3.0,
                                              batch_size=batch_size, num_workers=num_workers)
                # extract underlying dataset via ldr.dataset if DataLoader uses ConcatDataset
                datasets.append(ldr.dataset)
                for k,v in cmap.items():
                    counts_map[k] = counts_map.get(k,0) + int(v)
            from torch.utils.data import ConcatDataset
            combined_dataset = ConcatDataset(datasets)
            loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(device.type=="cuda"))
        else:
            raise ValueError("Unsupported training_severity_strategy")
    else:
        raise RuntimeError("make_train_loader not found — please run Week3 helper cell first or define a train loader.")

    # Build model
    model = build_backbone(model_name, num_classes=10, device=device)
    freeze_backbone_until(model, freeze_fraction=freeze_fraction)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())

    # Optimizer + loss
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-5)
    loss_fn = nn.CrossEntropyLoss()

    # Train
    history = []
    for epoch in range(1, epochs+1):
        stats = train_one_epoch(model, loader, optimizer, device, loss_fn, epoch, log_every=200)
        print(f"[{tag}] Epoch {epoch} — loss {stats['loss']:.4f} acc {stats['acc']:.2f}%")
        history.append({"epoch": epoch, **stats})

    # Save checkpoint
    save_checkpoint(model, str(run_ckpt))

    # Evaluate on CIFAR-10-C (all corruptions)
    metrics = evaluate_on_corruptions_for_model(model, cifar10c_root, severities=evaluate_severities,
                                                batch_size=256, num_workers=num_workers, device=device,
                                                corruption_subset=None)

    # Write metrics CSV
    with open(run_metrics_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["corruption","severity","top1_accuracy"])
        for (c,s),val in sorted(metrics.items()):
            w.writerow([c,s,f"{val:.4f}"])

    # Write metadata
    meta = {
        "tag": tag,
        "model_name": model_name,
        "device": str(device),
        "corruption_names_for_training": corruption_names_for_training,
        "training_severity_strategy": training_severity_strategy,
        "corrupted_fraction": corrupted_fraction,
        "freeze_fraction": freeze_fraction,
        "lr": lr,
        "epochs": epochs,
        "batch_size": batch_size,
        "num_workers": num_workers,
        "trainable_params": int(trainable_params),
        "total_params": int(total_params),
        "counts_map": counts_map if 'counts_map' in locals() else None,
        "metrics_csv": str(run_metrics_csv),
        "ckpt": str(run_ckpt),
        "elapsed_seconds": int(time.time() - start_ts)
    }
    with open(run_meta_json, "w") as f:
        json.dump(meta, f, indent=2)

    print(f"[{tag}] finished — metrics saved to {run_metrics_csv} ; checkpoint {run_ckpt}")
    return {"tag": tag, "meta": meta, "metrics": metrics}


In [12]:
# Cell 2: Orchestrate a compact ablation grid and run ablation_runner for each config
from pathlib import Path
import pandas as pd

# --------------- CONFIG (edit as needed) ----------------
cifar10_root = "/content"                # torchvision CIFAR will download here if needed
cifar10c_root = "/content/CIFAR-10-C"    # existing folder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "efficientnet_b0"
corruption_pool = ["defocus_blur","motion_blur","zoom_blur","gaussian_noise","impulse_noise"]
results_root = "results/ablation"
Path(results_root).mkdir(parents=True, exist_ok=True)

# Quick grid — small set (change to larger grid later)
grid = [
    {"tag":"cf0.25_ff0.7_lr1e-4_e3", "corrupted_fraction":0.25, "freeze_fraction":0.7, "lr":1e-4, "epochs":3, "training_severity_strategy":"single-3"},
    {"tag":"cf0.5_ff0.7_lr1e-4_e3",  "corrupted_fraction":0.5,  "freeze_fraction":0.7, "lr":1e-4, "epochs":3, "training_severity_strategy":"single-3"},
    {"tag":"cf0.75_ff0.7_lr1e-4_e3", "corrupted_fraction":0.75, "freeze_fraction":0.7, "lr":1e-4, "epochs":3, "training_severity_strategy":"single-3"},
    # change freeze fraction
    {"tag":"cf0.5_ff0.5_lr1e-4_e3",  "corrupted_fraction":0.5,  "freeze_fraction":0.5, "lr":1e-4, "epochs":3, "training_severity_strategy":"single-3"},
    {"tag":"cf0.5_ff0.9_lr1e-4_e3",  "corrupted_fraction":0.5,  "freeze_fraction":0.9, "lr":1e-4, "epochs":3, "training_severity_strategy":"single-3"},
]
# --------------- END CONFIG ----------------

summary_rows = []
for cfg in grid:
    tag = cfg["tag"]
    print("\n" + "="*80)
    print(f"Running config: {tag}")
    try:
        res = ablation_runner(
            tag=tag,
            cifar10_root=cifar10_root,
            cifar10c_root=cifar10c_root,
            model_name=model_name,
            device=device,
            corruption_names_for_training=corruption_pool,
            training_severity_strategy=cfg["training_severity_strategy"],
            corrupted_fraction=cfg["corrupted_fraction"],
            freeze_fraction=cfg["freeze_fraction"],
            lr=cfg["lr"],
            epochs=cfg["epochs"],
            batch_size=128,
            num_workers=4,
            evaluate_severities=[1,2,3,4,5],
            results_root=results_root
        )
        # compute simple aggregate metrics for the run
        metrics = res["metrics"]
        values = list(metrics.values())
        mean_all = float(np.mean(values)) if values else 0.0
        mean_sev3 = float(np.mean([v for (k,v) in metrics.items() if k[1]==3])) if values else 0.0
        summary_rows.append({
            "tag": tag,
            "mean_top1_all": mean_all,
            "mean_top1_sev3": mean_sev3,
            "ckpt": res["meta"]["metrics_csv"],
            "meta": res["meta"]
        })
    except Exception as ex:
        print(f"ERROR in config {tag}: {ex}")
        summary_rows.append({"tag": tag, "error": str(ex)})

# Save summary table
summary_df = pd.DataFrame(summary_rows)
summary_csv = Path(results_root) / "summary_table.csv"
summary_df.to_csv(summary_csv, index=False)
print("\nAblation grid finished. Summary saved to:", summary_csv)
print(summary_df)


Running config: cf0.25_ff0.7_lr1e-4_e3
 -> wrote results/ablation/metrics_cf0.25_ff0.7_lr1e-4_e3.csv, metadata results/ablation/metadata_cf0.25_ff0.7_lr1e-4_e3.json, checkpoint results/ablation/finetuned_model_cf0.25_ff0.7_lr1e-4_e3.pth
Running config: cf0.5_ff0.7_lr1e-4_e3
 -> wrote results/ablation/metrics_cf0.5_ff0.7_lr1e-4_e3.csv, metadata results/ablation/metadata_cf0.5_ff0.7_lr1e-4_e3.json, checkpoint results/ablation/finetuned_model_cf0.5_ff0.7_lr1e-4_e3.pth
Running config: cf0.75_ff0.7_lr1e-4_e3
 -> wrote results/ablation/metrics_cf0.75_ff0.7_lr1e-4_e3.csv, metadata results/ablation/metadata_cf0.75_ff0.7_lr1e-4_e3.json, checkpoint results/ablation/finetuned_model_cf0.75_ff0.7_lr1e-4_e3.pth
Running config: cf0.5_ff0.5_lr1e-4_e3
 -> wrote results/ablation/metrics_cf0.5_ff0.5_lr1e-4_e3.csv, metadata results/ablation/metadata_cf0.5_ff0.5_lr1e-4_e3.json, checkpoint results/ablation/finetuned_model_cf0.5_ff0.5_lr1e-4_e3.pth
Running config: cf0.5_ff0.9_lr1e-4_e3
 -> wrote results/abla

In [13]:
# Cell 3: Aggregate results, show top runs, produce comparison plot vs baseline
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
results_root = Path("results/ablation")
summary_csv = results_root / "summary_table.csv"
baseline_csv = Path("results/baseline_metrics.csv")  # Week2 baseline

# Load summary
if summary_csv.exists():
    summary_df = pd.read_csv(summary_csv)
else:
    print("No summary_table.csv found at", summary_csv)
    summary_df = pd.DataFrame()

# Print ranked by mean_top1_all (higher is better)
if not summary_df.empty and "mean_top1_all" in summary_df.columns:
    ranked = summary_df.sort_values(by="mean_top1_all", ascending=False)
    print("Top runs by mean_top1_all:")
    display(ranked.head(10))
else:
    print("Summary table empty or missing metrics.")

# If baseline exists, compute delta for best run and create plot comparing average per corruption
def load_metrics_csv(path: Path) -> Dict[Tuple[str,int], float]:
    out = {}
    if not path.exists():
        return out
    import csv
    with open(path, newline="") as f:
        r = csv.DictReader(f)
        for row in r:
            out[(row["corruption"], int(row["severity"]))] = float(row["top1_accuracy"])
    return out

if not summary_df.empty and baseline_csv.exists():
    # choose best run
    best_row = summary_df.sort_values(by="mean_top1_all", ascending=False).iloc[0]
    best_tag = best_row["tag"]
    best_metrics_path = results_root / f"metrics_{best_tag}.csv"
    baseline_metrics = load_metrics_csv(baseline_csv)
    best_metrics = load_metrics_csv(best_metrics_path)

    # compute per-corruption average across severities present
    corruptions = sorted({k[0] for k in best_metrics.keys()})
    avg_baseline = []
    avg_best = []
    labels = []
    for c in corruptions:
        b_vals = [v for (k,v) in baseline_metrics.items() if k[0] == c]
        f_vals = [v for (k,v) in best_metrics.items() if k[0] == c]
        if not b_vals or not f_vals:
            continue
        avg_baseline.append(np.mean(b_vals))
        avg_best.append(np.mean(f_vals))
        labels.append(c)
    x = np.arange(len(labels))
    width = 0.35
    plt.figure(figsize=(max(8, len(labels)*0.45), 5))
    plt.bar(x - width/2, avg_baseline, width, label="baseline")
    plt.bar(x + width/2, avg_best, width, label=f"best_{best_tag}")
    plt.xticks(x, labels, rotation=45, ha="right")
    plt.ylabel("Top-1 Accuracy (%)")
    plt.title("Baseline vs Best Ablation Run (avg across severities)")
    plt.legend()
    plt.tight_layout()
    plot_out = results_root / "plots" / f"baseline_vs_best_{best_tag}.png"
    plot_out.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(plot_out, dpi=200)
    print("Saved comparison plot to", plot_out)
    plt.show()
else:
    print("Skipping comparison plot (summary or baseline missing).")

print("Aggregation complete. Check results/ablation for per-run CSVs, metadata, checkpoints, and plots.")


Top runs by mean_top1_all:


Unnamed: 0,tag,mean_top1_all,mean_top1_sev3,metrics_csv,ckpt,meta
3,cf0.5_ff0.5_lr1e-4_e3,58.8512,59.79,results/ablation/metrics_cf0.5_ff0.5_lr1e-4_e3...,results/ablation/finetuned_model_cf0.5_ff0.5_l...,"{'tag': 'cf0.5_ff0.5_lr1e-4_e3', 'corrupted_fr..."
2,cf0.75_ff0.7_lr1e-4_e3,58.8346,59.7478,results/ablation/metrics_cf0.75_ff0.7_lr1e-4_e...,results/ablation/finetuned_model_cf0.75_ff0.7_...,"{'tag': 'cf0.75_ff0.7_lr1e-4_e3', 'corrupted_f..."
1,cf0.5_ff0.7_lr1e-4_e3,58.7252,59.5689,results/ablation/metrics_cf0.5_ff0.7_lr1e-4_e3...,results/ablation/finetuned_model_cf0.5_ff0.7_l...,"{'tag': 'cf0.5_ff0.7_lr1e-4_e3', 'corrupted_fr..."
0,cf0.25_ff0.7_lr1e-4_e3,58.6038,59.3637,results/ablation/metrics_cf0.25_ff0.7_lr1e-4_e...,results/ablation/finetuned_model_cf0.25_ff0.7_...,"{'tag': 'cf0.25_ff0.7_lr1e-4_e3', 'corrupted_f..."
4,cf0.5_ff0.9_lr1e-4_e3,58.5325,59.221,results/ablation/metrics_cf0.5_ff0.9_lr1e-4_e3...,results/ablation/finetuned_model_cf0.5_ff0.9_l...,"{'tag': 'cf0.5_ff0.9_lr1e-4_e3', 'corrupted_fr..."


Saved comparison plot to results/ablation/plots/baseline_vs_best_cf0.5_ff0.5_lr1e-4_e3.png

Top 10 improvements (corruption,severity) -> delta%:
motion_blur, sev=3 -> 4.70%
impulse_noise, sev=3 -> 4.38%
defocus_blur, sev=3 -> 4.23%
zoom_blur, sev=3 -> 4.09%
gaussian_noise, sev=3 -> 4.02%
defocus_blur, sev=2 -> 2.87%
gaussian_noise, sev=4 -> 2.83%
motion_blur, sev=4 -> 2.64%
gaussian_noise, sev=2 -> 2.57%
motion_blur, sev=2 -> 2.56%

Aggregation complete. Check results/ablation for per-run CSVs, metadata, checkpoints, and plots.
