<a href="https://colab.research.google.com/github/rahiakela/deep-learning-research-and-practice/blob/main/pytorch-lightning-in-practice/pytorch-lightning/episode_1_training_classification_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Episode 1: Training a classification model on MNIST with PyTorch

**Reference:**

https://www.youtube.com/watch?v=OMDn66kM9Qc&list=PLaMu-SDt_RB5hhJKZC5a6HPdlDTawUT3r

In [5]:
import torch

from torch import nn
from torch import optim
from torchvision import datasets, transforms 
from torch.utils.data import random_split, DataLoader

In [6]:
# define model
model = nn.Sequential(
  nn.Linear(28 * 28, 64),
  nn.ReLU(),
  nn.Linear(64, 64),
  nn.ReLU(),
  nn.Linear(64, 10)
)

In [7]:
# define optimizer
params = model.parameters()
optimizer = optim.SGD(params, lr=1e-2)

In [8]:
# define loss
loss = nn.CrossEntropyLoss()

In [None]:
# define data loader
train_data = datasets.MNIST("data", train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

In [13]:
# define training and validation step
epochs = 5

for epoch in range(epochs):
  # training loop
  losses = list()
  for batch in train_loader:
    x, y = batch

    # x: b x 1 x 28 x 28 (B*C*W*H)
    b = x.size(0)
    x = x.view(b, -1)

    # Step 1: forward
    l = model(x)  # l: logits

    # Step 2: compute the objective function
    J = loss(l, y)

    # Step 3: cleaining the gradients
    model.zero_grad()
    # optimizer.zero_grad()
    # params.grad._zero()

    # Step 4: accumulate the partial derivative of loss wrt params
    J.backward()
    # params.grad.sum_(dL/dparams)

    # Step 5: step in the opposite direction of the gradient
    optimizer.step()
    # with torch.no_grad(): params = params - eta * params.grad
    losses.append(J.item())

  print(f"Epoch: {epoch + 1}, train loss: {torch.tensor(losses).mean():.2f}")

  # validation loop
  losses = list()
  for batch in val_loader:
    x, y = batch

    # x: b x 1 x 28 x 28 (B*C*W*H)
    b = x.size(0)
    x = x.view(b, -1)

    # Step 1: forward
    with torch.no_grad():
      l = model(x)

    # Step 2: compute the objective function
    J = loss(l, y)

    losses.append(J.item())

  print(f"Epoch: {epoch + 1}, val loss: {torch.tensor(losses).mean():.2f}")

Epoch: 1, train loss: 1.24
Epoch: 1, val loss: 0.48
Epoch: 2, train loss: 0.40
Epoch: 2, val loss: 0.36
Epoch: 3, train loss: 0.32
Epoch: 3, val loss: 0.32
Epoch: 4, train loss: 0.28
Epoch: 4, val loss: 0.29
Epoch: 5, train loss: 0.26
Epoch: 5, val loss: 0.26
