<a href="https://colab.research.google.com/github/ysuter/FHNW-BAI-DeepLearning/blob/main/Beispiele/mnist_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MNIST (PyTorch) — Colab with **Validation**, **Confusion Matrix**, and **Draw-to-Classify**
Adapted from the official [PyTorch MNIST example](https://github.com/pytorch/examples/blob/main/mnist/main.py), extended with:
- **Train/Val/Test** split (e.g., 50k / 10k / 10k)
- Best-on-**validation** checkpoint
- **Confusion Matrix** on test set
- Interactive **canvas** to draw a digit and classify with the trained model


In [None]:
#@title Install/Check dependencies (Colab usually has these pre-installed)
!pip -q install torch torchvision matplotlib tqdm pillow --upgrade

## Modelldefinition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#  Model
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64*12*12, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = Net().to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {n_params:,}")

In [None]:
#@title MNIST: data, model, training (with validation), testing, and visualization
import os, math, random
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

#  Config
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

@dataclass
class Config:
    batch_size: int = 64
    test_batch_size: int = 1000
    epochs: int = 3
    lr: float = 1.0
    gamma: float = 0.7
    log_interval: int = 100
    val_size: int = 10000
    use_amp: bool = True
    num_workers: int = 2

cfg = Config()

#  Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

root = "./data"
full_train = datasets.MNIST(root, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root, train=False, download=True, transform=transform)

val_size = cfg.val_size
train_size = len(full_train) - val_size
generator = torch.Generator().manual_seed(SEED)
train_ds, val_ds = random_split(full_train, [train_size, val_size], generator=generator)

train_loader_kwargs = dict(batch_size=cfg.batch_size, shuffle=True)
eval_loader_kwargs = dict(batch_size=cfg.test_batch_size, shuffle=False)
if device.type == "cuda":
    train_loader_kwargs.update(num_workers=cfg.num_workers, pin_memory=True)
    eval_loader_kwargs.update(num_workers=cfg.num_workers, pin_memory=True)

train_loader = DataLoader(train_ds, **train_loader_kwargs)
val_loader   = DataLoader(val_ds,   **eval_loader_kwargs)
test_loader  = DataLoader(mnist_test, **eval_loader_kwargs)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(mnist_test)}")

#  Optimizer & Scheduler
optimizer = torch.optim.Adadelta(model.parameters(), lr=cfg.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=cfg.gamma)
scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp and device.type == "cuda")
criterion = nn.NLLLoss()

#  Train / Eval
train_history = {"loss": []}
val_history = {"loss": [], "acc": []}

def train_epoch(epoch):
    model.train()
    running = 0.0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Train Epoch {epoch}")
    for batch_idx, (data, target) in pbar:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=scaler.is_enabled()):
            output = model(data)
            loss = criterion(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running += loss.item() * data.size(0)
        if batch_idx % cfg.log_interval == 0:
            pbar.set_postfix(loss=f"{loss.item():.4f}")
    avg = running / len(train_loader.dataset)
    train_history["loss"].append(avg)
    print(f"Train avg loss: {avg:.4f}")
    return avg

@torch.no_grad()
def evaluate(loader, split_name="Val"):
    model.eval()
    total_loss, correct = 0.0, 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        total_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
    avg_loss = total_loss / len(loader.dataset)
    acc = correct / len(loader.dataset)
    print(f"{split_name}: avg loss {avg_loss:.4f}, acc {acc*100:.2f}% ({correct}/{len(loader.dataset)})")
    if split_name.lower().startswith("val"):
        val_history["loss"].append(avg_loss)
        val_history["acc"].append(acc)
    return avg_loss, acc

#  Train Loop with best-on-val checkpoint
BEST_PATH = "best_mnist_val.pt"
best_val_acc = -1.0

for epoch in range(1, cfg.epochs + 1):
    _ = train_epoch(epoch)
    _, val_acc = evaluate(val_loader, "Val")
    scheduler.step()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), BEST_PATH)
        print(f"✅ Saved new best model (val acc={best_val_acc*100:.2f}%)")

print(f"Best validation accuracy: {best_val_acc*100:.2f}%")

# ---------------------- Plots ----------------------
plt.figure(figsize=(12,4))
plt.subplot(1,2,1); plt.plot(train_history["loss"], label="train"); plt.plot(val_history["loss"], label="val"); plt.title("Loss"); plt.grid(True); plt.legend()
plt.subplot(1,2,2); plt.plot(val_history["acc"], label="val acc"); plt.title("Validation Accuracy"); plt.grid(True); plt.legend()
plt.show()

# ---------------------- Test evaluation ----------------------
model.load_state_dict(torch.load(BEST_PATH, map_location=device))
_ = evaluate(test_loader, "Test")

# ---------------------- Visualize predictions ----------------------
@torch.no_grad()
def show_batch_predictions(loader, n_images=16):
    model.eval()
    data, target = next(iter(loader))
    data, target = data.to(device), target.to(device)
    output = model(data)
    preds = output.argmax(dim=1)
    imgs = data.cpu()
    cols = 8
    rows = math.ceil(n_images / cols)
    plt.figure(figsize=(cols*2, rows*2))
    for i in range(n_images):
        if i >= imgs.size(0): break
        img = imgs[i,0]
        img = img * 0.3081 + 0.1307  # denormalize
        plt.subplot(rows, cols, i+1)
        plt.imshow(img.numpy(), cmap="gray")
        plt.title(f"P:{preds[i].item()} T:{target[i].item()}", fontsize=10)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

show_batch_predictions(test_loader, n_images=16)


In [None]:
#@title Confusion Matrix on the Test Set
import numpy as np
import matplotlib.pyplot as plt
import torch

@torch.no_grad()
def compute_confusion_matrix(model, loader, device, num_classes=10, normalize=True, eps=1e-12):
    """
    Returns (cm_raw, cm_norm) where:
      - cm_raw: integer counts (num_classes x num_classes), rows=true, cols=pred
      - cm_norm: row-normalized (each row sums to 1.0) if normalize=True, else None
    """
    model.eval()
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)

    for data, target in loader:
        data, target = data.to(device), target.to(device)
        pred = model(data).argmax(dim=1)
        for t, p in zip(target.view(-1), pred.view(-1)):
            cm[int(t.item()), int(p.item())] += 1

    if normalize:
        row_sums = cm.sum(axis=1, keepdims=True)
        cm_norm = cm / np.maximum(row_sums, eps)
        return cm, cm_norm
    else:
        return cm, None

model.load_state_dict(torch.load("best_mnist_val.pt", map_location=device))
cm_raw, cm_norm = compute_confusion_matrix(model, test_loader, device, num_classes=10)

fig, ax = plt.subplots(figsize=(6,5))
im = ax.imshow(cm_norm, interpolation='nearest')
ax.set_title('Confusion Matrix (Test)')
ax.set_xlabel('Predicted label')
ax.set_ylabel('True label')
ax.set_xticks(range(10)); ax.set_yticks(range(10))
ax.set_xticklabels(list(range(10))); ax.set_yticklabels(list(range(10)))
for i in range(10):
    for j in range(10):
        ax.text(j, i, int(cm_raw[i, j]), ha='center', va='center', fontsize=8)
plt.tight_layout()
plt.show()

totals = cm_raw.sum(axis=1).astype(np.float64)
with np.errstate(divide='ignore', invalid='ignore'):
    per_class_acc = np.divide(np.diag(cm_raw), totals, out=np.zeros_like(totals), where=totals>0)
print("Per-class accuracy:", np.round(per_class_acc, 4))
print("Overall accuracy:", round(np.trace(cm_raw) / cm_raw.sum(), 4))


In [None]:
#@title Draw a digit (0–9) and classify with the trained model
# Draw white digit on black canvas. Click Predict to send to Python for classification.
from google.colab.output import eval_js
from IPython.display import display, Javascript
from base64 import b64decode
import io
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt

model.eval()
try:
    model.load_state_dict(torch.load("best_mnist_val.pt", map_location=device))
except Exception as e:
    print("Warning: could not load best_mnist_val.pt. Train the model first.", e)

js = Javascript('''
async function draw_digit() {
  return await new Promise((resolve) => {
    const div = document.createElement('div');
    div.style.margin = '8px 0';
    const title = document.createElement('div');
    title.textContent = 'Draw a digit (white on black). Click Predict.';
    title.style.margin = '6px 0';
    const canvas = document.createElement('canvas');
    canvas.width = 280; canvas.height = 280;
    canvas.style.border = '1px solid #999';
    canvas.style.touchAction = 'none';
    const ctx = canvas.getContext('2d');
    ctx.fillStyle = 'black'; ctx.fillRect(0,0,canvas.width,canvas.height);
    let drawing=false;
    function rel(e){const r=canvas.getBoundingClientRect(); return {x:e.clientX-r.left, y:e.clientY-r.top};}
    canvas.addEventListener('mousedown', e=>{drawing=true; const p=rel(e); ctx.lineWidth=20; ctx.lineCap='round'; ctx.strokeStyle='white'; ctx.beginPath(); ctx.moveTo(p.x,p.y);});
    canvas.addEventListener('mousemove', e=>{if(!drawing)return; const p=rel(e); ctx.lineTo(p.x,p.y); ctx.stroke();});
    canvas.addEventListener('mouseup', ()=>{drawing=false;});
    canvas.addEventListener('mouseleave', ()=>{drawing=false;});
    canvas.addEventListener('touchstart', e=>{const r=canvas.getBoundingClientRect(); const t=e.touches[0]; drawing=true; ctx.lineWidth=20; ctx.lineCap='round'; ctx.strokeStyle='white'; ctx.beginPath(); ctx.moveTo(t.clientX-r.left,t.clientY-r.top); e.preventDefault();},{passive:false});
    canvas.addEventListener('touchmove', e=>{if(!drawing)return; const r=canvas.getBoundingClientRect(); const t=e.touches[0]; ctx.lineTo(t.clientX-r.left,t.clientY-r.top); ctx.stroke(); e.preventDefault();},{passive:false});
    canvas.addEventListener('touchend', ()=>{drawing=false;},{passive:false});
    const btns=document.createElement('div'); btns.style.marginTop='6px';
    const clearBtn=document.createElement('button'); clearBtn.textContent='Clear';
    const predictBtn=document.createElement('button'); predictBtn.textContent='Predict';
    clearBtn.style.marginRight='6px';
    clearBtn.onclick=()=>{ctx.fillStyle='black'; ctx.fillRect(0,0,canvas.width,canvas.height);};
    predictBtn.onclick=()=>{const dataURL=canvas.toDataURL('image/png'); resolve(dataURL); document.body.removeChild(div);};
    btns.appendChild(clearBtn); btns.appendChild(predictBtn);
    div.appendChild(title); div.appendChild(canvas); div.appendChild(btns);
    document.body.appendChild(div);
  });
}
''')
display(js)
dataURL = eval_js('draw_digit()')

if dataURL.startswith('data:image/png;base64,'):
    b64 = dataURL.split(',')[1]
else:
    b64 = dataURL

img = Image.open(io.BytesIO(b64decode(b64))).convert('L')
img_small = img.resize((28,28), Image.BILINEAR)
arr = np.array(img_small).astype(np.float32)/255.0
x = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
x = (x - 0.1307) / 0.3081
x = x.to(device)

with torch.no_grad():
    out = model(x)
    probs = torch.softmax(out, dim=1)
    pred = torch.argmax(probs, dim=1).item()

print(f"Predicted digit: {pred}")
print("Probabilities:", probs.squeeze(0).cpu().numpy())

# Plot the processed input + confidence bar chart
fig, axs = plt.subplots(1,2, figsize=(8,3))
axs[0].imshow(arr, cmap='gray')
axs[0].set_title(f'Input (Pred: {pred})')
axs[0].axis('off')
axs[1].bar(range(10), probs.cpu().numpy().squeeze())
axs[1].set_xticks(range(10))
axs[1].set_ylim([0,1])
axs[1].set_title("Confidence per class")
plt.show()
