In [1]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from IPython.display import clear_output
from PIL import Image
from matplotlib import cm
from time import perf_counter
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm

warnings.filterwarnings('ignore')

In [2]:
mnist_train = MNIST(
    "../datasets/mnist",
    train=True,
    download=True
)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../datasets/mnist\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:09<00:00, 1063142.75it/s]


Extracting ../datasets/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../datasets/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../datasets/mnist\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<?, ?it/s]


Extracting ../datasets/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../datasets/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../datasets/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 824596565.71it/s]


Extracting ../datasets/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../datasets/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../datasets/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting ../datasets/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../datasets/mnist\MNIST\raw





In [20]:
mnist_valid = MNIST(
    "../datasets/mnist",
    train=False,
    download=True,
    transform=T.ToTensor()
)

In [22]:
def create_mlp_model():
    model = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),

        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),

        nn.Flatten(),
        nn.Linear(4 * 4 * 64, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    )
    return model



In [23]:
def evaluate(model: nn.Module, device: torch.device, data_loader: DataLoader, loss_fn):
    model.eval()
    total_loss = 0
    total = 0
    correct = 0
    for x, y in tqdm(data_loader, desc='Evaluate'):
        x = x.to(device)
        y = y.to(device)
        output = model(x)
        loss = loss_fn(output, y)
        total_loss += loss.item()
        _, y_pred = torch.max(output, 1)
        total += y.size(0)
        correct += (y_pred == y).sum().item()
    return total_loss / len(data_loader), correct / total

In [None]:
# def train(model, device, train_loader, optimizer, epoch):
#     model.train()
#     for batch_idx, (data, target) in enumerate(train_loader):
#         data, target = data.to(device), target.to(device)
#         optimizer.zero_grad()
#         output = model(data)
#         loss = nn.CrossEntropyLoss()(output, target)
#         loss.backward()
#         optimizer.step()
#         if batch_idx % 100 == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), loss.item()))
#         

In [12]:
# from torch.optim import Optimizer
# def train(model: nn.Module, data_loader: DataLoader, optimizer: Optimizer, loss_fn):
#     model.train()
#     total_loss = 0
#     total = 0
#     correct = 0
#     for x, y in tqdm(data_loader, desc='Train'):
#         optimizer.zero_grad()
#         output = model(x)
#         loss = loss_fn(output, y)
#         loss.backward()
#         total_loss += loss.item()
#         optimizer.step()
#         _, y_pred = torch.max(output, 1)
#         total += y.size(0)
#         correct += (y_pred == y).sum().item()
#     return total_loss / len(data_loader), correct / total

In [24]:
def train(model: nn.Module, device: torch.device, data_loader: DataLoader, optimizer: torch.optim.Optimizer, loss_fn):
    model.train()
    total_loss = 0
    total = 0
    correct = 0
    for x, y in tqdm(data_loader, desc='Train'):
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
        _, y_pred = torch.max(output, 1)
        total += y.size(0)
        correct += (y_pred == y).sum().item()
    return total_loss / len(data_loader), correct / total

In [25]:
def plot_stats(
    train_loss: list[float],
    valid_loss: list[float],
    valid_accuracy: list[float],
    title: str
):
    plt.figure(figsize=(16, 8))

    plt.title(title + ' loss')

    plt.plot(train_loss, label='Train loss')
    plt.plot(valid_loss, label='Valid loss')
    plt.legend()
    plt.grid()

    plt.show()

    plt.figure(figsize=(16, 8))

    plt.title(title + ' accuracy')

    plt.plot(valid_accuracy)
    plt.grid()

    plt.show()

In [26]:
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
valid_loader = DataLoader(mnist_valid, batch_size=64, shuffle=False)

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [28]:
import torchvision

model = create_mlp_model().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [29]:
for epoch in range(20):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, loss_fn)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Train Accuracy: {train_accuracy}')
    test_loss, test_accuracy = evaluate(model, device, valid_loader, loss_fn)
    print(f'Epoch {epoch+1}, Test Loss: {test_loss}, Test Accuracy: {test_accuracy}')


torch.save(model.state_dict(), 'mnist_model.pth')

Train:   0%|          | 0/938 [00:00<?, ?it/s]


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>