In [78]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

In [95]:
DATA_PATH = 'dataset/'
batch_size = 32
learning_rate = 0.001
epochs = 32

In [80]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root=DATA_PATH, train=True, transform=trans, download=True)
test_dataset = datasets.MNIST(root=DATA_PATH, train=False, transform=trans)

In [81]:
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [82]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(7*7*64, 1000)
        self.fc2 = nn.Linear(1000, 10)
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [83]:
model = CNNModel()
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
loss_history = []
accuracy_history = []
batches = len(train_loader)
for j in range(0, epochs):
    for i, (images, labels) in enumerate(train_loader):
        outputs = model(images)
        error = loss(outputs, labels)
        loss_history.append(error.item())
        
        # Backprop and perform Adam optimisation
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
        
        # Track the accuracy
        total = labels.size(0)
        _, predicted = torch.max(y.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy_history.append(correct / total)

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
                  .format(j + 1, epochs, i + 1, batches, error.item(),
                          (correct / total) * 100))

Epoch [1/32], Step [100/1875], Loss: 2.3168, Accuracy: 12.50%
Epoch [1/32], Step [200/1875], Loss: 2.2899, Accuracy: 15.62%
Epoch [1/32], Step [300/1875], Loss: 2.3029, Accuracy: 9.38%
