In [None]:
#  Imports
import time
import warnings
from pathlib import Path
from typing import Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import wandb

# Silence benign warnings that clutter output
warnings.filterwarnings("ignore", category=UserWarning)

# For perfect reproducibility (optional – comment if not desired)
torch.manual_seed(42)
np.random.seed(42)

#  Hyper‑parameters (same variable name `h_params` required downstream)

h_params = {
    "epochs": 10,
    "learning_rate": [0.0001],
    "batch_size": 128,
    "num_of_filter": 64,
    "filter_size": [3, 3, 3, 3, 3],
    "actv_func": "gelu",
    "filter_multiplier": 2,
    "data_augumentation": False,
    "batch_normalization": True,
    "dropout": 0.4,
    "conv_layers": 5,
    "dense_layer_size": 256,
}

IMAGE_SIZE      = 224          # Target spatial size fed to CNN
NUM_OF_CLASSES  = 10           # Dataset has 10 super‑classes
CLASS_LABELS    = [
    "Amphibia", "Animalia", "Arachnida", "Aves", "Fungi",
    "Insecta", "Mammalia", "Mollusca", "Plantae", "Reptilia"
]

#  Helper utilities

def _get_transforms(aug: bool) -> Tuple[transforms.Compose, transforms.Compose]:
    """Return (train_tfms, test_tfms) based on augmentation flag."""
    base_ops = [transforms.Resize((IMAGE_SIZE, IMAGE_SIZE))]
    aug_ops  = [
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.GaussianBlur(kernel_size=3),
    ]
    to_tensor = [transforms.ToTensor()]

    train_ops = base_ops + (aug_ops if aug else []) + to_tensor
    test_ops  = base_ops + to_tensor

    return transforms.Compose(train_ops), transforms.Compose(test_ops)


def _init_wandb_run(config: Dict) -> wandb.wandb_sdk.wandb_run.Run:
    """Start a uniquely‑named W&B run. Login key expected in env/cli."""
    wandb.login()                      # Empty key => falls back to env var
    run_name = (
        f"{config['actv_func']}_ep_{config['epochs']}"
        f"_lr_{config['learning_rate']}"
        f"_init_fltr_cnt_{config['num_of_filter']}"
        f"_fltr_sz_{config['filter_size']}"
        f"_fltr_mult_{config['filter_multiplier']}"
        f"_data_aug_{config['data_augumentation']}"
        f"_batch_norm_{config['batch_normalization']}"
        f"_dropout_{config['dropout']}"
        f"_dense_{config['dense_layer_size']}"
    )
    return wandb.init(project="DL Assignment 2", name=run_name, config=config)


def split_dataset_with_class_distribution(
    dataset: ImageFolder, split_ratio: float
) -> Tuple[Subset, Subset]:
    """
    Split indices into train/val buckets *per‑class* to preserve distribution.
    The ranges are hard‑coded because Nature‑12K data on Kaggle arrives ordered
    by class in blocks of 1 000 samples each (last block has 999).
    """
    class_ranges = [(i * 1000, (i + 1) * 1000 - 1) for i in range(10)]
    class_ranges[-1] = (9000, 9998)  # adjust final slice size

    train_idx, val_idx = [], []
    for start, end in class_ranges:
        class_ids = list(range(start, end + 1))
        split_at  = int(len(class_ids) * split_ratio)
        train_idx.extend(class_ids[:split_at])
        val_idx.extend(class_ids[split_at:])

    return Subset(dataset, train_idx), Subset(dataset, val_idx)


def prepare_data(h_params: Dict) -> Dict:
    """
    Build loaders & meta info.
    Adds a redundant sanity‑check printout of dataset sizes that can be
    commented out if you dislike the clutter.
    """
    train_tfms, test_tfms = _get_transforms(h_params["data_augumentation"])

    # NOTE: update paths if your dataset directory differs
    base = Path("/kaggle/input/nature/inaturalist_12K")
    train_dir, val_dir = base / "train", base / "val"

    full_train = ImageFolder(train_dir, transform=train_tfms)
    train_ds, val_ds = split_dataset_with_class_distribution(full_train, 0.8)

    test_ds = ImageFolder(val_dir, transform=test_tfms)  # Kaggle's "val" is our test

    bs = h_params["batch_size"]
    data = {
        "train_len": len(train_ds),   "val_len": len(val_ds),   "test_len": len(test_ds),
        "train_loader": DataLoader(train_ds, batch_size=bs, shuffle=True,  num_workers=2),
        "val_loader":   DataLoader(val_ds,   batch_size=bs, shuffle=False, num_workers=2),
        "test_loader":  DataLoader(test_ds,  batch_size=bs, shuffle=False, num_workers=2),
    }

    # Redundant diagnostic – safe to remove
    print(f"[INFO] Data splits – train:{data['train_len']}  "
          f"val:{data['val_len']}  test:{data['test_len']}")
    return data


# 4. Model definition (class name & attributes unchanged)

class CNN(nn.Module):
    def __init__(self, h_params: Dict) -> None:
        super().__init__()
        self.h_params = h_params

        # Dynamically grow filter list
        self.filters = [
            int(self.h_params["num_of_filter"] * (self.h_params["filter_multiplier"] ** i))
            for i in range(5)
        ]
        print(f"[DEBUG] Convolution filter counts: {self.filters}")

        # Build sequential conv–>BN blocks
        self.conv_layers = nn.ModuleList()
        self.bn_layers   = nn.ModuleList()

        in_ch = 3
        for i in range(self.h_params["conv_layers"]):
            self.conv_layers.append(nn.Conv2d(in_ch, self.filters[i],
                                              kernel_size=self.h_params["filter_size"][i]))
            if self.h_params["batch_normalization"]:
                self.bn_layers.append(nn.BatchNorm2d(self.filters[i]))
            in_ch = self.filters[i]

        fmap_side = self._calc_flatten_size(self.h_params["filter_size"], IMAGE_SIZE)
        dense_in  = self.filters[-1] * fmap_side * fmap_side

        self.fc1 = nn.Linear(dense_in, self.h_params["dense_layer_size"])
        self.fc2 = nn.Linear(self.h_params["dense_layer_size"], NUM_OF_CLASSES)

        self.dropout = nn.Dropout(self.h_params["dropout"])
        self.activation_func = self._get_activation(self.h_params["actv_func"])

    # ------------------------------------------------------------------
    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        """Forward pass: conv‑stack → flatten → MLP."""
        for i in range(self.h_params["conv_layers"]):
            x = self.conv_layers[i](x)
            if self.h_params["batch_normalization"]:
                x = self.bn_layers[i](x)
            x = self.activation_func(x)
            x = F.max_pool2d(x, 2)   # stride=2 implicitly
        x = x.flatten(1)            # (B, ‑1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    # ------------------------------------------------------------------
    @staticmethod
    def _get_activation(name: str):
        """Return functional activation by string."""
        mapping = {
            "elu":         F.elu,
            "gelu":        F.gelu,
            "silu":        F.silu,
            "selu":        F.selu,
            "leaky_relu":  F.leaky_relu,
        }
        return mapping.get(name, F.relu)

    # ------------------------------------------------------------------
    @staticmethod
    def _calc_flatten_size(kernels, size) -> int:
        """Compute leftover fmap side after (conv → pool)×5."""
        for k in kernels:  # each conv followed by 2×2 pooling
            size = (size - k + 1) // 2
        return size


# 5. Evaluation helpers

@torch.no_grad()
def evaluate_testing_model(model: nn.Module, device, loader_data: Dict) -> None:
    model.eval()
    total_correct = 0
    for inputs, labels in loader_data["test_loader"]:
        inputs, labels = inputs.to(device), labels.to(device)
        preds = model(inputs).argmax(dim=1)
        total_correct += (preds == labels).sum().item()

    acc = total_correct / loader_data["test_len"]
    print(f"[METRIC] Test accuracy: {acc:.4f}")


def generateGridImage(model: nn.Module, device, loader_data: Dict) -> None:
    """Log a 10×3 (30‑image) grid with true/pred labels to W&B."""
    model.eval()
    test_iter = iter(loader_data["test_loader"])
    imgs, trues, preds = [], [], []

    for _ in range(30):
        x, y = next(test_iter)
        x, y = x.to(device), y.to(device)
        p = model(x).argmax(1)
        imgs.extend(x.cpu())
        trues.extend(y.cpu())
        preds.extend(p.cpu())

    fig, axs = plt.subplots(10, 3, figsize=(12, 40))
    for idx, ax in enumerate(axs.ravel()):
        ax.imshow(np.transpose(imgs[idx], (1, 2, 0)))
        ax.set_axis_off()
        ax.set_title(f"T:{CLASS_LABELS[trues[idx]]}\nP:{CLASS_LABELS[preds[idx]]}",
                     fontdict={"fontsize": 10})
    plt.tight_layout()
    wandb.log({"Predictions": wandb.Image(plt.gcf())})
    plt.close(fig)


# 6. Training loop

def train(h_params: Dict, data: Dict) -> None:
    start_ts = time.time()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN(h_params).to(device)

    # DataParallel helps on multi‑GPU nodes; harmless on single GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))

    loss_fn  = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=h_params["learning_rate"])

    for epoch in range(h_params["epochs"]):
        model.train()
        epoch_loss = epoch_correct = 0

        for step, (x, y) in enumerate(data["train_loader"]):
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            out   = model(x)
            loss  = loss_fn(out, y)
            loss.backward()
            optimizer.step()

            epoch_loss    += loss.item()
            epoch_correct += (out.argmax(1) == y).sum().item()

            # Redundant intra‑epoch message every 10 mini‑batches
            if step % 10 == 0:
                batch_acc = (out.argmax(1) == y).float().mean().item()
                print(f"Ep{epoch:<2d} • step {step:<4d} "
                      f"batch_acc:{batch_acc:5.3f}  loss:{loss.item():6.4f}")

        # ── validation ────────────────────────────────────────────────────
        model.eval()
        val_loss = val_correct = 0
        with torch.no_grad():
            for x, y in data["val_loader"]:
                x, y = x.to(device), y.to(device)
                out  = model(x)
                val_loss    += loss_fn(out, y).item()
                val_correct += (out.argmax(1) == y).sum().item()

        # epoch‑level metrics
        train_acc = epoch_correct / data["train_len"]
        val_acc   = val_correct   / data["val_len"]

        print(f"[EPOCH {epoch}] "
              f"train_acc:{train_acc:.4f}  val_acc:{val_acc:.4f}  "
              f"train_loss:{epoch_loss/len(data['train_loader']):.4f}  "
              f"val_loss:{val_loss/len(data['val_loader']):.4f}")

        # W&B logging
        wandb.log({
            "epoch": epoch,
            "train_accuracy": train_acc,
            "val_accuracy":   val_acc,
            "train_loss":     epoch_loss / len(data["train_loader"]),
            "val_loss":       val_loss   / len(data["val_loader"]),
        })

        # optional: free VRAM between epochs
        torch.cuda.empty_cache()

    print(f"[DONE] Training completed in {(time.time() - start_ts)/60:.1f} min")
    torch.save(model.state_dict(), "bestmodel.pth")

    # Uncomment to evaluate and/or visualise predictions
    # evaluate_testing_model(model, device, data)
    # generateGridImage(model, device, data)


# 7. Main entry point (run/sweep)
if __name__ == "__main__":
    # Regular run (single config) ------------------------------------------------
    data = prepare_data(h_params)
    run  = _init_wandb_run(h_params)
    train(h_params, data)
    run.finish()

    # Sweep run (Bayesian search) ------------------------------------------------
    sweep_params = {
        "method": "bayes",
        "name":   "DL assn 2 sweep",
        "metric": {"goal": "maximize", "name": "val_accuracy"},
        "parameters": {
            "epochs":              {"values": [10]},
            "learning_rate":       {"values": [0.0001, 0.001]},
            "batch_size":          {"values": [32, 64]},
            "num_of_filter":       {"values": [16, 32, 64]},
            "filter_size":         {"values": [[3]*5, [5]*5, [7]*5,
                                               [11, 9, 7, 5, 3], [3, 5, 7, 9, 11]]},
            "actv_func":           {"values": ["elu", "gelu", "leaky_relu", "selu"]},
            "filter_multiplier":   {"values": [1, 2]},
            "data_augumentation":  {"values": [False]},
            "batch_normalization": {"values": [True, False]},
            "dropout":             {"values": [0, 0.1, 0.2]},
            "dense_layer_size":    {"values": [64, 128, 256]},
            "conv_layers":         {"values": [5]},
        },
    }

    sweep_id = wandb.sweep(sweep=sweep_params, project="DL Assignment 2")

    def main():
        wandb.init(project="DL Assignment 2")            # loads sweep config
        cfg = dict(wandb.config)                        # keep original var name
        d   = prepare_data(cfg)
        train(cfg, d)

    # Launch up to 10 sweep trials
    wandb.agent(sweep_id, function=main, count=10)
