In [22]:
import os
from pathlib import Path
from dotenv import load_dotenv

load_dotenv(dotenv_path=Path(".env"), override=False)

PROJECT = os.environ.get("WANDB_PROJECT", "tinyimagenet-resnet")
ENTITY = os.environ.get("WANDB_ENTITY", "sairohith")
DATA_DIR = os.environ.get("DATA_DIR", "./tiny-imagenet-200")
EPOCHS = int(os.environ.get("EPOCHS", 3))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", 128))
LR = float(os.environ.get("LR", 1e-3))
IMG_SIZE = int(os.environ.get("IMG_SIZE", 64))
DATA_FRACTION = float(os.environ.get("DATA_FRACTION", 0.01))
DRIFT_THRESHOLD = float(os.environ.get("DRIFT_THRESHOLD", 0.1))
SAVE_PATH = os.environ.get("SAVE_PATH", "best_resnet_wandb.pth")
RUN_NAME = os.environ.get("RUN_NAME", "resnet18_run1")


In [23]:
%pip install -q python-dotenv


[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip
[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


Note: you may need to restart the kernel to use updated packages.



In [24]:
import os
import math
import copy
from pathlib import Path
from PIL import Image
import random
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision.models as models
from tqdm import tqdm
import wandb

# ---------------- transforms ----------------
def get_transforms(img_size=64, train=True):
    if train:
        return T.Compose([
            T.RandomResizedCrop(img_size),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
    return T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])


In [25]:

# ---------------- dataloaders ----------------
def load_tiny_imagenet_dataloaders(data_dir, batch_size=128, img_size=64, num_workers=4, fraction=1.0, seed=42):
    data_dir = Path(data_dir)
    if not 0 < fraction <= 1:
        raise ValueError("fraction must be in (0, 1].")
    random.seed(seed)
    rng = random.Random(seed)

    # train uses ImageFolder structure: train/<wnid>/images/*.JPEG
    train_dir = data_dir / "train"
    train_tf = get_transforms(img_size, train=True)
    train_ds = datasets.ImageFolder(root=str(train_dir), transform=train_tf)

    if fraction < 1.0:
        per_class_indices = {}
        for idx, (_, label) in enumerate(train_ds.samples):
            per_class_indices.setdefault(label, []).append(idx)
        subset_indices = []
        for indices in per_class_indices.values():
            take = max(1, math.ceil(len(indices) * fraction))
            subset_indices.extend(rng.sample(indices, take))
        subset_indices.sort()
        train_data = Subset(train_ds, subset_indices)
    else:
        train_data = train_ds
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

    # validation: read val_annotations and precompute tensors (simple)
    val_img_dir = data_dir / "val" / "images"
    ann_file = data_dir / "val" / "val_annotations.txt"
    val_tf = get_transforms(img_size, train=False)

    # Use the correct mapping provided by ImageFolder
    wnid_to_idx = train_ds.class_to_idx
    tensors = []
    labels = []
    with open(ann_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 2: 
                continue
            img_name, wnid = parts[0], parts[1]
            p = val_img_dir / img_name
            if not p.exists(): 
                continue
            img = Image.open(p).convert("RGB")
            tensors.append(val_tf(img))
            labels.append(wnid_to_idx[wnid])
    if len(tensors) == 0:
        raise RuntimeError("No validation images found — check data_dir layout.")
    x = torch.stack(tensors)
    y = torch.tensor(labels, dtype=torch.long)
    val_dataset = TensorDataset(x, y)
    if fraction < 1.0:
        per_class_val = {}
        for idx, label in enumerate(labels):
            per_class_val.setdefault(label, []).append(idx)
        val_indices = []
        for indices in per_class_val.values():
            take = max(1, math.ceil(len(indices) * fraction))
            val_indices.extend(rng.sample(indices, take))
        val_indices.sort()
        val_dataset = Subset(val_dataset, val_indices)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader, len(train_ds.classes)

In [26]:


# ---------------- model ----------------
def build_pretrained_resnet(num_classes, device=None):
    try:
        weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=weights)
    except Exception:
        model = models.resnet18(pretrained=True)
    in_f = model.fc.in_features
    model.fc = nn.Linear(in_f, num_classes)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return model.to(device), device

In [27]:


# ---------------- training / eval ----------------
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in tqdm(loader, leave=False):
        imgs = imgs.to(device)
        labels = labels.to(device)
        out = model(imgs)
        loss = criterion(out, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        preds = out.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return running_loss / total, correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in tqdm(loader, leave=False):
            imgs = imgs.to(device)
            labels = labels.to(device)
            out = model(imgs)
            loss = criterion(out, labels)
            running_loss += loss.item() * imgs.size(0)
            preds = out.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return running_loss / total, correct / total

In [28]:


# ---------------- full training with W&B ----------------
def train_resnet_with_wandb(data_dir,
                           project="tinyimagenet-resnet",
                           entity=None,
                           run_name=None,
                           epochs=10,
                           batch_size=128,
                           img_size=64,
                           lr=1e-3,
                           fraction=1.0,
                           save_path="best_resnet_wandb.pth"):
    # init wandb
    config = dict(epochs=epochs, batch_size=batch_size, img_size=img_size, lr=lr, data_dir=str(data_dir), data_fraction=fraction)
    run = wandb.init(project=project, entity=entity, name=run_name, config=config)
    cfg = run.config

    # data + model
    train_loader, val_loader, num_classes = load_tiny_imagenet_dataloaders(cfg.data_dir, batch_size=cfg.batch_size, img_size=cfg.img_size, fraction=cfg.data_fraction)
    model, device = build_pretrained_resnet(num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    base_dataset = train_loader.dataset.dataset if isinstance(train_loader.dataset, Subset) else train_loader.dataset
    class_names = getattr(base_dataset, "classes", [])

    best_val_acc = 0.0
    best_epoch = 0
    best_state = copy.deepcopy(model.state_dict())
    for epoch in range(1, cfg.epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        run.log({
            "epoch": epoch,
            "train/loss": train_loss, "train/acc": train_acc,
            "val/loss": val_loss, "val/acc": val_acc,
            "lr": scheduler.get_last_lr()[0] if hasattr(scheduler, "get_last_lr") else optimizer.param_groups[0]["lr"],
            "data_fraction": cfg.data_fraction
        })
        print(f"Epoch {epoch}/{cfg.epochs}  train_acc={train_acc:.4f}  val_acc={val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            best_state = copy.deepcopy(model.state_dict())
        scheduler.step()

    checkpoint_path = None
    if save_path:
        cpu_state = {k: v.cpu() for k, v in best_state.items()}
        checkpoint = {
            "epoch": best_epoch,
            "model_state": cpu_state,
            "val_acc": best_val_acc,
            "class_names": class_names
        }
        torch.save(checkpoint, save_path)
        checkpoint_path = save_path
        art = wandb.Artifact("resnet18-tinyimagenet", type="model")
        art.add_file(save_path)
        run.log_artifact(art)
        print(f"Saved best checkpoint from epoch {best_epoch} -> {save_path}")

    run.summary["best_val_acc"] = best_val_acc
    run.finish()
    return best_val_acc, checkpoint_path





In [29]:
import wandb
wandb.login()

True

In [30]:
best_acc, checkpoint_path = train_resnet_with_wandb(data_dir=DATA_DIR,
                                                      project=PROJECT,
                                                      run_name=RUN_NAME,
                                                      epochs=EPOCHS,
                                                      batch_size=BATCH_SIZE,
                                                      img_size=IMG_SIZE,
                                                      lr=LR,
                                                      fraction=DATA_FRACTION,
                                                      save_path=SAVE_PATH)
print("Best val accuracy:", best_acc)
print("Checkpoint saved to:", checkpoint_path)

                                               

Epoch 1/3  train_acc=0.0126  val_acc=0.0290



                                               

Epoch 2/3  train_acc=0.0652  val_acc=0.0780



                                               

Epoch 3/3  train_acc=0.1359  val_acc=0.1410

Saved best checkpoint from epoch 3 -> best_resnet_wandb.pth
Saved best checkpoint from epoch 3 -> best_resnet_wandb.pth


0,1
data_fraction,▁▁▁
epoch,▁▅█
lr,▁▁▁
train/acc,▁▄█
train/loss,█▄▁
val/acc,▁▄█
val/loss,█▄▁

0,1
best_val_acc,0.141
data_fraction,0.1
epoch,3.0
lr,0.001
train/acc,0.1359
train/loss,4.40628
val/acc,0.141
val/loss,4.14922


Best val accuracy: 0.141
Checkpoint saved to: best_resnet_wandb.pth
 0.141
Checkpoint saved to: best_resnet_wandb.pth


## Deploy to Hugging Face Space
Follow the cells below to package the trained checkpoint as a Gradio app and push it to a Hugging Face Space.

**High-level flow:**
1. Make sure the training cell above finished and saved `best_resnet_wandb.pth` (or another checkpoint).
2. Configure the required secrets (Hugging Face access token, W&B entity/project/artifact).
3. Generate the Gradio app (`app.py`) that knows how to download the artifact and run inference.
4. Create the `requirements.txt` file for the Space.
5. Push both files to a new (or existing) Hugging Face Space using `huggingface_hub`.

In [31]:
import os
from pathlib import Path
from getpass import getpass
from dotenv import load_dotenv

load_dotenv(dotenv_path=Path(".env"), override=False)

def _ensure_env(var_name, prompt_text, is_secret=False, default=""):
    current = os.environ.get(var_name, default)
    if current:
        masked = "***" if is_secret else current
        print(f"Using {var_name} from environment (.env): {masked}")
        os.environ[var_name] = current
        return current
    if is_secret:
        value = getpass(prompt_text).strip()
    else:
        value = input(prompt_text).strip()
    if value:
        os.environ[var_name] = value
    return value

hf_token = _ensure_env("HF_TOKEN", "HF token (input hidden): ", is_secret=True)
hf_user = _ensure_env("HF_USER", "Hugging Face username: ")
space_name = _ensure_env("SPACE_NAME", "Desired Space name (e.g., tinyimagenet-demo): ")

_ensure_env("WANDB_ENTITY", "W&B entity (username/team used during training): ", default=ENTITY)
_ensure_env("WANDB_PROJECT", "W&B project name: ", default=PROJECT)

default_artifact = ""
if os.environ.get("WANDB_ENTITY") and os.environ.get("WANDB_PROJECT"):
    default_artifact = f"{os.environ['WANDB_ENTITY']}/{os.environ['WANDB_PROJECT']}/resnet18-tinyimagenet:latest"
_ensure_env("WANDB_ARTIFACT", "W&B artifact path [entity/project/name:alias]: ", default=default_artifact)

print("Environment variables ready for Hugging Face upload.")

Using HF_TOKEN from environment (.env): ***
Using HF_USER from environment (.env): SaiRohith24816
Using SPACE_NAME from environment (.env): tinyimagenet-demo
Using WANDB_ENTITY from environment (.env): ir2023
Using WANDB_PROJECT from environment (.env): tinyimagenet-resnet
Using WANDB_ARTIFACT from environment (.env): ir2023/tinyimagenet-resnet/resnet18-tinyimagenet:latest
Environment variables ready for Hugging Face upload.
Using HF_USER from environment (.env): SaiRohith24816
Using SPACE_NAME from environment (.env): tinyimagenet-demo
Using WANDB_ENTITY from environment (.env): ir2023
Using WANDB_PROJECT from environment (.env): tinyimagenet-resnet
Using WANDB_ARTIFACT from environment (.env): ir2023/tinyimagenet-resnet/resnet18-tinyimagenet:latest
Environment variables ready for Hugging Face upload.



In [32]:
%pip install -q gradio huggingface_hub git-lfs


[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip
[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


Note: you may need to restart the kernel to use updated packages.



### Generate Gradio app
The cell below creates an `hf_space` folder (safe to re-run) and writes `app.py`. The app will:
- Download the checkpoint from the configured W&B artifact if `MODEL_PATH` is missing.
- Rebuild the Tiny-ImageNet ResNet-18, load weights, and expose a Gradio interface.
- Display top-1 prediction, confidence, and inference latency.

In [33]:
from pathlib import Path
import textwrap
import shutil

hf_dir = Path("hf_space")
hf_dir.mkdir(exist_ok=True)
app_path = hf_dir / "app.py"
app_code = textwrap.dedent("""\
import os
import time
from pathlib import Path
from typing import Dict, List
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
import gradio as gr

try:
    from dotenv import load_dotenv
    load_dotenv(dotenv_path=Path(".env"), override=False)
except Exception as err:
    print(f"Could not load .env file: {err}")

MODEL_PATH = Path(os.environ.get("MODEL_PATH", "best_resnet_wandb.pth"))
ARTIFACT_PATH = os.environ.get("WANDB_ARTIFACT")
DOWNLOAD_ROOT = Path(os.environ.get("MODEL_DOWNLOAD_DIR", "./downloaded_artifacts"))

def maybe_download_artifact():
    if MODEL_PATH.exists():
        return
    if not ARTIFACT_PATH:
        print("MODEL_PATH missing and WANDB_ARTIFACT not set -- cannot download weights.")
        return
    try:
        import wandb
        api_key = os.environ.get("WANDB_API_KEY")
        if api_key:
            wandb.login(key=api_key)
        else:
            wandb.login()
        api = wandb.Api()
        artifact = api.artifact(ARTIFACT_PATH)
        DOWNLOAD_ROOT.mkdir(parents=True, exist_ok=True)
        artifact_dir = Path(artifact.download(root=str(DOWNLOAD_ROOT), recursive=False))
        candidate = artifact_dir / MODEL_PATH.name
        if candidate.exists():
            candidate.replace(MODEL_PATH)
        else:
            matches = sorted(artifact_dir.glob("*.pth"))
            if matches:
                matches[0].replace(MODEL_PATH)
        print(f"Downloaded artifact files to {artifact_dir}")
    except Exception as err:
        print(f"Failed to download artifact: {err}")

def load_class_names(ckpt) -> List[str]:
    names = ckpt.get("class_names")
    if isinstance(names, list) and len(names) > 0:
        return names
    wnids_file = Path("./tiny-imagenet-200/wnids.txt")
    if wnids_file.exists():
        return [line.strip() for line in wnids_file.read_text().splitlines() if line.strip()]
    return [f"class_{idx}" for idx in range(200)]

def build_model(num_classes: int) -> nn.Module:
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def format_top_k(probs: torch.Tensor, class_names: List[str], k: int = 5) -> Dict[str, float]:
    scores = probs.squeeze(0)
    k = min(k, scores.numel())
    top_scores, top_indices = torch.topk(scores, k)
    formatted: Dict[str, float] = {}
    for score, idx in zip(top_scores, top_indices):
        label = class_names[idx] if 0 <= idx < len(class_names) else f"class_{int(idx)}"
        formatted[label] = float(score)
    return formatted

def main():
    maybe_download_artifact()
    if not MODEL_PATH.exists():
        raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    state_dict = checkpoint.get("model_state", checkpoint)
    class_names = load_class_names(checkpoint)
    model = build_model(len(class_names))
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    transform = T.Compose([
        T.Resize((64, 64)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    def predict(image: Image.Image):
        start = time.time()
        x = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(x)
            probs = torch.softmax(logits, dim=1)
        elapsed = (time.time() - start) * 1000.0
        topk_dict = format_top_k(probs, class_names, k=5)
        top_label = next(iter(topk_dict)) if topk_dict else "unknown"
        print(f"Predicted {top_label} in {elapsed:.2f} ms")
        return topk_dict

    title = "Tiny-ImageNet ResNet18"
    description = ("Upload an image to see the top-5 Tiny-ImageNet predictions from the fine-tuned ResNet18 model.")
    gr.Interface(
        fn=predict,
        inputs=gr.Image(type="pil"),
        outputs=gr.Label(num_top_classes=5),
        title=title,
        description=description,
    ).launch(server_name="0.0.0.0", server_port=7860)

if __name__ == "__main__":
    main()

""")
app_path.write_text(app_code)
print(f"Wrote {app_path}")

src_checkpoint = Path(SAVE_PATH).expanduser()
if src_checkpoint.exists():
    dest_checkpoint = hf_dir / src_checkpoint.name
    shutil.copy2(src_checkpoint, dest_checkpoint)
    print(f"Copied checkpoint to {dest_checkpoint}")
else:
    print("Warning: checkpoint not found locally; set WANDB_ARTIFACT so the Space can download weights.")

Wrote hf_space\app.py

Copied checkpoint to hf_space\best_resnet_wandb.pth
Copied checkpoint to hf_space\best_resnet_wandb.pth


### Create requirements.txt
This requirements list is kept minimal for the Gradio app. Add pins if the Space needs reproducibility.

In [34]:
requirements = """\
torch
torchvision
gradio
onnxruntime; sys_platform == 'win32'
pillow
wandb
huggingface_hub
git-lfs
python-dotenv
"""
req_path = Path("hf_space") / "requirements.txt"
req_path.write_text(requirements.strip() + "\n")
print(f"Wrote {req_path}")

Wrote hf_space\requirements.txt



### Push to Hugging Face Space
The next cell creates (or reuses) the target Space and uploads `app.py` and `requirements.txt`. It expects the environment variables you set above to be present.

In [35]:
import sys
from huggingface_hub import HfApi

hf_token = os.environ.get("HF_TOKEN")
hf_user = os.environ.get("HF_USER")
space_name = os.environ.get("SPACE_NAME")
if not hf_token or not hf_user or not space_name:
    raise RuntimeError("HF_TOKEN, HF_USER, and SPACE_NAME must be set before pushing.")

repo_id = f"{hf_user}/{space_name}"
api = HfApi(token=hf_token)
repo_url = api.create_repo(repo_id=repo_id, repo_type="space", space_sdk="gradio", exist_ok=True)
print(f"Space repo: {repo_url}")

api.upload_folder(
    folder_path="hf_space",
    path_in_repo=".",
    repo_id=repo_id,
    repo_type="space",
    commit_message="Update Tiny-ImageNet Gradio app"
 )
print("Uploaded app.py and requirements.txt to the Space.")
print("Next: add WANDB_API_KEY as a secret on the Space if artifact download is required.")

Space repo: https://huggingface.co/spaces/SaiRohith24816/tinyimagenet-demo



best_resnet_wandb.pth:   0%|          | 0.00/45.2M [00:00<?, ?B/s]

Uploaded app.py and requirements.txt to the Space.
Next: add WANDB_API_KEY as a secret on the Space if artifact download is required.

Next: add WANDB_API_KEY as a secret on the Space if artifact download is required.


In [36]:
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image, ImageEnhance
import wandb
from pathlib import Path
from tqdm import tqdm
import random
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

class BrightnessShift:
    def __init__(self, factor=0.5):
        self.factor = factor
    def __call__(self, img):
        return ImageEnhance.Brightness(img).enhance(self.factor)

class AddGaussianNoise:
    def __init__(self, std=0.1):
        self.std = std
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std

def get_drift_transform(img_size, drift_type, intensity):
    transforms = [T.Resize((img_size, img_size))]
    if drift_type == "brightness":
        transforms.append(BrightnessShift(intensity))
    transforms.append(T.ToTensor())
    if drift_type == "noise":
        transforms.append(AddGaussianNoise(intensity))
    elif drift_type == "combined":
        transforms.insert(-1, BrightnessShift(0.4))
        transforms.append(AddGaussianNoise(0.15))
    transforms.append(T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]))
    return T.Compose(transforms)

def load_model(checkpoint_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(checkpoint_path, map_location=device)
    class_names = ckpt.get("class_names", [])
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, len(class_names))
    model.load_state_dict(ckpt.get("model_state", ckpt))
    model.to(device).eval()
    return model, class_names, device

def create_dataset(data_dir, class_names, num_samples=200):
    data_dir = Path(data_dir)
    samples = []
    with open(data_dir / "val" / "val_annotations.txt") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                img_path = data_dir / "val" / "images" / parts[0]
                if img_path.exists():
                    samples.append((img_path, parts[1]))
    random.seed(42)
    samples = random.sample(samples, min(num_samples, len(samples)))
    dataset = []
    for img_path, wnid in samples:
        try:
            dataset.append((Image.open(img_path).convert("RGB"), class_names.index(wnid)))
        except ValueError:
            continue
    return dataset

def evaluate(model, dataset, transform, device):
    criterion = nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for img, label in tqdm(dataset, leave=False):
            x = transform(img).unsqueeze(0).to(device)
            y = torch.tensor([label], dtype=torch.long).to(device)
            out = model(x)
            loss += criterion(out, y).item()
            correct += (out.argmax(1).item() == label)
            total += 1
    return correct / total, loss / total

print("Simple drift simulation functions loaded.")

Simple drift simulation functions loaded.



In [37]:
model, class_names, device = load_model("best_resnet_wandb.pth")
dataset = create_dataset("./tiny-imagenet-200", class_names, num_samples=200)
baseline_transform = T.Compose([
    T.Resize((64, 64)), T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

wandb.init(project="tinyimagenet-resnet", entity="ir2023", name="baseline", 
           tags=["drift"], config={"samples": len(dataset)})
baseline_acc, baseline_loss = evaluate(model, dataset, baseline_transform, device)
wandb.log({"accuracy": baseline_acc, "loss": baseline_loss})
wandb.summary.update({"baseline_accuracy": baseline_acc, "baseline_loss": baseline_loss})
wandb.finish()

print(f"Baseline: Accuracy={baseline_acc:.4f}, Loss={baseline_loss:.4f}")

                                                  

0,1
accuracy,▁
loss,▁

0,1
accuracy,0.18
baseline_accuracy,0.18
baseline_loss,4.16186
loss,4.16186


Baseline: Accuracy=0.1800, Loss=4.1619



In [38]:
scenarios = [
    {"name": "dark_30", "type": "brightness", "intensity": 0.3},
    {"name": "dark_20", "type": "brightness", "intensity": 0.2},
    {"name": "bright_180", "type": "brightness", "intensity": 1.8},
    {"name": "noise_low", "type": "noise", "intensity": 0.1},
    {"name": "noise_high", "type": "noise", "intensity": 0.25},
    {"name": "combined", "type": "combined", "intensity": 0.0},
]

results = []
threshold = 0.1

for s in scenarios:
    transform = get_drift_transform(64, s['type'], s['intensity'])
    run = wandb.init(project="tinyimagenet-resnet", entity="ir2023", 
                     name=f"drift_{s['name']}", tags=["drift", s['type']], 
                     config={**s, "baseline_acc": baseline_acc})
    
    acc, loss = evaluate(model, dataset, transform, device)
    drop = baseline_acc - acc
    drop_pct = (drop / baseline_acc * 100) if baseline_acc > 0 else 0
    alert = drop > threshold
    
    wandb.log({"accuracy": acc, "loss": loss, "accuracy_drop": drop, 
               "accuracy_drop_percent": drop_pct})
    wandb.summary.update({"accuracy": acc, "accuracy_drop": drop, 
                          "alert_triggered": alert})
    
    if alert:
        wandb.alert(title=f"Drift: {s['name']}", 
                    text=f"Acc drop: {drop:.4f} ({drop_pct:.2f}%)",
                    level=wandb.AlertLevel.WARN)
    
    results.append({"scenario": s['name'], "accuracy": acc, "loss": loss,
                    "drop": drop, "drop_pct": drop_pct, "alert": alert, 
                    "url": run.url})
    wandb.finish()
    print(f"{s['name']}: Acc={acc:.4f}, Drop={drop:.4f} ({'⚠️' if alert else '✓'})")

print(f"\nCompleted {len(scenarios)} drift scenarios")

                                                  

0,1
accuracy,▁
accuracy_drop,▁
accuracy_drop_percent,▁
loss,▁

0,1
accuracy,0.105
accuracy_drop,0.075
accuracy_drop_percent,41.66667
alert_triggered,False
loss,4.61847


dark_30: Acc=0.1050, Drop=0.0750 (✓)



                                                  

0,1
accuracy,▁
accuracy_drop,▁
accuracy_drop_percent,▁
loss,▁

0,1
accuracy,0.065
accuracy_drop,0.115
accuracy_drop_percent,63.88889
alert_triggered,True
loss,4.84289


dark_20: Acc=0.0650, Drop=0.1150 (⚠️)



                                                  

0,1
accuracy,▁
accuracy_drop,▁
accuracy_drop_percent,▁
loss,▁

0,1
accuracy,0.145
accuracy_drop,0.035
accuracy_drop_percent,19.44444
alert_triggered,False
loss,4.31817


bright_180: Acc=0.1450, Drop=0.0350 (✓)



                                                  

0,1
accuracy,▁
accuracy_drop,▁
accuracy_drop_percent,▁
loss,▁

0,1
accuracy,0.11
accuracy_drop,0.07
accuracy_drop_percent,38.88889
alert_triggered,False
loss,4.50471


noise_low: Acc=0.1100, Drop=0.0700 (✓)



                                                  

0,1
accuracy,▁
accuracy_drop,▁
accuracy_drop_percent,▁
loss,▁

0,1
accuracy,0.03
accuracy_drop,0.15
accuracy_drop_percent,83.33333
alert_triggered,True
loss,5.33197


noise_high: Acc=0.0300, Drop=0.1500 (⚠️)



                                                  

0,1
accuracy,▁
accuracy_drop,▁
accuracy_drop_percent,▁
loss,▁

0,1
accuracy,0.02
accuracy_drop,0.16
accuracy_drop_percent,88.88889
alert_triggered,True
loss,5.24387


combined: Acc=0.0200, Drop=0.1600 (⚠️)

Completed 6 drift scenarios


Completed 6 drift scenarios


In [39]:
df = pd.DataFrame(results)
print("\nDrift Results Summary:")
print(df[['scenario', 'accuracy', 'drop', 'drop_pct', 'alert']].to_string(index=False))
print(f"\nWorst: {df.loc[df['drop'].idxmax(), 'scenario']} (drop: {df['drop'].max():.4f})")


Drift Results Summary:
  scenario  accuracy  drop  drop_pct  alert
   dark_30     0.105 0.075 41.666667  False
   dark_20     0.065 0.115 63.888889   True
bright_180     0.145 0.035 19.444444  False
 noise_low     0.110 0.070 38.888889  False
noise_high     0.030 0.150 83.333333   True
  combined     0.020 0.160 88.888889   True

Worst: combined (drop: 0.1600)

  scenario  accuracy  drop  drop_pct  alert
   dark_30     0.105 0.075 41.666667  False
   dark_20     0.065 0.115 63.888889   True
bright_180     0.145 0.035 19.444444  False
 noise_low     0.110 0.070 38.888889  False
noise_high     0.030 0.150 83.333333   True
  combined     0.020 0.160 88.888889   True

Worst: combined (drop: 0.1600)


In [41]:
report = f"""# Drift Report

**Baseline:** {baseline_acc:.4f} | **Threshold:** {threshold}
**Samples:** {len(dataset)} | **Alerts:** {df['alert'].sum()}

## Results
{df[['scenario', 'accuracy', 'drop', 'drop_pct', 'alert']].to_markdown(index=False)}

## W&B Links
"""
for _, r in df.iterrows():
    report += f"- [{r['scenario']}]({r['url']})\n"

Path("drift_report.md").write_text(report, encoding='utf-8')
print("Saved: drift_simple_report.md")

Saved: drift_simple_report.md

