In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [22]:
class CNN(nn.Module):
    def __init__(self,in_channels = 1, num_classes = 10):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=8,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.pool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
        self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.fc1 = nn.Linear(16*7*7,num_classes) # 16 = output channels on conv2, 7 = 28/2/2 (maxpool2d is applied 2 times and divides the tensor in half each time)

    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0],-1) # flatten tensor to go on fully connected layer
        x = self.fc1(x)
        return x

In [25]:
model = CNN()
x = torch.randn(64,1,28,28)
model(x).shape

torch.Size([64, 10])

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

in_channels = 1
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 1

train_dataset = datasets.MNIST(root='dataset/',train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(root='dataset/',train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)

model = CNN(in_channels=in_channels, num_classes=num_classes).to(device)

In [29]:
for epoch in range(num_epochs):
    for batch_idx, (data,targets) in enumerate(train_loader):
        data = data.to(device=device)
        targets = targets.to(device=device)

        scores = model(data.to(device)) # call forward
        loss = criterion(scores,targets) # compute loss

        optimizer.zero_grad() # set model weights gradients to 0 for each batch
        loss.backward() # backpropagation (compute gradients)

        optimizer.step() # update weights based on gradients

In [31]:
def check_accuracy(loader,model):
    if loader.dataset.train:
        print("checking on training data")
    else:
        print("checking on test data")
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad(): # don't compute gradients on testing
        for x,y in loader:
            x = x.to(device)
            y = y.to(device)
        
            scores = model(x)
            _,predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
            print(f"correct: {num_correct}, total: {num_samples}, accuracy: {num_correct/num_samples}")

    model.train()
    return num_correct/num_samples

check_accuracy(train_loader,model)

checking on training data
correct: 9, total: 64, accuracy: 0.140625
correct: 21, total: 128, accuracy: 0.1640625
correct: 33, total: 192, accuracy: 0.171875
correct: 43, total: 256, accuracy: 0.16796875
correct: 54, total: 320, accuracy: 0.16875000298023224
correct: 61, total: 384, accuracy: 0.1588541716337204
correct: 71, total: 448, accuracy: 0.1584821492433548
correct: 84, total: 512, accuracy: 0.1640625
correct: 95, total: 576, accuracy: 0.1649305522441864
correct: 104, total: 640, accuracy: 0.16250000894069672
correct: 112, total: 704, accuracy: 0.15909090638160706
correct: 130, total: 768, accuracy: 0.1692708432674408
correct: 137, total: 832, accuracy: 0.16466346383094788
correct: 153, total: 896, accuracy: 0.1707589328289032
correct: 162, total: 960, accuracy: 0.16875000298023224
correct: 170, total: 1024, accuracy: 0.166015625
correct: 179, total: 1088, accuracy: 0.16452206671237946
correct: 183, total: 1152, accuracy: 0.1588541716337204
correct: 192, total: 1216, accuracy: 0.

tensor(0.1589, device='cuda:0')