In [1]:
import os
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms

class MLP(nn.Module):
  '''
    Multilayer Perceptron.
  '''
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Flatten(),
      nn.Linear(28 * 28 * 1, 64),
      nn.ReLU(),
      nn.Linear(64, 32),
      nn.ReLU(),
      nn.Linear(32, 10),
      nn.LogSoftmax(dim = 1)
    )


  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)
  
  
if __name__ == '__main__':
  
  # Set fixed random number seed
  torch.manual_seed(42)
  
  # Prepare MNIST dataset
  dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
  trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1)
  
  # Initialize the MLP
  mlp = MLP()
  
  # Define the loss function and optimizer
  loss_function = nn.NLLLoss()
  optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
  
  # Run the training loop
  for epoch in range(0, 5): # 5 epochs at maximum
    
    # Print epoch
    print(f'Starting epoch {epoch+1}')
    
    # Set current loss value
    current_loss = 0.0
    
    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):
      
      # Get inputs
      inputs, targets = data
      
      # Zero the gradients
      optimizer.zero_grad()
      
      # Perform forward pass
      outputs = mlp(inputs)
      
      # Compute loss
      loss = loss_function(outputs, targets)
      
      # Perform backward pass
      loss.backward()
      
      # Perform optimization
      optimizer.step()
      
      # Print statistics
      current_loss += loss.item()
      if i % 500 == 499:
          print('Loss after mini-batch %5d: %.3f' %
                (i + 1, current_loss / 500))
          current_loss = 0.0

  # Process is complete.
  print('Training process has finished.')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 9760603.72it/s] 


Extracting /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/train-images-idx3-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 47207986.68it/s]

Extracting /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 8408347.93it/s]


Extracting /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 8585186.47it/s]

Extracting /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/ngocp/Documents/projects/pyml/viblo-classify/research/MNIST/raw

Starting epoch 1





Loss after mini-batch   500: 1.987
Loss after mini-batch  1000: 1.118
Loss after mini-batch  1500: 0.699
Loss after mini-batch  2000: 0.561
Loss after mini-batch  2500: 0.490
Loss after mini-batch  3000: 0.451
Loss after mini-batch  3500: 0.420
Loss after mini-batch  4000: 0.409
Loss after mini-batch  4500: 0.402
Loss after mini-batch  5000: 0.374
Loss after mini-batch  5500: 0.368
Loss after mini-batch  6000: 0.340
Starting epoch 2
Loss after mini-batch   500: 0.315
Loss after mini-batch  1000: 0.337
Loss after mini-batch  1500: 0.316
Loss after mini-batch  2000: 0.312
Loss after mini-batch  2500: 0.290
Loss after mini-batch  3000: 0.322
Loss after mini-batch  3500: 0.288
Loss after mini-batch  4000: 0.293
Loss after mini-batch  4500: 0.294
Loss after mini-batch  5000: 0.291
Loss after mini-batch  5500: 0.294
Loss after mini-batch  6000: 0.270
Starting epoch 3
Loss after mini-batch   500: 0.265
Loss after mini-batch  1000: 0.287
Loss after mini-batch  1500: 0.255
Loss after mini-batch