# Implementing a CNN with PyTorch

In [10]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from tqdm import trange
from tqdm.notebook import tqdm

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            # First set of convolutions
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Second set of convolutions
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # final size should be (7x7)
            
            # Fully connected Layer
            nn.Flatten(),
            nn.Linear(7*7*64, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
            nn.Softmax() # output
        )

    def forward(self, x):
        # conv layer 1
        return self.layers(x)

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_device()
print(f"Using device: {device}")

# Load the data
mnist_train = datasets.MNIST(root="./datasets", 
                             train=True, 
                             transform=transforms.ToTensor(), 
                             download=True)
mnist_test = datasets.MNIST(root="./datasets", 
                            train=False, 
                            transform=transforms.ToTensor(), 
                            download=True)

# Split train set into train (90%) and validation (10%)
train_size = int(0.9 * len(mnist_train))
val_size = len(mnist_train) - train_size
# If this causes errors, restart the kernel lmao
mnist_train, mnist_val = random_split(mnist_train, [train_size, val_size])

train_loader = DataLoader(mnist_train, 
                          batch_size=100, 
                          shuffle=True)
val_loader = DataLoader(mnist_val,
                        batch_size=100, 
                        shuffle=True)
test_loader = DataLoader(mnist_test, 
                         batch_size=100, 
                         shuffle=False)

model = CNN().to(device)
print(model)

# Training
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train():
    def validate():
        correct = 0
        total = val_size

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                x = images
                y = model(x)
                
                predictions = torch.argmax(y, dim=1)
                correct += torch.sum((predictions == labels).float())
        
        return correct/total

    prev = 0
    for epoch in trange(20):
        curr = validate()
        print("Current validation accuracy: ", curr.item()) # type: ignore
        if curr > 0.95 and abs(curr - prev) < 0.001: 
            print("training converged")
            return
        prev = curr
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            x = images
            y = model(images)
            
            loss(y, labels).backward()
            optimizer.step()
            

train()

# Testing
correct = 0
total = len(mnist_test)

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        x = images
        y = model(x)
        
        predictions = torch.argmax(y, dim=1)
        correct += torch.sum((predictions == labels).float())

print(f"Test accuracy {correct/total}")

Using device: mps
CNN(
  (layers): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=3136, out_features=256, bias=True)
    (12): ReLU()
    (13): Linear(in_features=256, out_features=10, bias=True)
    (14): Softmax(dim=None)
  )
)


  0%|          | 0/20 [00:00<?, ?it/s]

Current validation accuracy:  0.10983332991600037


  0%|          | 0/540 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:11<03:33, 11.25s/it]

Current validation accuracy:  0.9639999866485596


  0%|          | 0/540 [00:00<?, ?it/s]

 10%|█         | 2/20 [00:21<03:13, 10.72s/it]

Current validation accuracy:  0.972000002861023


  0%|          | 0/540 [00:00<?, ?it/s]

 15%|█▌        | 3/20 [00:32<03:05, 10.91s/it]

Current validation accuracy:  0.9778333306312561


  0%|          | 0/540 [00:00<?, ?it/s]

 20%|██        | 4/20 [00:44<02:57, 11.07s/it]

Current validation accuracy:  0.9816666841506958


  0%|          | 0/540 [00:00<?, ?it/s]

 25%|██▌       | 5/20 [00:56<02:54, 11.66s/it]

Current validation accuracy:  0.9776666760444641


  0%|          | 0/540 [00:00<?, ?it/s]

 30%|███       | 6/20 [01:09<02:50, 12.20s/it]

Current validation accuracy:  0.9735000133514404


  0%|          | 0/540 [00:00<?, ?it/s]

 35%|███▌      | 7/20 [01:21<02:35, 11.95s/it]

Current validation accuracy:  0.984666645526886


  0%|          | 0/540 [00:00<?, ?it/s]

 40%|████      | 8/20 [01:32<02:19, 11.65s/it]

Current validation accuracy:  0.9778333306312561


  0%|          | 0/540 [00:00<?, ?it/s]

 45%|████▌     | 9/20 [01:45<02:12, 12.00s/it]

Current validation accuracy:  0.9791666865348816


  0%|          | 0/540 [00:00<?, ?it/s]

 50%|█████     | 10/20 [01:58<02:03, 12.32s/it]

Current validation accuracy:  0.984333336353302


  0%|          | 0/540 [00:00<?, ?it/s]

 55%|█████▌    | 11/20 [02:11<01:53, 12.59s/it]

Current validation accuracy:  0.9789999723434448


  0%|          | 0/540 [00:00<?, ?it/s]

 60%|██████    | 12/20 [02:22<01:37, 12.13s/it]

Current validation accuracy:  0.9860000014305115


  0%|          | 0/540 [00:00<?, ?it/s]

 65%|██████▌   | 13/20 [02:33<01:22, 11.84s/it]

Current validation accuracy:  0.9856666922569275
training converged





Test accuracy 0.9869999885559082
