In [None]:
import torch
import torchvision # torch package for vision related things
import torch.nn.functional as F # Parameterless functions, like (some) activation functions
import torchvision.datasets as datasets # Standard datasets
import torchvision.transforms as transforms  # Transformations we can perform on our dataset for augmentation
from torch import optim # For optimizers like SGD, Adam, etc.
from torch import nn # All neural network modules
from torch.utils.data import DataLoader # Gives easier dataset managment by creating mini batches etc.
from tqdm import tqdm # For nice progress bar!

#Create a simple CNN

In [None]:
class CNN(nn.Module):
  def __init__(self, in_channels = 1, num_classes = 10):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(in_channels = in_channels,
                           out_channels = 8,
                           kernel_size = (3,3),
                           stride = (1, 1),
                           padding = (1, 1))
    self.pool = nn.MaxPool2d(kernel_size = (2,2),
                             stride = (2,2))
    self.conv2 = nn.Conv2d(in_channels = 8,
                           out_channels = 16,
                           kernel_size = (3,3),
                           stride = (1,1),
                           padding = (1,1))
    self.fc1 = nn.Linear(16*7*7, num_classes)
  
  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    return x

In [None]:
model = CNN()
x = torch.rand(64, 1, 28, 28)
model(x)

tensor([[ 0.1598, -0.1109,  0.1365, -0.1395,  0.3723,  0.1270,  0.3142,  0.3179,
         -0.0734,  0.3381],
        [ 0.1562, -0.1088,  0.1550, -0.1099,  0.3849,  0.1653,  0.3104,  0.2850,
         -0.0832,  0.3145],
        [ 0.1295, -0.0268,  0.1642, -0.1646,  0.3362,  0.1110,  0.3330,  0.2690,
         -0.0771,  0.3446],
        [ 0.1377, -0.0618,  0.1537, -0.1351,  0.4253,  0.1572,  0.2920,  0.3117,
         -0.0487,  0.3546],
        [ 0.1352, -0.0534,  0.1849, -0.1484,  0.3523,  0.1609,  0.3119,  0.2734,
         -0.0937,  0.3518],
        [ 0.1491, -0.0953,  0.1642, -0.1534,  0.3865,  0.1609,  0.3208,  0.2479,
         -0.1090,  0.3767],
        [ 0.1672, -0.0915,  0.1798, -0.1574,  0.3782,  0.1834,  0.3477,  0.2824,
         -0.0813,  0.3165],
        [ 0.1542, -0.0871,  0.1331, -0.1350,  0.3777,  0.1477,  0.3355,  0.2729,
         -0.0784,  0.3358],
        [ 0.1585, -0.1050,  0.1230, -0.1419,  0.3790,  0.1525,  0.2880,  0.2579,
         -0.0896,  0.3381],
        [ 0.1732, -

In [None]:
model(x).shape

torch.Size([64, 10])

Set device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Hyperparameters

In [None]:
in_channels = 1
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 5

Load data

In [None]:
train_dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



Initialize Network

In [None]:
model = CNN(in_channels = in_channels, num_classes = num_classes).to(device)

Loss and optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

Train Network

In [None]:
for epoch in range(num_epochs):
  for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
    #Get data to cuda if possible
    data = data.to(device = device)
    targets = targets.to(device = device)

    #forward
    scores = model(data)
    loss = criterion(scores, targets)

    #backward
    optimizer.zero_grad()
    loss.backward()

    #gradient descent to adam step
    optimizer.step()

100%|██████████| 938/938 [00:11<00:00, 82.72it/s]
100%|██████████| 938/938 [00:11<00:00, 83.68it/s]
100%|██████████| 938/938 [00:11<00:00, 84.87it/s]
100%|██████████| 938/938 [00:11<00:00, 82.73it/s]
100%|██████████| 938/938 [00:11<00:00, 83.46it/s]


Check accuracy on training & test to see how good our model

In [None]:
def check_accuracy(loader, model):
  num_correct = 0
  num_samples = 0
  model.eval()

  with torch.no_grad():
    for x, y in loader:
      x = x.to(device = device)
      y = y.to(device = device)
      scores = model(x)
      _, predictions = scores.max(1)
      num_correct += (predictions == y).sum()
      num_samples += predictions.size(0) 

  model.train()
  return num_correct / num_samples

In [None]:
print(f"Accuracy on training set: {check_accuracy(train_loader, model)*100:.2f}")
print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")

Accuracy on training set: 98.74
Accuracy on test set: 98.41
