In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
# Hyper parameters
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001

In [4]:
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 67506188.29it/s]


Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 43921571.36it/s]


Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 29985264.66it/s]


Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3520053.36it/s]


Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw



first layer performs convolution to extract features, typically smaller features to recognize detailed-patterns sort of,

applies batch normalization for stable training, introduces non-linearity using ReLU activation, and downsamples through max-pooling.

downsampling means taking the maximum value of the feature map because the maximum value in each window represents the most activated or significant feature in that region, ensuring that important features are retained.

The second layer continues like the first layer but with more number of filters that kind of combines the patterns recognized in the first layer and tries to output new bigger patterns..and performs furthur more stablization, ReLU activation, and downsampling.

Finally, the third layer, the fully connected layer connects the flattened output from the convolutional layers to the specified number of output classes, enabling the network to make predictions and classify input images based on the learned features.

In [5]:
# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)

In [6]:
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

Epoch [1/5], Step [100/600], Loss: 0.0601
Epoch [1/5], Step [200/600], Loss: 0.0812
Epoch [1/5], Step [300/600], Loss: 0.1294
Epoch [1/5], Step [400/600], Loss: 0.0247
Epoch [1/5], Step [500/600], Loss: 0.0191
Epoch [1/5], Step [600/600], Loss: 0.0081
Epoch [2/5], Step [100/600], Loss: 0.0373
Epoch [2/5], Step [200/600], Loss: 0.0642
Epoch [2/5], Step [300/600], Loss: 0.1292
Epoch [2/5], Step [400/600], Loss: 0.0399
Epoch [2/5], Step [500/600], Loss: 0.1057
Epoch [2/5], Step [600/600], Loss: 0.0336
Epoch [3/5], Step [100/600], Loss: 0.0289
Epoch [3/5], Step [200/600], Loss: 0.0195
Epoch [3/5], Step [300/600], Loss: 0.0164
Epoch [3/5], Step [400/600], Loss: 0.1026
Epoch [3/5], Step [500/600], Loss: 0.0102
Epoch [3/5], Step [600/600], Loss: 0.0617
Epoch [4/5], Step [100/600], Loss: 0.0540
Epoch [4/5], Step [200/600], Loss: 0.0243
Epoch [4/5], Step [300/600], Loss: 0.0178
Epoch [4/5], Step [400/600], Loss: 0.0136
Epoch [4/5], Step [500/600], Loss: 0.0340
Epoch [4/5], Step [600/600], Loss:

In [8]:
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')