In [1]:
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
import torch
import torch.utils.data as dutils

import torchvision.datasets as dataset
import torchvision.transforms as dtrans

In [4]:
device = 'cuda'
seed = 99

In [5]:
input_dim = 784
output_dim = 10
num_examples = 60000

batch_size = 32

torch.manual_seed(seed);

### Data setup

In [6]:
transform = dtrans.Compose([dtrans.ToTensor(), dtrans.Normalize((0.5,), (0.5,))])
train_set = dataset.MNIST(root="../../data/", download=True, train=True, transform=transform)
test_set = dataset.MNIST(root="../../data/", download=True, train=False, transform=transform)

train_loader = dutils.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = dutils.DataLoader(test_set, batch_size=batch_size, shuffle=False)

### Model setup

In [7]:
def cross_entropy(yhat, y):
    loss = -torch.log(yhat[:, y] + 1e-6)
    return loss.mean()

In [8]:
def softmax(y_linear):
    row_max = y_linear.max(dim=1)[0]
    row_max.unsqueeze_(1)
    row_max = row_max.repeat(1, y_linear.shape[1])
    exp = torch.exp(y_linear - row_max)
    norm = exp.sum(dim=1).unsqueeze(1).repeat(1, y_linear.shape[1])
    return exp / norm

In [9]:
# sample_y_linear = torch.randn(10, 10)
# sample_yhat = softmax(sample_y_linear)
# print(sample_y_linear)
# print(sample_yhat)

In [10]:
def SGD(params, lr):
    for param in params:
        param.data = param.data - lr * param.grad

In [11]:
class Net():
    def __init__(self, input_dim, out_dim, device):
        self.w = torch.randn(size=(input_dim, output_dim), device=device, requires_grad=True)
        self.b = torch.randn(output_dim, device=device, requires_grad=True)
#         print('Weight shape: ', self.w.shape)
#         print('Bias shape: ', self.b.shape)
        
    def forward(self, x):
        w = self.w
        bias = self.b
        out = torch.mm(x, w) + bias
        out = softmax(out)
        return out.squeeze()
    
    def parameters(self):
        return [self.w, self.b]
    
    def zero_grad(self):
        self.w.grad = None
        self.b.grad = None
        
    def __call__(self, x):
        return self.forward(x)

In [12]:
def test(net, test_loader, device):
    loss = 0
    error = 0
    for i, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, input_dim)
        with torch.no_grad():
            out = net(images)
            loss += cross_entropy(out, labels).item()
            error += (out.argmax(1) != labels).sum().item()
    loss /= i       
    print('Loss: %.4f, Error: %d' % (loss, error))

In [13]:
def train(net, train_loader, device):
    params = net.parameters()
    lr = 0.05
    
    loss = 0
    error = 0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 784)
        
        out = net(images)
        loss_batch = cross_entropy(out, labels)
        
        net.zero_grad()
        loss_batch.backward()
        SGD(params, lr)
        
        loss += loss_batch.item()
        error += (out.argmax(1) != labels).sum().item()
    loss /= i       
    print('Loss: %.4f, Error: %d' % (loss, error))

In [None]:
net = Net(input_dim, output_dim, device)

test(net, test_loader, device)
for epochs in range(50):
    train(net, train_loader, device)
    test(net, test_loader, device)

Loss: 11.9037, Error: 9237
Loss: 10.4815, Error: 53149
Loss: 9.6393, Error: 8787
Loss: 8.6890, Error: 52342
Loss: 7.6064, Error: 8752
Loss: 6.9661, Error: 52471
Loss: 6.1948, Error: 8758
Loss: 5.9001, Error: 52746
Loss: 6.3346, Error: 8815
Loss: 5.4730, Error: 52726
Loss: 5.4648, Error: 8846
Loss: 5.3808, Error: 52879
Loss: 5.0378, Error: 8764
Loss: 5.4657, Error: 52833
Loss: 5.0953, Error: 8690
Loss: 5.5544, Error: 52710
Loss: 5.2022, Error: 8989
Loss: 5.6810, Error: 52810
Loss: 4.7876, Error: 8796
Loss: 5.7221, Error: 52761
Loss: 8.5183, Error: 9007
Loss: 6.0051, Error: 52724
Loss: 4.9018, Error: 8976
Loss: 5.7725, Error: 52678
Loss: 6.0535, Error: 8668
Loss: 6.1507, Error: 52669
Loss: 6.0806, Error: 8995
Loss: 6.1019, Error: 52745
Loss: 6.5479, Error: 8973
Loss: 6.2497, Error: 52777
Loss: 5.0957, Error: 8726
Loss: 6.2973, Error: 52725
Loss: 5.4567, Error: 8794
Loss: 6.1098, Error: 52680
Loss: 4.2894, Error: 8890
Loss: 6.3927, Error: 52673
Loss: 8.3267, Error: 8843
