In [None]:
import torch
import torchvision
import torchvision.transforms as transforms


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)


100.0%
100.0%
100.0%
100.0%


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class PlainNN(nn.Module):
    def __init__(self):
        super(PlainNN, self).__init__()
        # Input size is 28*28 = 784
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

plain_net = PlainNN()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Number of parameters in PlainNN:", count_parameters(plain_net))


Number of parameters in PlainNN: 535818


In [None]:
class MNIST_CNN(nn.Module):
    def __init__(self):
        super(MNIST_CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)  
        x = x.view(-1, 64 * 14 * 14)  
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

cnn_net = MNIST_CNN()

print("Number of parameters in CNN:", count_parameters(cnn_net))


Number of parameters in CNN: 1625866


In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

optimizer_plain = optim.Adam(plain_net.parameters(), lr=0.001)
optimizer_cnn = optim.Adam(cnn_net.parameters(), lr=0.001)

num_epochs = 10


In [None]:
def train_model(model, optimizer, trainloader, num_epochs):
    model.train()  
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:  
                print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{i + 1}], Loss: {running_loss / 100:.4f}')
                running_loss = 0.0
    print('Finished Training')


In [None]:
print("Training Plain Neural Network:")
train_model(plain_net, optimizer_plain, trainloader, num_epochs)


In [None]:
print("\nTraining CNN:")
train_model(cnn_net, optimizer_cnn, trainloader, num_epochs)


In [None]:
def evaluate_model(model, testloader):
    model.eval()  
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

accuracy_plain = evaluate_model(plain_net, testloader)
accuracy_cnn = evaluate_model(cnn_net, testloader)

print(f'Accuracy of Plain Neural Network: {accuracy_plain:.2f}%')
print(f'Accuracy of CNN: {accuracy_cnn:.2f}%')


Accuracy of Plain Neural Network: 97.95%
Accuracy of CNN: 99.17%
