In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(7*7*64, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def prepare_data_loader(dataset, batch_size=256, train=True):
    if dataset == 'MNIST':
        dataset_class = datasets.MNIST
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))])
        data = dataset_class(root='./data', train=train, download=True, transform=transform)
    elif dataset == 'SVHN':
        dataset_class = SVHN
        split = 'train' if train else 'test'
        transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                        transforms.Resize((28, 28)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5,), (0.5,))])
        data = dataset_class(root='./data', split=split, download=True, transform=transform)

    return DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=8)

epochs = 15

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN()
model.to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()


data_loader = prepare_data_loader('MNIST')

model.train()
for epoch in range(epochs):
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')


data_loader = prepare_data_loader('MNIST', train=False)

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the model: %d %%' % (100 * correct / total))

data_loader = prepare_data_loader('SVHN')

model.train()
for epoch in range(epochs):
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

data_loader = prepare_data_loader('SVHN', train=False)

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the model: %d %%' % (100 * correct / total))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 138491068.58it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 86156254.50it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 78785744.03it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 939560.50it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw





Epoch 1, Loss: 0.09171704202890396
Epoch 2, Loss: 0.026677899062633514
Epoch 3, Loss: 0.02370590716600418
Epoch 4, Loss: 0.005984825547784567
Epoch 5, Loss: 0.03890741243958473
Epoch 6, Loss: 0.007693937513977289
Epoch 7, Loss: 0.010660524480044842
Epoch 8, Loss: 0.0034923783969134092
Epoch 9, Loss: 0.006934292148798704
Epoch 10, Loss: 0.050870489329099655
Epoch 11, Loss: 0.00012621360656339675
Epoch 12, Loss: 0.0029706647619605064
Epoch 13, Loss: 0.001243883976712823
Epoch 14, Loss: 0.0025125076062977314
Epoch 15, Loss: 0.029517492279410362
Accuracy of the model: 99 %
Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/train_32x32.mat


100%|██████████| 182040794/182040794 [00:33<00:00, 5359392.14it/s] 


Epoch 1, Loss: 0.25316864252090454
Epoch 2, Loss: 0.41190671920776367
Epoch 3, Loss: 0.21338558197021484
Epoch 4, Loss: 0.284019410610199
Epoch 5, Loss: 0.1736878901720047
Epoch 6, Loss: 0.20226913690567017
Epoch 7, Loss: 0.23257097601890564
Epoch 8, Loss: 0.11025325208902359
Epoch 9, Loss: 0.19612009823322296
Epoch 10, Loss: 0.1562546342611313
Epoch 11, Loss: 0.33906006813049316
Epoch 12, Loss: 0.2796209752559662
Epoch 13, Loss: 0.06671099364757538
Epoch 14, Loss: 0.08226464688777924
Epoch 15, Loss: 0.15355302393436432
Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./data/test_32x32.mat


100%|██████████| 64275384/64275384 [00:12<00:00, 5033553.48it/s] 


Accuracy of the model: 88 %
