In [None]:
#  Imports & setup

import time
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms, models
import wandb

# Fix seeds for deterministic behaviour (optional)
torch.manual_seed(42)
np.random.seed(42)

# Quick GPU info for logs
if torch.cuda.is_available():
    print(f"[INFO] Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("[INFO] Training on CPU – may be slow.")


IMAGE_SIZE      = 224
NUM_OF_CLASSES  = 10

#  Helper functions

def _get_tfms() -> Tuple[transforms.Compose, transforms.Compose]:
    """Return deterministic (no augmentation) train/test transforms."""
    base = transforms.Resize((IMAGE_SIZE, IMAGE_SIZE))
    to_t  = transforms.ToTensor()
    return transforms.Compose([base, to_t]), transforms.Compose([base, to_t])


def split_dataset_with_class_distribution(dataset, split_ratio):
    """
    Split Nature‑12K into train/val by hard‑coded index ranges
    so that class distribution mirrors original order (1 000 images/class).
    """
    ranges = [(i * 1000, (i + 1) * 1000 - 1) for i in range(10)]
    ranges[-1] = (9000, 9998)  # last class has 999 samples

    train_ids, val_ids = [], []
    for start, end in ranges:
        idxs     = list(range(start, end + 1))
        cut      = int(len(idxs) * split_ratio)
        train_ids.extend(idxs[:cut])
        val_ids.extend(idxs[cut:])

    return Subset(dataset, train_ids), Subset(dataset, val_ids)


def prepare_data(h_params: Dict) -> Dict:
    """Load datasets & build loaders. Prints split sizes for sanity."""
    train_tfms, test_tfms = _get_tfms()

    base_dir = Path("/kaggle/input/nature1/inaturalist_12K")
    train_root, val_root = base_dir / "train", base_dir / "val"

    full_train = ImageFolder(train_root, transform=train_tfms)
    train_ds, val_ds = split_dataset_with_class_distribution(full_train, 0.8)
    test_ds = ImageFolder(val_root, transform=test_tfms)  # Kaggle "val" as hold‑out test

    bs = h_params["batch_size"]
    data = dict(
        train_len=len(train_ds),
        val_len=len(val_ds),
        test_len=len(test_ds),
        train_loader=DataLoader(train_ds, batch_size=bs, shuffle=True,  num_workers=2),
        val_loader=  DataLoader(val_ds,   batch_size=bs, shuffle=False, num_workers=2),
        test_loader= DataLoader(test_ds,  batch_size=bs, shuffle=False, num_workers=2),
    )
    print(f"[INFO] train:{data['train_len']}  val:{data['val_len']}  test:{data['test_len']}")
    return data


def _freeze_except_last_k(model: nn.Module, k: int) -> None:
    """Freeze all params then unfreeze last *k* parameters (flattened order)."""
    for p in model.parameters():
        p.requires_grad = False
    if k > 0:
        for p in list(model.parameters())[-k:]:
            p.requires_grad = True


def resnet50Model(h_params: Dict) -> nn.Module:
    """Create a ResNet‑50 with custom head and selective unfreezing."""
    model = models.resnet50(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, NUM_OF_CLASSES)
    _freeze_except_last_k(model, h_params["last_unfreeze_layers"])
    return model


# Training & evaluation
def train(h_params: Dict, data: Dict) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = resnet50Model(h_params).to(device)

    # DataParallel allows multi‑GPU; safe on single GPU
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=h_params["learning_rate"])

    for ep in range(h_params["epochs"]):
        model.train()
        ep_loss, ep_correct = 0.0, 0

        for step, (x, y) in enumerate(data["train_loader"]):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

            ep_loss    += loss.item()
            ep_correct += (out.argmax(1) == y).sum().item()

            # verbose logging every 20 mini‑batches
            if step % 20 == 0:
                print(f"Ep{ep:<2d} step{step:<4d} "
                      f"batch_acc:{(out.argmax(1)==y).float().mean():.3f} "
                      f"loss:{loss.item():.4f}")

        # ══ Validation ════════════════════════════════════════════════════
        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for x, y in data["val_loader"]:
                x, y = x.to(device), y.to(device)
                out  = model(x)
                val_loss    += criterion(out, y).item()
                val_correct += (out.argmax(1) == y).sum().item()

        # Epoch‑level metrics
        train_acc = ep_correct / data["train_len"]
        val_acc   = val_correct / data["val_len"]
        print(f"[EP{ep}] train_acc:{train_acc:.4f} val_acc:{val_acc:.4f} "
              f"train_loss:{ep_loss/len(data['train_loader']):.4f} "
              f"val_loss:{val_loss/len(data['val_loader']):.4f}")

        # Send to W&B
        wandb.log({
            "epoch": ep,
            "train_accuracy": train_acc,
            "val_accuracy":   val_acc,
            "train_loss":     ep_loss / len(data["train_loader"]),
            "val_loss":       val_loss / len(data["val_loader"]),
        })

        # free memory each epoch (mostly redundant)
        torch.cuda.empty_cache()

    # ══ Save final weights with timestamp to avoid overwrite ═════════════
    ts = int(time.time())
    torch.save(model.state_dict(), f"model_{ts}.pth")
    print("[DONE] Training complete – model saved.")


# Main execution

if __name__ == "__main__":
    wandb.login()  # falls back to env var if key is set elsewhere
    run_name = (f"{h_params['model']}_ep_{h_params['epochs']}"
                f"_bs_{h_params['batch_size']}"
                f"_lr_{h_params['learning_rate']}"
                f"_last_unfreeze_layers_{h_params['last_unfreeze_layers']}")

    run = wandb.init(project="DL Assignment 2B", name=run_name, config=h_params)
    data = prepare_data(h_params)
    train(h_params, data)
    run.finish()
