<img src="https://hilpisch.com/tpq_logo.png" alt="The Python Quants" width="35%" align="right" border="0"><br>


# Deep Learning Basics with PyTorch

**Dr. Yves J. Hilpisch with GPT-5**


# Chapter 8 — Organizing Code with torch.nn
Refactor the tiny MLP with nn.Module; add train/eval and checkpointing.

In [None]:
# !pip -q install torch numpy matplotlib scikit-learn
import torch, numpy as np, matplotlib.pyplot as plt
from torch import nn
plt.style.use('seaborn-v0_8') # plotting  # plotting
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
%config InlineBackend.figure_format = 'retina'


## Define model and training helpers

In [None]:
class TinyMLP(nn.Module):
    def __init__(self, in_dim=2, hidden=16, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x):
        return self.net(x)

# Prepare data (moons)
X, y = make_moons(n_samples=600, noise=0.25, random_state=0)
X_tr, X_te, y_tr, y_te = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=y
)
X_tr = torch.tensor(X_tr, dtype=torch.float32)
X_te = torch.tensor(X_te, dtype=torch.float32)
y_tr = torch.tensor(y_tr, dtype=torch.long)
y_te = torch.tensor(y_te, dtype=torch.long)

# Model, optimizer, loss
model = TinyMLP()
opt = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

def accuracy(logits, y):
    return (logits.argmax(1) == y).float().mean().item()


## Train and evaluate

In [None]:
losses = []
acc_history = []
for _ in range(50):
    model.train()
    logits = model(X_tr)  # raw model scores before softmax/sigmoid
    loss = loss_fn(logits, y_tr)  # training objective
    opt.zero_grad()
    loss.backward()
    opt.step()
    losses.append(float(loss.detach()))

    model.eval()
    with torch.no_grad():
        acc = accuracy(model(X_te), y_te)
    acc_history.append(acc)

losses[-1], acc_history[-1]


<img src="https://hilpisch.com/tpq_logo.png" alt="The Python Quants" width="35%" align="right" border="0"><br>
