<img src="https://hilpisch.com/tpq_logo.png" alt="The Python Quants" width="35%" align="right" border="0"><br>


# Deep Learning Basics with PyTorch

**Dr. Yves J. Hilpisch with GPT-5**


# Chapter 12 — Training at Scale

## Throughput quick check (toy)

In [None]:
import time, torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.randn(256, 64, device = device)
w = torch.randn(64, 64, device = device)
torch.cuda.synchronize() if device.type=='cuda' else None
t0 = time.time()
for _ in range(500):
    y = x @ w  # targets/labels
    torch.cuda.synchronize() if device.type=='cuda' else None
    elapsed = time.time() - t0
    elapsed
    # 0.0


## AMP training step

In [None]:
import torch, torch.nn.functional as F
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Linear(128, 10).to(device)
opt = torch.optim.SGD(model.parameters(), lr=1e-2)
use_amp = (device.type == 'cuda')
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)

x = torch.randn(32, 128, device=device)
y = torch.randint(0, 10, (32,), device=device)

opt.zero_grad(set_to_none=True)
if use_amp:
    # Mixed precision region on CUDA
    with torch.amp.autocast('cuda', enabled=True):
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        else:
            # Fallback: full precision on CPU/Metal
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            opt.step()
            opt.zero_grad(set_to_none=True)


## Gradient accumulation

In [None]:
import torch, torch.nn.functional as F
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Linear(128, 10).to(device)
opt = torch.optim.SGD(model.parameters(), lr = 1e-2)  # optimizer setup / step
accum = 4
opt.zero_grad(set_to_none = True)
for step in range(8):
    x = torch.randn(16, 128, device = device)
    y = torch.randint(0, 10, (16, ), device = device)  # targets/labels
    loss = F.cross_entropy(model(x), y) / accum  # training objective
    loss.backward()
    if (step+1) % accum == 0:
        opt.step()
        opt.zero_grad(set_to_none = True)
        # None


## Checkpoint save/load (toy)

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state = {'model': {'w': torch.randn(2)}, 'epoch': 3}
torch.save(state, 'checkpoint.pt')
ckpt = torch.load('checkpoint.pt', map_location = device)
ckpt['epoch'], isinstance(ckpt['model'], dict)
# (3, True)

<img src="https://hilpisch.com/tpq_logo.png" alt="The Python Quants" width="35%" align="right" border="0"><br>
