## Data preparation

In [None]:
import os
from torchvision import transforms
from torchvision.datasets import mnist
import torch.utils.data as data
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import itertools


def save():
    os.makedirs('mnist/train', exist_ok=True)
    os.makedirs('mnist/test', exist_ok=True)
    for i in range(10):
        os.makedirs('mnist/train/' + str(i), exist_ok=True)
        os.makedirs('mnist/test/' + str(i), exist_ok=True)

    for i, item in enumerate(train_loader):
        img, label = item
        img = img[0].cpu().numpy()
        array = (img.reshape((28, 28)) * 255).astype(np.uint8)
        img = Image.fromarray(array, 'L')
        label = label.cpu().numpy()[0]
        img_path = 'mnist/train/' + str(label) + '/' + str(i) + '.jpg'
        img.save(img_path)


    for i, item in enumerate(test_loader):
        img, label = item
        img = img[0].cpu().numpy()
        array = (img.reshape((28, 28)) * 255).astype(np.uint8)
        img = Image.fromarray(array, 'L')
        label = label.cpu().numpy()[0]
        img_path = 'mnist/test/' + str(label) + '/' + str(i) + '.jpg'
        img.save(img_path)


def show():
    plt.figure(figsize=(16, 9))
    for i, item in enumerate(itertools.islice(train_loader,2,12)):
        plt.subplot(2, 5, i+1)
        img,label= item
        img = img[0].cpu().numpy()
        array = (img.reshape((28, 28)) * 255).astype(np.uint8)
        img = Image.fromarray(array, 'L')
        label = label.cpu().numpy()[0]
        plt.imshow(img, cmap=plt.get_cmap('gray'))
    plt.show()


if __name__ == '__main__':
    train_data = mnist.MNIST('mnist', train=True, transform=transforms.ToTensor(), download=True)
    test_data = mnist.MNIST('mnist', train=False, transform=transforms.ToTensor(), download=True)
    train_loader = data.DataLoader(dataset=train_data, batch_size=1, shuffle=True)
    test_loader = data.DataLoader(dataset=test_data, batch_size=1, shuffle=True)
    train_total = train_loader.__len__()
    test_total = test_loader.__len__()
    labels = train_data.targets
    dataiter = iter(train_data)
    images, labs = dataiter.__next__()
    save()
    show()

## Network architecture

In [None]:
import torch
from torch import nn
from torchsummary import summary


class CNN(nn.Module):
    def __init__(self, classes=10, n1=16, n2=32, n3=64):
        super(CNN, self).__init__()
        self.classes = classes
        self.n1 = n1
        self.n2 = n2
        self.n3 = n3

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=self.n1, kernel_size=5, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=self.n1, out_channels=self.n2, kernel_size=5, stride=1, padding=0)

        self.maxpool = nn.MaxPool2d(2)
        # The linear layer input size (n2 * 4 * 4) is based on the output size of conv layers.
        self.fc1 = nn.Linear(self.n2 * 4 * 4, self.n3)
        self.fc2 = nn.Linear(self.n3, self.classes)

        self.net = nn.Sequential(
            self.conv1,
            nn.ReLU(),
            self.maxpool,
            self.conv2,
            nn.ReLU(),
            self.maxpool,
            nn.Flatten(),
            self.fc1,
            nn.ReLU(),
            self.fc2,
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        # x = self.conv1(x)
        # x = nn.ReLU()(x)
        # x = self.maxpool(x)
        # x = self.conv2(x)
        # x = nn.ReLU()(x)
        # x = self.maxpool(x)
        # x = x.view(x.size(0), -1)
        # x = self.fc1(x)
        # x = nn.ReLU()(x)
        # x = self.fc2(x)
        # x = nn.Softmax()(x)

        x = self.net(x)
        return x


if __name__ == '__main__':
    cnn = CNN(classes=10)
    print(cnn)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    cnn = cnn.to(device)
    summary(cnn, (1, 28, 28), device=device.type)

## Dataset

In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torchvision import transforms
import matplotlib.pyplot as plt

class MnistDataSet(Dataset):
    def __init__(self, data_dir, class_num=10, transform=None):
        """
        Args:
            data_dir (str): Directory containing data, e.g., 'mnist/train' or 'mnist/test'
            class_num (int): Number of classes (default is 10)
            transform (callable, optional): Transform to be applied on a sample image
        """
        super(MnistDataSet, self).__init__()
        self.data_dir = data_dir
        self.class_num = class_num
        self.transform = transform

        self.data = []
        # Loop through each class directory
        for label in range(self.class_num):
            label_dir = os.path.join(self.data_dir, str(label))
            img_names = os.listdir(label_dir)
            for img_name in img_names:
                img_path = os.path.join(label_dir, img_name)
                self.data.append((img_path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path, label = self.data[index]
        # Open image and convert to grayscale to ensure one channel
        img = Image.open(img_path)
        img = np.array(img, dtype=np.float32).reshape(1, 28, 28) / 255.0
        return img, label


if __name__ == '__main__':
    # show some elements in dataset
    from torch.utils.data import DataLoader
    dataset = MnistDataSet('mnist/test')
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

    img, label = dataset[0]
    plt.imshow(img.squeeze(), cmap='gray')

## Train & Evaluation Functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        # Move input data and labels to the specified device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate accuracy
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if (batch_idx + 1) % 100 == 0:
            print(f'Batch {batch_idx+1}/{len(dataloader)} - Loss: {running_loss/(batch_idx+1):.4f}  Acc: {100.*correct/total:.2f}%')
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

## Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

train_dataset = MnistDataSet(data_dir='mnist/train')
test_dataset = MnistDataSet(data_dir='mnist/test')

batch_size = 64

# num worker: cpu cores
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Initialize model, loss function, and optimizer
model = CNN(classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 10
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

# Start training
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)

    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_accuracies.append(train_acc)
    test_accuracies.append(test_acc)

    print(f'Epoch {epoch+1} Summary: Train Loss {train_loss:.4f}, Train Acc {train_acc:.2f}%, Test Loss {test_loss:.4f}, Test Acc {test_acc:.2f}%')
    print('-'*50)

# Plot loss and accuracy curves
epochs = range(1, num_epochs+1)
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, label='Train Accuracy')
plt.plot(epochs, test_accuracies, label='Test Accuracy')
plt.title('Accuracy Curve')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()

# save checkpoint
torch.save(model.state_dict(), 'mnist_cnn.pth')

## Inference

In [None]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN(classes=10)
checkpoint = torch.load('mnist_cnn.pth')
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

criterion = nn.CrossEntropyLoss()
dataset = MnistDataSet('mnist/test')
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

correct = 0
total = 0
test_loss, test_acc = evaluate(model, dataloader, criterion, device)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')