In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


In [8]:
# Define the CNN architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(14 * 14 * 4, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = x.view(-1, 14 * 14 * 4)
        x = self.fc1(x)
        return x

In [9]:
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x1f37a944890>

In [10]:
# Load the MNIST dataset
train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True)

In [11]:
# Define the dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128)

In [12]:
# Create the CNN model
model = CNN()

In [13]:
# Count the number of trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

Number of trainable parameters: 7890


In [14]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [17]:
# Train the model
# Train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

num_epochs = 20
best_accuracy = 0.0
patience = 5  # Number of epochs without improvement to trigger early stopping
counter = 0  # Counter for early stopping
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_dataloader)}], Loss: {loss.item():.4f}")

    # Evaluate the model on the validation set
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Validation Accuracy: {accuracy:.2f}%")

    # Check if the current accuracy is better than the previous best accuracy
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        counter = 0
        # Save the model weights
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        counter += 1

    # Check if early stopping criteria is met
    if counter >= patience:
        print("Early stopping triggered.")
        break

    model.train()

print(f"Best Validation Accuracy: {best_accuracy:.2f}%")

Epoch [1/20], Step [100/469], Loss: 0.0776
Epoch [1/20], Step [200/469], Loss: 0.0867
Epoch [1/20], Step [300/469], Loss: 0.0851
Epoch [1/20], Step [400/469], Loss: 0.1042
Validation Accuracy: 96.33%
Epoch [2/20], Step [100/469], Loss: 0.0673
Epoch [2/20], Step [200/469], Loss: 0.0835
Epoch [2/20], Step [300/469], Loss: 0.0490
Epoch [2/20], Step [400/469], Loss: 0.0916
Validation Accuracy: 96.54%
Epoch [3/20], Step [100/469], Loss: 0.1212
Epoch [3/20], Step [200/469], Loss: 0.0830
Epoch [3/20], Step [300/469], Loss: 0.1227
Epoch [3/20], Step [400/469], Loss: 0.1231
Validation Accuracy: 96.88%
Epoch [4/20], Step [100/469], Loss: 0.0485
Epoch [4/20], Step [200/469], Loss: 0.0851
Epoch [4/20], Step [300/469], Loss: 0.0455
Epoch [4/20], Step [400/469], Loss: 0.1161
Validation Accuracy: 96.96%
Epoch [5/20], Step [100/469], Loss: 0.0633
Epoch [5/20], Step [200/469], Loss: 0.0448
Epoch [5/20], Step [300/469], Loss: 0.0454
Epoch [5/20], Step [400/469], Loss: 0.1411
Validation Accuracy: 97.01%
