In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='poster')
%matplotlib inline

In [2]:
train_transformer = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor()])
val_transformer = transforms.Compose([
    transforms.ToTensor()])

In [3]:
train_loader = torch.utils.data.DataLoader(
    dataset=datasets.FashionMNIST(root='fmnist', train=True,
                                  transform=train_transformer,
                                  download=True),
    batch_size=64)
val_loader = torch.utils.data.DataLoader(
    dataset=datasets.FashionMNIST(root='fmnist', train=False,
                                  transform=val_transformer,
                                  download=True),
    batch_size=128)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [4]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(28 * 28, 10, bias=True) # w -> (28 * 28, 10), b -> (10)
    def forward(self, x):
        # x -> (batch size, 1, 28, 28)
        # expects -> (batch size, 28 * 28)
        x = x.view(-1, 28 * 28)
        x = self.fc(x)
        # x -> (batch size, 10)
        return x

In [5]:
net = Model()
celoss = nn.CrossEntropyLoss()
opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

In [6]:
for e in range(5):
    print(f'Epoch {e + 1}')
    for i, (images, labels) in enumerate(train_loader):
        if ( i + 1) % 100 == 0:
            print(e, i)
        yhat = net(images)
        opt.zero_grad()
        loss = celoss(yhat, labels)
        loss.backward()
        opt.step()
    
    preds_arr = []
    labels_arr = []
    for val_images, val_labels in val_loader:
        val_yhat = net(val_images)
        preds = val_yhat.argmax(dim=1)
        preds_arr.append(preds)
        labels_arr.append(val_labels)
    preds = np.hstack(preds_arr)
    labels = np.hstack(labels_arr)

    accuracy = (preds == labels).sum() / preds.shape[0]
    print(f'Validation accuracy: {accuracy * 100}')

Epoch 1
0 99
0 199
0 299
0 399
0 499
0 599
0 699
0 799
0 899
Validation accuracy: 76.48
Epoch 2
1 99
1 199
1 299
1 399
1 499
1 599
1 699
1 799
1 899
Validation accuracy: 77.66
Epoch 3
2 99
2 199
2 299
2 399
2 499
2 599
2 699
2 799
2 899
Validation accuracy: 78.43
Epoch 4
3 99
3 199
3 299
3 399
3 499
3 599
3 699
3 799
3 899
Validation accuracy: 78.68
Epoch 5
4 99
4 199
4 299
4 399
4 499
4 599
4 699
4 799
4 899
Validation accuracy: 78.8
