# PyTorch Tutorial

## Problem statement

## Data: PyTorch Dataset and DataLoader classes

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):
    
    def __init__(self, N):
        super().__init__()
        # Generate N random samples
        self.samples = torch.rand(N, 256)
        # Generate labels
        _labels = self.samples.reshape(N, 4, -1).mean(-1)  # (N, 4)
        self.labels = _labels.argmax(1)  # (N,)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        return self.samples[index], self.labels[index]


train_loader = DataLoader(
    MyDataset(40000),
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=2
)
test_loader = DataLoader(
    MyDataset(10000),
    batch_size=256,
    shuffle=False,
    drop_last=False,
    num_workers=2
)

## Model

In [3]:
from torch import nn


class MyModel(nn.Module):

    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, out_dim)
        )

    def forward(self, x):
        x = x - 0.5
        return self.net(x)

## Running one epoch

In [4]:
from torch.nn import functional as F


def single_epoch_loop(model, loader, optimizer=None, mode='train'):
    n_samples, n_correct = 0, 0
    # Loop through my data loader
    for batch in loader:
        samples, labels = batch

        # Forward pass
        out = model(samples)

        # Loss: multi-class classification -> cross-entropy
        loss = F.cross_entropy(out, labels)  # no explicit softmax on out!

        # Backward pass
        if mode == 'train':  # don't update if not train
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Accuracy
        n_samples += len(samples)
        n_correct += (out.argmax(1) == labels).sum().item()
    return n_correct / n_samples, loss.item()

## Putting all together: our train/test script

In [5]:
from torch.optim import SGD, Adam


model = MyModel(256, 4)
# optimizer = SGD(model.parameters(), lr=1e-1)
optimizer = Adam(model.parameters(), lr=5e-4)

for epoch in range(20):
    model.train()
    train_acc, train_loss = single_epoch_loop(model, train_loader, optimizer, mode='train')

    model.eval()  # disables dropout/batchnorm update etc
    with torch.no_grad():  # disable gradient computations
        test_acc, test_loss = single_epoch_loop(model, test_loader, mode='test')
    print(
        epoch, train_acc, train_loss,
        test_acc, test_loss
    )

0 0.6664162660256411 0.7797598838806152 0.9014 0.8465317487716675
1 0.9112830528846154 0.3555120527744293 0.9583 0.37055617570877075
2 0.9362479967948718 0.2592712938785553 0.9655 0.24488228559494019
3 0.9442608173076923 0.19257815182209015 0.9695 0.2047320455312729
4 0.9473157051282052 0.22139495611190796 0.9732 0.17845094203948975
5 0.9496444310897436 0.16154563426971436 0.9726 0.14041639864444733
6 0.9524238782051282 0.16834506392478943 0.9736 0.13506561517715454
7 0.9537760416666666 0.1187991127371788 0.9741 0.14165055751800537
8 0.9550530849358975 0.10179155319929123 0.9761 0.15123522281646729
9 0.9585586939102564 0.13134028017520905 0.9741 0.1254160851240158
10 0.9584084535256411 0.1346886157989502 0.9774 0.11818530410528183
11 0.9592848557692307 0.1368156224489212 0.9767 0.11666432023048401
12 0.9622145432692307 0.11158645153045654 0.975 0.09952695667743683
13 0.962089342948718 0.10493985563516617 0.9743 0.08962585777044296
14 0.9641426282051282 0.13586169481277466 0.9745 0.0938

## Save/load model

In [6]:
# Save
torch.save({
    "epoch": epoch + 1,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "best_acc": test_acc
}, 'ckpt.pt')

# Load
ckpnt = torch.load('ckpt.pt')
model.load_state_dict(ckpnt["model_state_dict"], strict=True)
optimizer.load_state_dict(ckpnt["optimizer_state_dict"])
start_epoch = ckpnt["epoch"]
val_acc_prev_best = ckpnt['best_acc']

## What can we do better in real-world projects?