# Implementing a CNN with PyTorch

In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from tqdm import trange
from tqdm.notebook import tqdm

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.layers = nn.Sequential(
            # First set of convolutions
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Second set of convolutions
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # final size should be (7x7)
            
            # Fully connected Layer
            nn.Flatten(),
            nn.Linear(7*7*64, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
            nn.Softmax() # output
        )

    def forward(self, x):
        # conv layer 1
        return self.layers(x)

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_device()
print(f"Using device: {device}")

# Load the data
mnist_train = datasets.MNIST(root="./datasets", 
                             train=True, 
                             transform=transforms.ToTensor(), 
                             download=True)
mnist_test = datasets.MNIST(root="./datasets", 
                            train=False, 
                            transform=transforms.ToTensor(), 
                            download=True)

# Split train set into train (90%) and validation (10%)
train_size = int(0.9 * len(mnist_train))
val_size = len(mnist_train) - train_size
# If this causes errors, restart the kernel lmao
mnist_train, mnist_val = random_split(mnist_train, [train_size, val_size])

train_loader = DataLoader(mnist_train, 
                          batch_size=100, 
                          shuffle=True)
val_loader = DataLoader(mnist_val,
                        batch_size=100, 
                        shuffle=True)
test_loader = DataLoader(mnist_test, 
                         batch_size=100, 
                         shuffle=False)

model = CNN().to(device)

# Training
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train():
    def validate():
        correct = 0
        total = val_size

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                x = images
                y = model(x)
                
                predictions = torch.argmax(y, dim=1)
                correct += torch.sum((predictions == labels).float())
        
        return correct/total

    prev = 0
    for epoch in trange(20):
        curr = validate()
        print("Current validation accuracy: ", curr.item()) # type: ignore
        if curr > 0.95 and abs(curr - prev) < 0.001: 
            print("training converged")
            return
        prev = curr
        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            x = images
            y = model(images)
            
            loss(y, labels).backward()
            optimizer.step()
            

train()

# Testing
correct = 0
total = len(mnist_test)

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        x = images
        y = model(x)
        
        predictions = torch.argmax(y, dim=1)
        correct += torch.sum((predictions == labels).float())

print(f"Test accuracy {correct/total}")

Using device: mps


  0%|          | 0/20 [00:00<?, ?it/s]

Current validation accuracy:  0.09883332997560501


  0%|          | 0/540 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:09<03:04,  9.70s/it]

Current validation accuracy:  0.9660000205039978


  0%|          | 0/540 [00:00<?, ?it/s]

 10%|█         | 2/20 [00:19<02:53,  9.65s/it]

Current validation accuracy:  0.9693333506584167


  0%|          | 0/540 [00:00<?, ?it/s]

 15%|█▌        | 3/20 [00:29<02:50, 10.00s/it]

Current validation accuracy:  0.9760000109672546


  0%|          | 0/540 [00:00<?, ?it/s]

 20%|██        | 4/20 [00:39<02:36,  9.80s/it]

Current validation accuracy:  0.9816666841506958


  0%|          | 0/540 [00:00<?, ?it/s]

 25%|██▌       | 5/20 [00:49<02:28,  9.92s/it]

Current validation accuracy:  0.9794999957084656


  0%|          | 0/540 [00:00<?, ?it/s]

 30%|███       | 6/20 [01:00<02:22, 10.21s/it]

Current validation accuracy:  0.9781666398048401


  0%|          | 0/540 [00:00<?, ?it/s]

 35%|███▌      | 7/20 [01:10<02:11, 10.12s/it]

Current validation accuracy:  0.9808333516120911


  0%|          | 0/540 [00:00<?, ?it/s]

 40%|████      | 8/20 [01:20<02:00, 10.08s/it]

Current validation accuracy:  0.984000027179718


  0%|          | 0/540 [00:00<?, ?it/s]

 45%|████▌     | 9/20 [01:30<01:50, 10.09s/it]

Current validation accuracy:  0.9831666946411133
training converged





Test accuracy 0.984499990940094
