## Step 0 - Imports

In [None]:
# NOTE: feel free to utilize libraries that you find helpful

import math
import random

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from datasets import load_dataset
from IPython.display import clear_output
from torch.utils.data import DataLoader, Dataset


## Step 1 — Reproducibility + device

- **Reproducibility:** We fix random seeds so dataset shuffling, weight initialization, and dropout are as repeatable as possible across runs. This makes results comparable and debugging meaningful.

- **Determinism controls:** Setting `torch.backends.cudnn.deterministic=True` and `benchmark=False` reduces non-deterministic GPU behavior (at a potential speed cost), so you don’t “win or lose” accuracy due to hidden algorithm choices.


In [None]:
SEED: int = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")

## Step 2 — Load Fashion-MNIST dataset

> Remember that in seminar we did MNIST dataset that was only for digit classification.

Now we're working with a standardized image classification benchmark: 28×28 grayscale **clothing** photos, 10 classes, with an official 60k/10k train/test split.

We'll download it without specific options to later do normalization&shuffling ourselves in order to understand how it's done.

In [None]:
CLASS_NAMES: list[str] = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]


In [None]:
ds = load_dataset("fashion_mnist")  # cached under ~/.cache/huggingface/datasets by default

# Make columns come out as NumPy arrays where possible.
train_split = ds["train"].with_format("numpy")
test_split = ds["test"].with_format("numpy")


def _images_to_u8(images: object) -> np.ndarray:
    """
    The 'image' column can come back either as:
      - a single np.ndarray of shape [N, 28, 28], or
      - a list-like of [28, 28] arrays.
    Normalize it into a uint8 ndarray [N, 28, 28].
    """
    if isinstance(images, np.ndarray):
        arr = images
    else:
        arr = np.stack(images)  # works if it's list-like of 2D arrays
    return arr.astype(np.uint8)


train_images_u8 = _images_to_u8(train_split["image"])
train_labels_i64 = np.asarray(train_split["label"], dtype=np.int64)

test_images_u8 = _images_to_u8(test_split["image"])
test_labels_i64 = np.asarray(test_split["label"], dtype=np.int64)

assert train_images_u8.shape == (60_000, 28, 28)
assert train_labels_i64.shape == (60_000,)
assert test_images_u8.shape == (10_000, 28, 28)
assert test_labels_i64.shape == (10_000,)


## Step 3 — Data Visualization **[1 point]**

***Task***:

- Randomly select a small set of training examples.

- Display them as a grid of grayscale images.

- Put the class name (decoded from the numeric label) as the title for each image.

- Verify visually that:
  - the images look like clothing silhouettes (not noise/corrupted)
  - labels match what you see (e.g., sneaker vs sandal)
  - pixel orientation/contrast looks reasonable


In [None]:
grid_n = 16
idx = np.random.default_rng(SEED).choice(
    len(train_images_u8), size=grid_n, replace=False
)

raise NotImplementedError # TODO: <your code here>

## Step 4 — Train/val split + normalization **[1 point]**

***Task***:

- Create a **train/validation split** from the **shuffled (by you!)** original training set (e.g., 90% train, 10% val).

- Compute **normalization statistics (mean and std)** using **only the train split**.

- Apply the same normalization to **train, val, and test** using the train-derived mean/std.

- Add a channel dimension to get `[N, 1, 28, 28]`

    > We add a channel dimension so each sample matches the standard image tensor convention **[C, H, W]** (and batches **[N, C, H, W]**).  
    Even for grayscale, **C=1** is required so downstream layers and dataloaders can treat grayscale and RGB images uniformly.


In [None]:
val_fraction: float = 0.10

n_total = train_images_u8.shape[0]
n_val = int(n_total * val_fraction)
n_train = n_total - n_val

raise NotImplementedError # TODO: <your code here>

train_images = train_images_u8[train_idx]
train_labels = train_labels_i64[train_idx]
val_images = train_images_u8[val_idx]
val_labels = train_labels_i64[val_idx]


In [None]:
mean =  # TODO: <your code here>
std =   # TODO: <your code here>
# NOTE: we need them later for denormalization also!

print(f"Normalization (computed on TRAIN split only): mean={mean:.6f}, std={std:.6f}")


def normalize_u8(images_u8: np.ndarray, mean: float, std: float) -> np.ndarray:
    """Convert uint8 images [N,28,28] -> float32 normalized [N,1,28,28]."""
    raise NotImplementedError # TODO: <your code here>


In [None]:
x_train = normalize_u8(train_images, mean, std)
x_val = normalize_u8(val_images, mean, std)
x_test = normalize_u8(test_images_u8, mean, std)

y_train = train_labels.astype(np.int64)
y_val = val_labels.astype(np.int64)
y_test = test_labels_i64.astype(np.int64)


## Step 5 — Dataset + DataLoaders **[1 point]**

***Task***:

- Wrap `(x, y)` arrays into a custom `Dataset` that:
  - validates shapes (`x: [N,C,H,W]`, `y: [N]`):
    - validate `x` shape
    - validate `y` shape
    - validate that they length match

  - wraps `(x, y)` into `torch.Tensor`

  - returns one `(image_tensor, label_tensor)` per index

- Create `DataLoader`s for train/val/test that:
  - batch samples (set `batch_size`)

  - shuffle _only_ the training loader

  - enable `pin_memory` when using CUDA for faster host $\to$ GPU transfers


In [None]:
class FashionMNISTFromNumpy(Dataset[tuple[torch.Tensor, torch.Tensor]]):
    def __init__(self, x: np.ndarray, y: np.ndarray):
        raise NotImplementedError # TODO: <your code here>

    def __len__(self) -> int:
        return self.y.shape[0]

    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError # TODO: <your code here>


train_ds = FashionMNISTFromNumpy(x_train, y_train)
val_ds = FashionMNISTFromNumpy(x_val, y_val)
test_ds = FashionMNISTFromNumpy(x_test, y_test)

batch_size: int = 256
pin_memory: bool = torch.cuda.is_available()

train_loader = DataLoader(...)  # TODO: <your code here>
val_loader = DataLoader(...)    # TODO: <your code here>
test_loader = DataLoader(...)   # TODO: <your code here>

## Step 6 — Model + loss + optimizer **[1 point]**

***Task:***

- Create your classifier (start with some simple baseline) appropriate for 28×28 images:
  - **linear (logistic regression)** on flattened pixels *<[good for baseline]>*

  - **MLP(FFN)** on flattened pixels with `nn.Dropout` & `nn.BatchNorm` *<[our seminar choice]>*:
    - here you can use `nn.Linear`, `nn.ReLU`, `nn.Dropout`, `nn.BatchNorm`, maybe try to use `nn.Sequential` for easier way of work

  - **CNN** *<[if you soooo cool and bored with MLP and ready to learn something new yourself!]>*:
    - here you can use `nn.Conv2d`, `nn.MaxPool2d`

    - > NOTE: for using CNN you _won't_ get any additional points, so if you're new to DL, stick to MLP

- Use `nn.CrossEntropyLoss` for 10-way classification (raw logits, no softmax in the model).

- Use an optimizer that works well out of the box (maybe `Adam` will do the thing)

- Start with reasonable hyperparameters:
  - learning rate: remember of how we came up with learning rate value in seminar class using ideas of normalization of data and normalization _within_ our NN

  - maybe try out **weight decay** (which is direct way of doing regularization, like LASSO or RIDGE in simple linear regression)

- Move model (and later batches) to the selected device (`cpu`/`cuda`) using `.to(device)`.


In [None]:
class Classifier(nn.Module):
    def __init__(self, in_features: int, n_classes: int):
        super().__init__()
        raise NotImplementedError # TODO: <your code here>

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError # TODO: <your code here>


model =     # TODO: <your code here>
criterion = # TODO: <your code here>
optimizer = # TODO: <your code here>

## Step 7 — Metrics helpers **[1 point]**

***Task:***

- Implement a confusion-matrix accumulator that:
  - converts logits to predicted labels (`argmax`)

  - counts `(true, pred)` pairs into an `[K, K]` matrix

  - can be summed across batches to cover the full split

- Implement metric computation from the confusion matrix:
  - compute per-class precision/recall/F1

  - compute overall accuracy

  - compute macro-F1 (mean over classes)

Make these utilities `no_grad`-safe and numerically safe (avoid division by zero).


In [None]:
@torch.no_grad()
def confusion_matrix_from_logits(
    logits: torch.Tensor,
    targets: torch.Tensor,
    n_classes: int,
) -> torch.Tensor:
    """
    Build a confusion matrix from model logits and integer targets.

    Args:
        logits: Tensor of shape [B, n_classes] containing unnormalized class scores.
        targets: Tensor of shape [B] containing ground-truth class indices in [0, n_classes-1].
        n_classes: Total number of classes.

    Returns:
        Confusion matrix - a tensor of shape [n_classes, n_classes] with dtype int64, where:
          - rows correspond to true classes
          - columns correspond to predicted classes
        Entry cm[i, j] is the number of samples with true class i predicted as class j.

    Notes:
        Runs under no-grad (evaluation utility). Computation is performed on CPU for counting.
    """
    preds = torch.argmax(logits, dim=1).view(-1).to("cpu")
    t = targets.view(-1).to("cpu")

    raise NotImplementedError # TODO: <your code here>

    return cm


def metrics_from_confusion_matrix(cm: torch.Tensor) -> dict[str, object]:
    """
    Compute classification metrics from a confusion matrix.

    Args:
        cm: Confusion matrix tensor of shape [n_classes, n_classes] with nonnegative counts,
            where rows are true classes and columns are predicted classes.

    Returns:
        Dictionary with:
          - "accuracy": float, overall accuracy (sum(diagonal) / sum(all))
          - "macro_f1": float, mean of per-class F1 scores (uniform class weighting)
          - "precision_per_class": list[float], per-class precision
          - "recall_per_class": list[float], per-class recall
          - "f1_per_class": list[float], per-class F1
          - "support_per_class": list[int], number of true samples per class (row sums)
          - "confusion_matrix": list[list[int]], cm converted to nested Python lists

    Notes:
        Uses clamping / epsilon safeguards to avoid division-by-zero when a class has zero
        support (no true samples) or zero predicted count (never predicted).
    """
    cm_f = cm.float()
    diag = torch.diag(cm_f)
    support = cm_f.sum(dim=1).clamp_min(1.0)
    pred_count = cm_f.sum(dim=0).clamp_min(1.0)

    precision = (diag / pred_count).cpu().numpy()
    recall =    # TODO: <your code here>
    f1 =        # TODO: <your code here>
    macro_f1 = float(np.mean(f1))

    accuracy =  # TODO: <your code here>

    return {
        "accuracy": accuracy,
        "macro_f1": macro_f1,
        "precision_per_class": precision.tolist(),
        "recall_per_class": recall.tolist(),
        "f1_per_class": f1.tolist(),
        "support_per_class": cm.sum(dim=1).tolist(),
        "confusion_matrix": cm.tolist(),
    }

## Step 8 — Train with per-epoch validation + live visualization **[2 points]**

***Task:***

- Run training for a fixed number of epochs.

- For each epoch:
  - **Training phase:** set `model.train()`, iterate over `train_loader`, do forward → loss → backward → optimizer step, accumulate average train loss and train accuracy.

  - **Validation phase:** set `model.eval()` and `torch.no_grad()`, iterate over `val_loader`, compute average val loss and val accuracy.

- Store per-epoch metrics in lists (loss/accuracy for train and val).

- After each epoch:
  - Update plots of:
    - train vs val **loss**
    - train vs val **accuracy**
    using `clear_output(wait=True)` so the curves refresh in-place.

  - Print a compact epoch summary line with the tracked metrics.


In [None]:
epochs: int = # TODO: <your code here>

train_losses: list[float] = []
val_losses: list[float] = []
train_accs: list[float] = []
val_accs: list[float] = []

for epoch in range(1, epochs + 1):
    # ---- train
    # TODO: <your code here>

    # ---- validation
    # TODO: <your code here>


    # ---- live plots
    # TODO: <your code here>


## Step 9 — Test evaluation **[1 point]**

***Task:***

- Switch to evaluation mode: `model.eval()` and disable gradients with `with torch.no_grad(): ...`.

- Iterate over `test_loader` once and compute:
  - average **test loss**

  - total **confusion matrix** accumulated across all batches

- From the final confusion matrix, compute and report:
  - overall **accuracy**

  - **macro-F1**

  - per-class **precision/recall/F1** and **support**

- Print the confusion matrix (rows=true, cols=pred) for error analysis.

In [None]:
n_classes = len(CLASS_NAMES)

model.eval()
cm = torch.zeros((n_classes, n_classes), dtype=torch.int64)
test_loss_sum = 0.0
test_n = 0

with torch.no_grad():
    for x, y in test_loader:
        # TODO: <your code here>

        b = x.shape[0]
        test_loss_sum += float(loss.item()) * b
        test_n += b

        cm += confusion_matrix_from_logits(logits, y, n_classes=n_classes)

test_loss = test_loss_sum / max(1, test_n)
m = metrics_from_confusion_matrix(cm)

In [None]:
print("\nTest evaluation")
print(f"  loss:     {test_loss:.6f}")
print(f"  accuracy: {float(m['accuracy']):.6f}")
print(f"  macro_f1: {float(m['macro_f1']):.6f}")

print("\nPer-class metrics:")
for k in range(n_classes):
    p = float(m["precision_per_class"][k])
    r = float(m["recall_per_class"][k])
    f1 = float(m["f1_per_class"][k])
    sup = int(m["support_per_class"][k])
    print(
        f"  class {k:02d} ({CLASS_NAMES[k]:>11s}) | P={p:.4f} R={r:.4f} F1={f1:.4f} support={sup}"
    )

print("\nConfusion matrix (rows=true, cols=pred):")
print(np.array(m["confusion_matrix"], dtype=np.int64))

## Step 10 — Show example classifications **[1 point]**

***Task:***

- Randomly pick a small set of test indices (e.g., 12).

- Build a batch from those samples and run a forward pass in `model.eval()` + `torch.no_grad()`.

- Convert logits to:
  - predicted class (`argmax`)

  - confidence (max softmax probability)

- Visualize the selected images in a grid:
  - display the image in original pixel space (invert normalization for plotting)

  - title each with **true label**, **predicted label**, and **confidence**

  - color the title by correctness (correct vs incorrect) to spot patterns quickly

In [None]:
# TODO: <your code here>