## Multiclass Classification Model on MNIST Dataset

### Dependencies Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from PIL import Image

import numpy as np
import random

### Data Preparation

In [2]:
mnist_train = datasets.MNIST(root='./data', train=True, download=False, transform=None)
mnist_test = datasets.MNIST(root='./data', train=False, download=False, transform=None)

In [3]:
# data as numpy array
np.array(mnist_train[0][0])[4:10][4:10]

array([[  0,   0,   0,   0,   0,   0,   0,  18, 219, 253, 253, 253, 253,
        253, 198, 182, 247, 241,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  80, 156, 107, 253, 253,
        205,  11,   0,  43, 154,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=uint8)

In [4]:
# data as pytorch tensor
transform = transforms.Compose([
    transforms.ToTensor()
])

transform(mnist_train[0][0])[0][4:10][4:10]

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.8588,
         0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137, 0.9686, 0.9451,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3137,
         0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000, 0.1686, 0.6039,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000]])

In [5]:
train_data = torch.cat([transform(mnist_train[i][0]) for i in range(len(mnist_train))]).view(-1, 28*28)
test_data = torch.cat([transform(mnist_train[i][0]) for i in range(len(mnist_test))]).view(-1, 28*28)

# one-hot label
train_label = torch.stack([F.one_hot(torch.tensor(im[1]), num_classes=10) for im in mnist_train])
test_label = torch.stack([F.one_hot(torch.tensor(im[1]), num_classes=10) for im in mnist_test])

print(f"train_data.shape = {train_data.shape}")
print(f"test_data.shape = {test_data.shape}")
print(f"train_label.shape = {train_label.shape}")
print(f"test_label.shape = {test_label.shape}")

train_data.shape = torch.Size([60000, 784])
test_data.shape = torch.Size([10000, 784])
train_label.shape = torch.Size([60000, 10])
test_label.shape = torch.Size([10000, 10])


### Model

In [22]:
class MNIST_Model(nn.Module):
    def __init__(self):
        super(MNIST_Model, self).__init__()
        self.hidden1 = nn.Linear(28*28, 50)
#         self.hidden2 = nn.Linear(300, 50)
        self.output = nn.Linear(50, 10)
#         self.layer = nn.Linear(28*28, 10)
    def forward(self, x):
        x = self.hidden1(x)
        x = F.relu(x)
#         x = self.hidden2(x)
#         x = F.relu(x)
        x = self.output(x)
#         x = self.layer(x)
#         x = F.relu(x)
        x = F.log_softmax(x, dim=1)
        return x

### Loss Function

In [174]:
def L1_loss(pred, act):
    """
    Returns absolute loss between prediction and truth.
    
    Args:
    pred (Tensor): n * 10 tensor, contains likelihood the model predicted for the each of the n images to be digits
         from 0 to 9. The likelihood of the image being a digit is stored in the length 10 tensor, where the 
         probability value pred[j][i] at index i s.t. 0 <= i <= 9 indicates the probability the model predicts that
         image j is digit i.
    act (Tensor): n * 10 tensor, a stack of one-hot tensors representing true labels.
    """
    assert pred.shape == act.shape, 'incompatible pred and true shapes'
    diff = act - pred
    err = diff.abs()
    err = err.sum(1)
    err = err.mean()
    return err

### Metric

In [16]:
def accuracy(pred, act):
    _, ind_pred = torch.max(pred, 1)
    _, ind_act = torch.max(act, 1)
    assert ind_act.shape == ind_pred.shape, 'incompatible prediction and truth sizes'
    
    diff = ind_pred - ind_act
    num_diff = diff.count_nonzero()
    correct = len(pred) - num_diff
    accuracy = correct / len(pred)
    return accuracy.item()

### Training

In [23]:
# training
def train_model(model, X_train, y_train, X_test, y_test, batch_size=20, sgd=False):
    assert len(X_train) == len(y_train), 'incompatible training size'
    assert len(X_test) == len(y_test), 'incompatible test size'
    
    dataset = TensorDataset(X_train, y_train)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    learning_rate = 1e-3
    epochs = 500
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}")
        if sgd == True:
            start_index = random.randint(0, len(X_train)-batch_size-1)
            X = X_train[start_index:start_index+batch_size]
            y = y_train[start_index:start_index+batch_size]
        else:
            X = X_train
            y = y_train
        
        # forward pass
        y_pred = model(X)
        loss = F.cross_entropy(y_pred, y.float())
        print(f"training loss = {loss}\n")
        
        test_pred = model(X_test)
        dev_error = F.cross_entropy(test_pred, y_test.float())
        print(f"validation loss = {dev_error}\n")

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # weight update
        optimizer.step()
        
dev_size = 6000
model = MNIST_Model()
train_model(model, train_data[:-dev_size], train_label[:-dev_size], test_data[-dev_size:], test_label[-dev_size:])

Epoch 1
training loss = 2.3028316497802734

validation loss = 2.3030312061309814

Epoch 2
training loss = 2.299143075942993

validation loss = 2.303001642227173

Epoch 3
training loss = 2.2952468395233154

validation loss = 2.3030102252960205

Epoch 4
training loss = 2.290877103805542

validation loss = 2.3030648231506348

Epoch 5
training loss = 2.285942316055298

validation loss = 2.3031508922576904

Epoch 6
training loss = 2.2804503440856934

validation loss = 2.3032500743865967

Epoch 7
training loss = 2.274444341659546

validation loss = 2.303354263305664



  x = F.softmax(x)


Epoch 8
training loss = 2.2679316997528076

validation loss = 2.3034534454345703

Epoch 9
training loss = 2.2608861923217773

validation loss = 2.303537368774414

Epoch 10
training loss = 2.253247022628784

validation loss = 2.303617238998413

Epoch 11
training loss = 2.2449183464050293

validation loss = 2.3036930561065674

Epoch 12
training loss = 2.235825777053833

validation loss = 2.3037800788879395

Epoch 13
training loss = 2.2259676456451416

validation loss = 2.3038902282714844

Epoch 14
training loss = 2.2153923511505127

validation loss = 2.304039716720581

Epoch 15
training loss = 2.204164743423462

validation loss = 2.3042430877685547

Epoch 16
training loss = 2.1923511028289795

validation loss = 2.304511308670044

Epoch 17
training loss = 2.1800036430358887

validation loss = 2.3048460483551025

Epoch 18
training loss = 2.167151927947998

validation loss = 2.3052468299865723

Epoch 19
training loss = 2.1538116931915283

validation loss = 2.305715560913086

Epoch 20
traini

validation loss = 2.3466129302978516

Epoch 115
training loss = 1.6127524375915527

validation loss = 2.3467111587524414

Epoch 116
training loss = 1.6117486953735352

validation loss = 2.3468074798583984

Epoch 117
training loss = 1.6107676029205322

validation loss = 2.3469035625457764

Epoch 118
training loss = 1.6098097562789917

validation loss = 2.3470005989074707

Epoch 119
training loss = 1.608875036239624

validation loss = 2.3470983505249023

Epoch 120
training loss = 1.607962965965271

validation loss = 2.3471970558166504

Epoch 121
training loss = 1.6070722341537476

validation loss = 2.3472959995269775

Epoch 122
training loss = 1.6062010526657104

validation loss = 2.3473939895629883

Epoch 123
training loss = 1.6053482294082642

validation loss = 2.3474905490875244

Epoch 124
training loss = 1.6045132875442505

validation loss = 2.3475847244262695

Epoch 125
training loss = 1.6036957502365112

validation loss = 2.3476758003234863

Epoch 126
training loss = 1.602895855903

validation loss = 2.351884365081787

Epoch 215
training loss = 1.565081000328064

validation loss = 2.3519084453582764

Epoch 216
training loss = 1.5648400783538818

validation loss = 2.3519320487976074

Epoch 217
training loss = 1.5646013021469116

validation loss = 2.3519551753997803

Epoch 218
training loss = 1.5643644332885742

validation loss = 2.351978302001953

Epoch 219
training loss = 1.5641294717788696

validation loss = 2.3520009517669678

Epoch 220
training loss = 1.5638965368270874

validation loss = 2.3520233631134033

Epoch 221
training loss = 1.563665509223938

validation loss = 2.3520455360412598

Epoch 222
training loss = 1.5634363889694214

validation loss = 2.352067470550537

Epoch 223
training loss = 1.5632089376449585

validation loss = 2.3520889282226562

Epoch 224
training loss = 1.5629830360412598

validation loss = 2.3521103858947754

Epoch 225
training loss = 1.5627589225769043

validation loss = 2.3521313667297363

Epoch 226
training loss = 1.562536716461181

training loss = 1.5469571352005005

validation loss = 2.3534741401672363

Epoch 320
training loss = 1.5468260049819946

validation loss = 2.353484630584717

Epoch 321
training loss = 1.5466951131820679

validation loss = 2.353494882583618

Epoch 322
training loss = 1.5465649366378784

validation loss = 2.3535048961639404

Epoch 323
training loss = 1.5464352369308472

validation loss = 2.3535149097442627

Epoch 324
training loss = 1.5463061332702637

validation loss = 2.353525161743164

Epoch 325
training loss = 1.5461772680282593

validation loss = 2.3535349369049072

Epoch 326
training loss = 1.546048879623413

validation loss = 2.3535449504852295

Epoch 327
training loss = 1.5459208488464355

validation loss = 2.3535549640655518

Epoch 328
training loss = 1.5457934141159058

validation loss = 2.353565216064453

Epoch 329
training loss = 1.5456660985946655

validation loss = 2.3535752296447754

Epoch 330
training loss = 1.5455396175384521

validation loss = 2.3535850048065186

Epoch 3

Epoch 418
training loss = 1.5357240438461304

validation loss = 2.3542892932891846

Epoch 419
training loss = 1.5356241464614868

validation loss = 2.354294538497925

Epoch 420
training loss = 1.5355242490768433

validation loss = 2.354299783706665

Epoch 421
training loss = 1.535425066947937

validation loss = 2.3543050289154053

Epoch 422
training loss = 1.5353257656097412

validation loss = 2.3543105125427246

Epoch 423
training loss = 1.535226583480835

validation loss = 2.3543155193328857

Epoch 424
training loss = 1.535127878189087

validation loss = 2.354320526123047

Epoch 425
training loss = 1.5350295305252075

validation loss = 2.354325771331787

Epoch 426
training loss = 1.5349311828613281

validation loss = 2.3543307781219482

Epoch 427
training loss = 1.5348331928253174

validation loss = 2.3543355464935303

Epoch 428
training loss = 1.5347357988357544

validation loss = 2.3543407917022705

Epoch 429
training loss = 1.5346384048461914

validation loss = 2.3543457984924316


In [24]:
train_pred = model(train_data)

  x = F.softmax(x)


In [25]:
accuracy(train_pred, train_label)

0.9440333247184753

In [26]:
test_pred = model(test_data)

  x = F.softmax(x)


In [27]:
accuracy(test_pred, test_label)

0.09939999878406525