In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [12]:
# Define the Neural Network Architecture
class SimpleNN(nn.Module):
    """
    A simple feed-forward neural network for image classification.
    """
    def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size1),
            nn.ReLU(),
            nn.Linear(hidden_size1, hidden_size2),
            nn.ReLU(),
            nn.Linear(hidden_size2, num_classes)
        )

    def forward(self, x):
        """
        Defines the forward pass of the neural network.
        """
        x = self.flatten(x)
        logits = self.network(x)

        return logits


In [13]:
# Set up Hyperparameters and Device
INPUT_SIZE = 784
HIDDEN_SIZE1 = 128
HIDDEN_SIZE2 = 64
NUM_CLASSES = 10
LEARNING_RATE = 0.001
BATCH_SIZE = 64
NUM_EPOCHS = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [14]:
# Load and prepare data 
# Define a transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the datasets
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

In [15]:
# Create DataLoaders
train_loaders = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [16]:
# Initialize the Model, Loss Function and Optimizer
model = SimpleNN(INPUT_SIZE, HIDDEN_SIZE1, HIDDEN_SIZE2, NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [17]:
# Train the Network
print("\n Starting Training... \n")
for epoch in range(NUM_EPOCHS):
    for i, (images, labels) in enumerate(train_loaders):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loaders)}], Loss: {loss.item():.4f}')

print("Finished training epoch", epoch + 1)


 Starting Training... 

Epoch [1/5], Step [100/938], Loss: 0.5558
Epoch [1/5], Step [200/938], Loss: 0.5109
Epoch [1/5], Step [300/938], Loss: 0.2030
Epoch [1/5], Step [400/938], Loss: 0.3530
Epoch [1/5], Step [500/938], Loss: 0.3674
Epoch [1/5], Step [600/938], Loss: 0.1706
Epoch [1/5], Step [700/938], Loss: 0.2283
Epoch [1/5], Step [800/938], Loss: 0.1956
Epoch [1/5], Step [900/938], Loss: 0.3906
Epoch [2/5], Step [100/938], Loss: 0.2362
Epoch [2/5], Step [200/938], Loss: 0.2176
Epoch [2/5], Step [300/938], Loss: 0.2065
Epoch [2/5], Step [400/938], Loss: 0.0483
Epoch [2/5], Step [500/938], Loss: 0.2332
Epoch [2/5], Step [600/938], Loss: 0.1448
Epoch [2/5], Step [700/938], Loss: 0.2084
Epoch [2/5], Step [800/938], Loss: 0.1631
Epoch [2/5], Step [900/938], Loss: 0.1146
Epoch [3/5], Step [100/938], Loss: 0.2169
Epoch [3/5], Step [200/938], Loss: 0.1256
Epoch [3/5], Step [300/938], Loss: 0.0958
Epoch [3/5], Step [400/938], Loss: 0.1063
Epoch [3/5], Step [500/938], Loss: 0.2103
Epoch [3/

In [18]:
# Model evaluation
model.eval()
with torch.no_grad():
  correct = 0
  total = 0
  for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)

    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)

    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  print(f'Accuracy of the network on the 10000 test images : {accuracy: .2f} %')

Accuracy of the network on the 10000 test images :  96.99 %
