In [1]:
import torch 
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import math


In [2]:
## Dataset ##

# Standard transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

mnist_train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transform,  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transform,  
                                          download=True) 

total_train = len(mnist_train_dataset)  # Should be 60000
train_fraction = 0.8
train_size = math.floor(total_train * train_fraction)
val_size = total_train - train_size

print(f"Total training samples: {total_train}, Training size: {train_size}, Validation size: {val_size}, Validation size: {val_size}")   


Total training samples: 60000, Training size: 48000, Validation size: 12000, Validation size: 12000


In [3]:
# Access the first image and its label
image, label = mnist_train_dataset[0]

# Check the size
print(image.shape)  # Output: torch.Size([1, 28, 28])

torch.Size([1, 28, 28])


In [4]:
class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Conv2d(6, 16, 5),
            nn.Tanh(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(16*4*4, 120),
            nn.Tanh(),
            nn.Linear(120, 84),
            nn.Tanh(),
            nn.Linear(84, 10) 
        )
    
    def forward(self, x):
        return self.net(x)

In [5]:
epochs = 10
batch_size=6
num_workers=2
model = LeNet5()
loss_fn = nn.CrossEntropyLoss()
lr=0.001
optimizer = torch.optim.Adam(model.parameters(), lr)
shuffle=True

In [6]:
train_dataset, val_dataset = torch.utils.data.random_split(mnist_train_dataset, [train_size, val_size])
# Create DataLoaders for train, val, and test
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [7]:
#Training + validation loops
for epoch in range(epochs):
    model.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        preds = model(x)
        loss  = loss_fn(preds, y)
        loss.backward()
        optimizer.step()
    # validation
    model.eval()
    val_loss, val_acc = 0.0, 0
    with torch.no_grad():
        for x, y in val_loader:
            preds = model(x)
            val_loss += loss_fn(preds, y).item()
            val_acc  += (preds.argmax(1)==y).sum().item()
    print(f"Epoch {epoch}: val_loss={val_loss/len(val_loader):.4f}, val_acc={val_acc/len(val_dataset):.4f}")

# 3. Final test
model.eval()
test_loss, test_acc = 0.0, 0
with torch.no_grad():
    for x, y in test_loader:
        preds = model(x)
        test_loss += loss_fn(preds, y).item()
        test_acc  += (preds.argmax(1)==y).sum().item()
print(f"Test   : loss={test_loss/len(test_loader):.4f}, acc={test_acc/len(test_dataset):.4f}")


libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x108efed40>
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1663, in __del__
    self._shutdown_workers()
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1627, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/popen_fork.py", line 41, in wait
    if not wait([self.sentinel], ti

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'mnist_model.pth')