In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)

# Dropout

In [None]:
nn.Dropout

In [None]:
SEED = 0
N = 10

In [None]:
torch.manual_seed(SEED)
x = torch.ones(N)

drop = nn.Dropout(p=0.5)

# TRAIN mode: dropout is active (randomly zeros + scales remaining by 1/(1-p))
drop.train()
y_train1 = drop(x)
y_train2 = drop(x)

# EVAL mode: dropout is disabled (identity)
drop.eval()
y_eval = drop(x)

print(f"x        = {x}\n")
print(f"train #1 = {y_train1}")
print(f"train #2 = {y_train2}\n")

# NOTE: may be not 1.0, but close to it
print(f"train #1 mean = {y_train1.mean().item():.3f}")
print(f"train #2 mean = {y_train2.mean().item():.3f}\n")

print(f"eval     = {y_eval}")

# NOTE: it does PROBABILITY zeroing, but PREDETERMINED scale
# -> works good only by LLN (law of large numbers)
# -> try larger N

# BatchNorm

In [None]:
nn.BatchNorm1d

$$y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$$

- not only normalization, but `nn.Linear` within (if `affine=True` - which is default)

In [None]:
BATCHES = 100
N = 8
C = 2
BIAS = 10
STD = 2

In [None]:
torch.manual_seed(SEED)
train_batches = [torch.randn(N, C) * STD + BIAS for _ in range(BATCHES)]

train_batches[0]

In [None]:
# BatchNorm over 1D features (N, C)
bn = nn.BatchNorm1d(C, affine=False)  # keep it pure: no gamma/beta

bn(train_batches[0])

But not only current batch normalization is calculated, also running statistics for the **inference (evaluation)**

In [None]:
print(f"[BEFORE] Learned running mean: {bn.running_mean}")
print(f"[BEFORE] Learned running var : {bn.running_var}\n")

bn.train()
for xb in train_batches:
    _ = bn(xb)  # updates running_mean/running_var

print(f"[AFTER] Learned running mean: {bn.running_mean}")
print(f"[AFTER] Learned running var : {bn.running_var}")

BatchNormâ€™s running mean/var approximate the training **data distribution**.  
So using them keeps feature scaling consistent with the weights.

At **inference** you want the same normalization the model trained against.

In [None]:
TEST_BATCH_BIAS = 20
x_test = torch.randn(N, C) * STD + TEST_BATCH_BIAS

print(x_test)
print(f"\nTest batch mean: {x_test.mean(dim=0)}  (should be near {TEST_BATCH_BIAS})")

In [None]:
# > run this cell several times

print(bn(x_test))  # Wrong Inference:

print(f"Learned running mean: {bn.running_mean}")
print(f"Learned running var : {bn.running_var}")

In [None]:
# > rerun training before this cell

bn.eval()
y_eval = bn(x_test)  # Correct Inference: uses running stats from training

print(y_eval)
print(f"y mean: {y_eval.mean(dim=0)}, y std: {y_eval.std(dim=0)}")

# Data + Model + Train + Eval

In [None]:
seed = 0
batch_size = 128
epochs = 8
lr = 1e-3
val_size = 5000

torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Data

In [None]:
tfm = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),  # MNIST mean/std
    ]
)

train_full = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=tfm)

train_ds, val_ds = random_split(
    train_full,
    [len(train_full) - val_size, val_size],
    generator=torch.Generator().manual_seed(seed),
)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=1)

In [None]:
@torch.no_grad()
def show_mnist(x, y, pred=None, n=12):
    x_vis = (x.cpu() * 0.3081 + 0.1307).clamp(0, 1)

    cols = 6
    rows = (n + cols - 1) // cols
    plt.figure(figsize=(cols * 2, rows * 2))
    for i in range(n):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(x_vis[i, 0], cmap="gray")
        prediction = f", p={pred[i].item()}" if pred is not None else ""
        plt.title(f"y={y[i].item()}{prediction}")
        plt.axis("off")
    plt.tight_layout()
    plt.show()


x, y = next(iter(train_loader))
show_mnist(x, y)

## 2. Model

In [None]:
class ClassifierMLP(nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.drop1 = nn.Dropout(p=0.2)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.drop2 = nn.Dropout(p=0.2)

        # why 10? <- logits for each class of [0, 1, ..., 9]
        self.fc3 = nn.Linear(hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        return x


model = ClassifierMLP().to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)

## 3. Train

In [None]:
@torch.no_grad()
def eval_loss_and_acc(loader):
    model.eval()
    total_loss = 0.0
    correct = 0
    n = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits, y, reduction="sum")
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        n += y.size(0)
    return total_loss / n, correct / n

In [None]:
train_losses, val_losses = [], []

for epoch in range(epochs):
    model.train()
    running = 0.0
    n = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        logits = model(x)
        loss = F.cross_entropy(logits, y)

        loss.backward()
        opt.step()
        opt.zero_grad()

        running += loss.item() * y.size(0)
        n += y.size(0)

    train_loss = running / n
    val_loss, val_acc = eval_loss_and_acc(val_loader)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    clear_output()
    plt.figure(figsize=(10, 4))
    plt.plot(range(epoch + 1), train_losses, label="train")
    plt.plot(range(epoch + 1), val_losses, label="val")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.title("MNIST MLP: train vs val loss")
    plt.legend()
    plt.tight_layout()
    plt.show()

## 4. Test

In [None]:
test_loss, test_acc = eval_loss_and_acc(test_loader)
print(f"TEST | loss={test_loss:.4f} | acc={test_acc*100:.2f}%")

In [None]:
model.eval()

x, y = next(iter(test_loader))
x, y = x.to(device), y.to(device)
logits = model(x)
pred = logits.argmax(dim=1)

show_mnist(x, y, pred)