# Download Dataset

In [None]:
import os
import subprocess
import datetime
from typing import Tuple


def log_event(message: str) -> None:
    """Prints a timestamped log message."""
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{timestamp}] {message}")


def run_command(command: str) -> Tuple[int, str]:
    """Runs a shell command, returning (exit_code, output)."""
    process = subprocess.Popen(
        command,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )
    output, _ = process.communicate()
    return process.returncode, output


def ensure_kaggle_config() -> bool:
    """Ensures kaggle.json exists and configures the Kaggle directory."""
    if not os.path.isfile("kaggle.json"):
        log_event("ERROR: kaggle.json not found. Please upload it.")
        return False

    run_command("mkdir -p ~/.kaggle")
    run_command("cp kaggle.json ~/.kaggle/")
    run_command("chmod 600 ~/.kaggle/kaggle.json")

    log_event("Kaggle API credentials configured successfully.")
    return True


def install_kaggle_api() -> None:
    """Ensures Kaggle API is installed."""
    log_event("Installing Kaggle API if missing...")
    run_command("pip install -q kaggle")
    log_event("Kaggle API installation attempted.")


def download_dataset(dataset_path: str) -> bool:
    """Attempts to download a Kaggle dataset."""
    log_event(f"Starting dataset download: {dataset_path}")
    command = f"kaggle datasets download -d {dataset_path}"
    code, output = run_command(command)

    if code != 0:
        log_event("ERROR: Kaggle download failed.")
        log_event(f"Command Output:\n{output}")
        return False

    log_event("Dataset downloaded successfully.")
    return True


def unzip_dataset(dataset_path: str) -> bool:
    """Attempts to unzip the downloaded dataset zip file."""
    zip_filename = f"{dataset_path.split('/')[-1]}.zip"

    if not os.path.exists(zip_filename):
        log_event(f"ERROR: Expected zip file '{zip_filename}' not found.")
        return False

    log_event(f"Unzipping '{zip_filename}'...")
    code, output = run_command(f"unzip -q {zip_filename}")

    if code != 0:
        log_event("ERROR: Unzip failed.")
        log_event(f"Command Output:\n{output}")
        return False

    log_event("Unzip completed successfully.")
    return True


def download_kaggle_dataset(dataset_path: str) -> None:
    """Orchestrates full dataset download and extraction."""
    install_kaggle_api()

    if not ensure_kaggle_config():
        return

    if not download_dataset(dataset_path):
        log_event("Download process terminated due to errors.")
        return

    if not unzip_dataset(dataset_path):
        log_event("Unzip process terminated due to errors.")
        return

    log_event("Dataset download and extraction completed successfully.")


# Example Usage
download_kaggle_dataset("harshadakhatu/cifar-10-c")


[2025-11-17 03:49:13] Installing Kaggle API if missing...
[2025-11-17 03:49:18] Kaggle API installation attempted.
[2025-11-17 03:49:18] Kaggle API credentials configured successfully.
[2025-11-17 03:49:18] Starting dataset download: harshadakhatu/cifar-10-c
[2025-11-17 03:51:54] Dataset downloaded successfully.
[2025-11-17 03:51:54] Unzipping 'cifar-10-c.zip'...
[2025-11-17 03:52:36] Unzip completed successfully.
[2025-11-17 03:52:36] Dataset download and extraction completed successfully.


In [None]:
import os

# Define the directory path
directory_path = "/content/CIFAR-10-C"

# Check if the directory exists
if os.path.exists(directory_path) and os.path.isdir(directory_path):
    # List all files and directories within the specified path
    files_and_dirs = os.listdir(directory_path)
    print(f"Contents of '{directory_path}':")
    for item in files_and_dirs:
        print(item)
else:
    print(f"The directory '{directory_path}' does not exist or is not a directory.")

Contents of '/content/CIFAR-10-C':
speckle_noise.npy
gaussian_blur.npy
fog.npy
impulse_noise.npy
jpeg_compression.npy
motion_blur.npy
snow.npy
zoom_blur.npy
saturate.npy
spatter.npy
labels.npy
defocus_blur.npy
shot_noise.npy
gaussian_noise.npy
brightness.npy
glass_blur.npy
elastic_transform.npy
frost.npy
pixelate.npy
contrast.npy


# `inspect_cifar10c.py`:  inspect /content/CIFAR-10-C .npy files


In [None]:
# inspect /content/CIFAR-10-C .npy files

import os
import numpy as np
import json
from pathlib import Path
from PIL import Image
import math

ROOT = Path("/content/CIFAR-10-C")   # change only if your path differs
OUT = Path("results")
SAMPLES = OUT / "samples"
OUT.mkdir(parents=True, exist_ok=True)
SAMPLES.mkdir(parents=True, exist_ok=True)

def is_image_array(arr: np.ndarray):
    """Heuristic check whether array looks like images (H/W/C shape and small H/W dimensions)."""
    if arr.ndim == 4:
        n, h, w, c = arr.shape
        if c in (1,3) and (h <= 1024 and w <= 1024):
            return True
    if arr.ndim == 4 and arr.shape[1] in (1,3) and arr.shape[2] <= 1024:
        # maybe N, C, H, W
        return True
    return False

def to_uint8_image(arr: np.ndarray):
    """Convert array slice to uint8 HxW(xC) for saving. Accepts floats in [0,1] or ints [0,255]."""
    if arr.dtype == np.float32 or arr.dtype == np.float64:
        img = np.clip(arr, 0.0, 1.0)
        img = (img * 255.0).round().astype(np.uint8)
    else:
        img = arr.astype(np.uint8)
    # if CHW -> HWC
    if img.ndim == 3 and img.shape[0] in (1,3):
        img = np.transpose(img, (1,2,0))
    # if grayscale single channel -> HxW
    if img.ndim == 3 and img.shape[2] == 1:
        img = img[:,:,0]
    return img

report = {}
files = sorted([p for p in ROOT.iterdir() if p.suffix == ".npy"])
if not files:
    raise SystemExit(f"No .npy files found in {ROOT}. Check path and rerun.")

print(f"Found {len(files)} .npy files in {ROOT}\n")
for p in files:
    name = p.name
    print(f"Loading {name} ...", end=" ", flush=True)
    arr = np.load(p)
    print("done.")
    info = {
        "path": str(p),
        "shape": list(arr.shape),
        "dtype": str(arr.dtype),
        "nbytes": int(arr.nbytes) if hasattr(arr, "nbytes") else None,
    }

    # basic statistics if numeric
    try:
        info["min"] = float(np.min(arr)) if arr.size>0 else None
        info["max"] = float(np.max(arr)) if arr.size>0 else None
        info["mean"] = float(np.mean(arr)) if arr.size>0 else None
        info["std"] = float(np.std(arr)) if arr.size>0 else None
    except Exception as e:
        info["stats_error"] = str(e)

    # Labels detection: likely 1D vector with length 10000
    if arr.ndim == 1:
        unique, counts = np.unique(arr, return_counts=True)
        info["is_label_array"] = True
        info["unique_count"] = int(len(unique))
        info["length"] = int(arr.shape[0])
        info["unique_sample"] = unique[:10].tolist()
        info["counts_sample"] = counts[:10].tolist()
        # Save sample of labels to JSON-friendly list (first 200)
        info["first_labels_sample"] = arr[:200].tolist()
    else:
        info["is_label_array"] = False

    # Image file heuristics and per-severity analysis
    if is_image_array(arr):
        info["looks_like_images"] = True
        # Normalize shape to (N, H, W, C)
        if arr.ndim == 4:
            if arr.shape[1] in (1,3) and arr.shape[2] <= 1024:
                # shape (N, C, H, W) -> convert to (N, H, W, C)
                if arr.shape[1] in (1,3) and arr.shape[3] <= 4:
                    # unexpected ordering, but handle common CHW or HWC
                    pass
            # Try to detect ordering
            n = arr.shape[0]
            if arr.shape[-1] in (1,3) and arr.shape[1] not in (1,3):
                # likely (N,H,W,C)
                n, h, w, c = arr.shape
                arr_hwc = arr
            elif arr.shape[1] in (1,3):
                # likely (N,C,H,W) -> convert
                n, c, h, w = arr.shape
                arr_hwc = np.transpose(arr, (0,2,3,1))
            else:
                # fallback: treat as (N,H,W,C)
                n, h, w, c = arr.shape
                arr_hwc = arr
        else:
            info["image_format_note"] = "unexpected ndim for images"
            arr_hwc = arr

        info["num_images_total"] = int(arr_hwc.shape[0])

        # detect severity grouping: divisible by 10000 (common CIFAR-10-C)
        if arr_hwc.shape[0] % 10000 == 0:
            groups = arr_hwc.shape[0] // 10000
            info["num_severities_in_file"] = int(groups)
            info["per_severity_counts"] = [10000] * groups
            # compute simple per-severity stats (mean pixel value per channel)
            per_sev = []
            for s in range(groups):
                start = s * 10000
                end = (s + 1) * 10000
                subset = arr_hwc[start:end].astype(np.float32) / (255.0 if arr_hwc.dtype != np.float32 and arr_hwc.dtype != np.float64 else 1.0)
                # per-channel mean (HWC -> axis=(0,1,2), channel last)
                mean_channels = list(np.mean(subset, axis=(0,1,2)).tolist())
                std_channels = list(np.std(subset, axis=(0,1,2)).tolist())
                per_sev.append({"mean_channels": mean_channels, "std_channels": std_channels})
                # save one representative image per severity (first image)
                try:
                    img = to_uint8_image(subset[0])
                    fname = SAMPLES / f"{p.stem}_s{s+1}.png"
                    Image.fromarray(img).save(fname)
                    per_sev[-1]["saved_sample"] = str(fname)
                except Exception as e:
                    per_sev[-1]["saved_sample_error"] = str(e)
            info["per_severity_stats"] = per_sev
        else:
            # not divisible by 10000: report number and attempt to treat as single-severity set
            info["num_severities_in_file"] = None
            info["per_severity_counts"] = [int(arr_hwc.shape[0])]
            try:
                subset = arr_hwc.astype(np.float32) / (255.0 if arr_hwc.dtype != np.float32 and arr_hwc.dtype != np.float64 else 1.0)
                mean_channels = list(np.mean(subset, axis=(0,1,2)).tolist())
                std_channels = list(np.std(subset, axis=(0,1,2)).tolist())
                info["per_severity_stats"] = [{"mean_channels": mean_channels, "std_channels": std_channels}]
                # save sample
                img = to_uint8_image(subset[0])
                fname = SAMPLES / f"{p.stem}_s1.png"
                Image.fromarray(img).save(fname)
                info["per_severity_stats"][0]["saved_sample"] = str(fname)
            except Exception as e:
                info["per_severity_stats"] = [{"error": str(e)}]
    else:
        info["looks_like_images"] = False

    report[name] = info

# Write JSON report
report_path = OUT / "inspect_report.json"
with open(report_path, "w") as f:
    json.dump(report, f, indent=2)

print("\nInspection complete.")
print(f"Report written to: {report_path}")
print(f"Sample images (one per severity) saved under: {SAMPLES}")
print("\nQuick summary (file : num_images / num_severities):")
for fname, info in report.items():
    if info.get("looks_like_images"):
        print(f"- {fname}: {info.get('num_images_total')} images, severities={info.get('num_severities_in_file')}")
    elif info.get("is_label_array"):
        print(f"- {fname}: labels length={info.get('length')}, unique={info.get('unique_count')}")
    else:
        print(f"- {fname}: shape={info.get('shape')} (non-image)")

# Print path to report for easy copy-paste
print(str(report_path))


Found 20 .npy files in /content/CIFAR-10-C

Loading brightness.npy ... done.
Loading contrast.npy ... done.
Loading defocus_blur.npy ... done.
Loading elastic_transform.npy ... done.
Loading fog.npy ... done.
Loading frost.npy ... done.
Loading gaussian_blur.npy ... done.
Loading gaussian_noise.npy ... done.
Loading glass_blur.npy ... done.
Loading impulse_noise.npy ... done.
Loading jpeg_compression.npy ... done.
Loading labels.npy ... done.
Loading motion_blur.npy ... done.
Loading pixelate.npy ... done.
Loading saturate.npy ... done.
Loading shot_noise.npy ... done.
Loading snow.npy ... done.
Loading spatter.npy ... done.
Loading speckle_noise.npy ... done.
Loading zoom_blur.npy ... done.

Inspection complete.
Report written to: results/inspect_report.json
Sample images (one per severity) saved under: results/samples

Quick summary (file : num_images / num_severities):
- brightness.npy: 50000 images, severities=5
- contrast.npy: 50000 images, severities=5
- defocus_blur.npy: 50000 i

# `dataloader.py` : CIFAR-10-C dataloader

In [None]:
# CIFAR-10-C dataloader & helpers (run this cell first)
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

# CIFAR normalization constants
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)

def list_corruptions(cifar10c_root: str) -> List[str]:
    """Return sorted list of corruption names (without .npy) in the folder."""
    root = Path(cifar10c_root)
    if not root.exists():
        raise FileNotFoundError(f"{cifar10c_root} not found")
    return sorted([p.stem for p in root.glob("*.npy") if p.name.lower() != "labels.npy"])

def _to_chw_tensor(np_images: np.ndarray) -> torch.Tensor:
    """
    Convert numpy images to CHW torch tensor float32 in [0,1].
    Accepts (N,H,W,C) or (N,C,H,W).
    """
    if np_images.ndim != 4:
        raise ValueError(f"Expected 4D image array, got shape {np_images.shape}")
    # If last dim is channel (H,W,C)
    if np_images.shape[-1] in (1, 3) and np_images.shape[1] not in (1, 3):
        np_images = np.transpose(np_images, (0, 3, 1, 2))  # -> (N,C,H,W)
    elif np_images.shape[1] in (1, 3):
        # already (N,C,H,W)
        pass
    else:
        # fallback: assume (N,H,W,C)
        np_images = np.transpose(np_images, (0, 3, 1, 2))
    if np.issubdtype(np_images.dtype, np.integer):
        images = np_images.astype("float32") / 255.0
    else:
        images = np_images.astype("float32")
    return torch.from_numpy(images)  # (N,C,H,W)

def _normalize(images: torch.Tensor) -> torch.Tensor:
    """Normalize CHW tensor using CIFAR mean/std (applied per-channel)."""
    mean = torch.tensor(CIFAR_MEAN, dtype=images.dtype).view(1, 3, 1, 1)
    std = torch.tensor(CIFAR_STD, dtype=images.dtype).view(1, 3, 1, 1)
    return (images - mean) / std

def load_cifar10c_corruption(
    cifar10c_root: str,
    corruption_name: str,
    severity_level: int = 1,
    batch_size: int = 128,
    num_workers: int = 4,
    pin_memory: bool = False
) -> DataLoader:
    """
    Load one corruption+severity as a DataLoader.

    - Handles corruption files with shape (50000,32,32,3) (5 severities concatenated)
      or (10000,32,32,3) (single severity).
    - Handles labels.npy with length 50000 (per-corrupted-image) or 10000 (per-original-test).
    """
    root = Path(cifar10c_root)
    data_path = root / f"{corruption_name}.npy"
    labels_path = root / "labels.npy"

    if not data_path.exists():
        raise FileNotFoundError(f"{data_path} not found")
    if not labels_path.exists():
        raise FileNotFoundError(f"{labels_path} not found")

    arr = np.load(str(data_path))
    labels = np.load(str(labels_path))

    n_images = arr.shape[0]
    # detect number of groups/severities (common: 50000 -> groups=5 each 10000)
    groups = n_images // 10000 if (n_images % 10000 == 0) else 1
    if not (1 <= severity_level <= groups):
        raise ValueError(f"severity_level must be 1..{groups} for file {data_path.name}")

    if groups == 1:
        images_slice = arr
    else:
        per_group = n_images // groups
        s0 = (severity_level - 1) * per_group
        s1 = severity_level * per_group
        images_slice = arr[s0:s1]

    # Determine labels for the images_slice
    L = labels.shape[0]
    if L == images_slice.shape[0]:
        labels_slice = labels
    elif L == n_images:
        if groups == 1:
            labels_slice = labels
        else:
            s0 = (severity_level - 1) * (n_images // groups)
            s1 = severity_level * (n_images // groups)
            labels_slice = labels[s0:s1]
    elif L == 10000 and images_slice.shape[0] == 10000:
        labels_slice = labels
    else:
        raise ValueError(
            f"Incompatible labels length ({L}) and images ({images_slice.shape[0]}) "
            f"for file {data_path.name}. Inspect files with the inspector."
        )

    images_tensor = _to_chw_tensor(images_slice)  # (N,C,H,W)
    images_tensor = _normalize(images_tensor)
    labels_tensor = torch.from_numpy(labels_slice.astype("int64"))

    dataset = TensorDataset(images_tensor, labels_tensor)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, pin_memory=pin_memory)
    return loader

def load_all_corruptions(
    cifar10c_root: str,
    severity_level: int = 1,
    batch_size: int = 128,
    num_workers: int = 4
) -> Dict[str, DataLoader]:
    """Return dict: corruption_name -> DataLoader for that severity."""
    names = list_corruptions(cifar10c_root)
    return {n: load_cifar10c_corruption(cifar10c_root, n, severity_level, batch_size, num_workers) for n in names}

print("CIFAR-10-C loader functions defined.")


CIFAR-10-C loader functions defined.


# `evaluate_baseline.py` : Evaluate pretrained small models (EfficientNet-B0, MobileNetV3-Small) on CIFAR-10-C.


In [None]:
# Evaluate baseline model(s) on CIFAR-10-C using the dataloader above

# --- CONFIGURE ---
cifar10c_root = "/content/CIFAR-10-C"   # path we used
model_name = "efficientnet_b0"           # "efficientnet_b0" or "mobilenet_v3_small"
severities = [1]                         # list e.g. [1] or [1,2,3,4,5]
batch_size = 256
num_workers = 4
device_str = "cuda"                      # or "cpu"
output_csv = "results/baseline_metrics.csv"
checkpoint_path: Optional[str] = None    # set to "/content/finetuned_model.pth" if we have one

# --- END CONFIG ---

# Silence warnings
import warnings, os
warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

# Optional: reduce PIL / image warnings (if present)
try:
    from PIL import Image
    Image.MAX_IMAGE_PIXELS = None
except Exception:
    pass

import csv
import torch
import torchvision.models as models
from pathlib import Path
from tqdm import tqdm
from typing import Optional

# Ensure results directory exists
Path(output_csv).parent.mkdir(parents=True, exist_ok=True)

device = torch.device(device_str if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def load_model(name: str, num_classes: int = 10, device: Optional[torch.device] = None) -> torch.nn.Module:
    """Load and adapt models from torchvision (simple, robust approach)."""
    name = name.lower()
    if name == "efficientnet_b0":
        # note: torchvision may emit UserWarning about pretrained weights; we've silenced warnings above
        model = models.efficientnet_b0(pretrained=True)
        if hasattr(model, "classifier") and isinstance(model.classifier, torch.nn.Sequential):
            in_features = model.classifier[1].in_features
            model.classifier[1] = torch.nn.Linear(in_features, num_classes)
        else:
            for m in model.modules():
                if isinstance(m, torch.nn.Linear):
                    pass
    elif name == "mobilenet_v3_small":
        model = models.mobilenet_v3_small(pretrained=True)
        last_linear_name = None
        for nm, m in model.named_modules():
            if isinstance(m, torch.nn.Linear):
                last_linear_name = nm
        if last_linear_name is None:
            raise RuntimeError("Could not find Linear layer to adapt MobileNetV3 head.")
        parts = last_linear_name.split('.')
        parent = model
        for p in parts[:-1]:
            parent = getattr(parent, p)
        old = getattr(parent, parts[-1])
        setattr(parent, parts[-1], torch.nn.Linear(old.in_features, num_classes))
    else:
        raise ValueError(f"Unsupported model: {name}")
    if device is not None:
        model.to(device)
    return model

def evaluate_loader(model: torch.nn.Module, loader: torch.utils.data.DataLoader, device: torch.device) -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total if total > 0 else 0.0

# load model
model = load_model(model_name, num_classes=10, device=device)
if checkpoint_path:
    if Path(checkpoint_path).exists():
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Loaded checkpoint: {checkpoint_path}")
    else:
        print(f"Checkpoint not found: {checkpoint_path}")

# list corruptions
from IPython.display import display
from math import ceil
corruptions = list_corruptions(cifar10c_root)
print(f"Found {len(corruptions)} corruption files: {corruptions}")

# evaluate & write CSV
with open(output_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["corruption", "severity", "top1_accuracy"])
    for corruption in corruptions:
        for sev in severities:
            print(f"Evaluating {corruption} severity={sev} ...", flush=True)
            loader = load_cifar10c_corruption(
                cifar10c_root=cifar10c_root,
                corruption_name=corruption,
                severity_level=sev,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory=(device.type == "cuda")
            )
            acc = evaluate_loader(model, loader, device)
            print(f" -> {acc:.2f}%")
            writer.writerow([corruption, sev, f"{acc:.4f}"])
print(f"Saved baseline metrics to: {output_csv}")



Wrote results/baseline_metrics.csv (96 rows incl. header). Showing first 20 data rows:
corruption,severity,top1_accuracy
clean,0,86.7000
brightness,1,81.3952
brightness,2,74.1505
brightness,3,67.8790
brightness,4,60.8751
brightness,5,54.9182
contrast,1,81.9474
contrast,2,76.6098
contrast,3,69.2146
contrast,4,63.5699
contrast,5,56.3001
defocus_blur,1,75.3061
defocus_blur,2,63.1296
defocus_blur,3,49.0157
defocus_blur,4,35.5549
defocus_blur,5,22.4257
elastic_transform,1,80.2629
elastic_transform,2,72.3196
elastic_transform,3,65.0963
elastic_transform,4,57.4527
