In [51]:
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
from torch import nn
from torch import optim
from torchvision import datasets
from torchvision import transforms

In [52]:
class CNN(nn.Module):
    def __init__(self, in_channels, out_size):
        super(CNN, self).__init__()
        self.c1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=3,
            stride=1,
            padding="same",
        )
        self.p1 = nn.MaxPool2d(
            kernel_size=2,
            stride=2,
            padding=0,
        )
        self.c2 = nn.Conv2d(
            in_channels=8,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding="same",
        )
        self.p2 = nn.MaxPool2d(
            kernel_size=2,
            stride=2,
            padding=0,
        )
        self.fl1 = nn.Flatten()
        self.fc1 = nn.Linear(16 * 7 * 7, out_size)


    def forward(self, x):
        a1 = F.relu(self.c1(x))
        a1 = self.p1(a1)
        a2 = F.relu(self.c2(a1))
        a2 = self.p2(a2)
        a3 = self.fl1(a2)
        out = self.fc1(a3)

        return out

In [59]:
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0

    model.eval()
    with torch.no_grad():
        bar = tqdm(loader)
        for batch, (x, y) in enumerate(bar):
            x = x.to(device=device)

            probs = model(x)
            _, y_pred = probs.max(1)

            num_correct += (y_pred == y).sum()
            num_samples += y_pred.size(0)

    return f"{(num_correct / num_samples) * 100}%"

In [57]:
batch_size = 2
lr = 1e-2
epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trans = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root="mnist", train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(root="mnist", train=False, transform=trans, download=True)

n_channels = train_dataset[0][0].shape[0]
n_classes = len(train_dataset.classes)

train_dataset = Subset(train_dataset, indices=torch.arange(100))
test_dataset = Subset(test_dataset, indices=torch.arange(20))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

cnn = CNN(n_channels, n_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(cnn.parameters(), lr=lr)

cnn.train()
for ep in range(epochs):
    bar = tqdm(train_dataloader)
    for batch, (X, y) in enumerate(bar):
        X = X.to(device)
        y = y.to(device)

        probs = cnn(X)
        loss = criterion(probs, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        bar.set_description(f"<EP>{ep+1}")
        bar.set_postfix(loss=loss)

<EP>1: 100%|██████████| 50/50 [00:00<00:00, 102.19it/s, loss=tensor(2.3408, grad_fn=<NllLossBackward0>)]
<EP>2: 100%|██████████| 50/50 [00:00<00:00, 98.96it/s, loss=tensor(0.3603, grad_fn=<NllLossBackward0>)]
<EP>3: 100%|██████████| 50/50 [00:00<00:00, 93.26it/s, loss=tensor(0.4841, grad_fn=<NllLossBackward0>)]
<EP>4: 100%|██████████| 50/50 [00:00<00:00, 95.71it/s, loss=tensor(0.0338, grad_fn=<NllLossBackward0>)] 
<EP>5: 100%|██████████| 50/50 [00:00<00:00, 98.36it/s, loss=tensor(0.4973, grad_fn=<NllLossBackward0>)] 
<EP>6: 100%|██████████| 50/50 [00:00<00:00, 94.43it/s, loss=tensor(0.0084, grad_fn=<NllLossBackward0>)]
<EP>7: 100%|██████████| 50/50 [00:00<00:00, 96.16it/s, loss=tensor(0.0008, grad_fn=<NllLossBackward0>)] 
<EP>8: 100%|██████████| 50/50 [00:00<00:00, 85.63it/s, loss=tensor(0.0046, grad_fn=<NllLossBackward0>)]
<EP>9: 100%|██████████| 50/50 [00:00<00:00, 90.21it/s, loss=tensor(0.0002, grad_fn=<NllLossBackward0>)]
<EP>10: 100%|██████████| 50/50 [00:00<00:00, 92.93it/s, loss

In [60]:
check_accuracy(train_dataloader, cnn)

100%|██████████| 50/50 [00:00<00:00, 508.53it/s]


'100.0%'

In [61]:
check_accuracy(test_dataloader, cnn)

100%|██████████| 10/10 [00:00<00:00, 374.49it/s]


'70.0%'