In [1]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np

### Task 1

In [2]:
######TASK 1######
# code for downloading and formatting the data
transforms_fnc = transforms.Compose([
    # transforms.Resize((784, 1)),
    transforms.ToTensor()
])

target_transform_fnc = transforms.Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), 1))

train_data = MNIST('./data', train=True, download=True, transform=transforms_fnc, target_transform=target_transform_fnc)
test_data = MNIST('./data', train=False, download=True, transform=transforms_fnc)

train_bs = len(train_data)
test_bs = len(test_data)

train_loader = iter(DataLoader(train_data, batch_size=train_bs, shuffle=False))
test_loader = iter(DataLoader(test_data, batch_size=test_bs, shuffle=False))

train_data_X, train_data_y = next(train_loader)
train_data_X = train_data_X.reshape(train_data_X.size(0), -1)

test_data_X, test_data_y = next(test_loader)
test_data_X = test_data_X.reshape(test_data_X.size(0), -1)

print('train data shape: {}, label shape: {}'.format(train_data_X.size(), train_data_y.size()))
print('test data shape: {}, label shape: {}'.format(test_data_X.size(), test_data_y.size()))

train data shape: torch.Size([60000, 784]), label shape: torch.Size([60000, 10])
test data shape: torch.Size([10000, 784]), label shape: torch.Size([10000])


### Task 2

In [3]:
######TASK 2######
# code for minibatch SGD implementation
def _gradient(data, label, weight):
    return torch.matmul(torch.t(data), torch.matmul(data, weight) - label) / data.size(0)

def _loss(data, label, weight):
    inner = label - torch.matmul(data, weight)
    norm = torch.linalg.norm(inner)
    return 0.5 * (norm ** 2) / data.size(0)

def _acc(data, label, weight):
    preds = torch.matmul(data, weight)
    return torch.sum((label == torch.argmax(preds, dim=1)).int()) / data.size(0)

def sgd_train(train_data_X_arg, train_data_y_arg, test_data_X_arg, test_data_y_arg, num_of_iterations, batch_size, learning_rate):
    # init weight
    weight = torch.empty(784, 10)
    torch.nn.init.zeros_(weight)

    # uni_dist_weight = torch.ones(train_data_X_arg.size(0))
    running_loss = []
    running_acc = []
    for iter_idx in range(num_of_iterations):

        # sampled_idx = torch.multinomial(uni_dist_weight, batch_size, replacement=True)
        sampled_idx = np.random.randint(0, train_data_X_arg.size(0), batch_size)
        sampled_batch_X, sampled_batch_y = train_data_X_arg[sampled_idx], train_data_y_arg[sampled_idx]

        # loss and gradient
        loss = _loss(sampled_batch_X, sampled_batch_y, weight)
        gradient = _gradient(sampled_batch_X, sampled_batch_y, weight)
        running_loss.append(loss.item())

        # acc
        acc = _acc(test_data_X_arg, test_data_y_arg, weight)
        running_acc.append(acc.item())

        # update
        weight = weight - learning_rate * gradient
        if iter_idx == 0 or (iter_idx + 1) % 100 == 0:
            print('iter: {}, loss: {}, acc: {}'.format(iter_idx + 1, loss.item(), acc))

In [8]:
sgd_train(train_data_X, train_data_y, test_data_X, test_data_y, int(train_data_X.size(0) / 10), 100, 0.001)

iter: 1, loss: 0.5, acc: 0.09799999743700027
iter: 100, loss: 0.36535966396331787, acc: 0.7544999718666077
iter: 200, loss: 0.33776628971099854, acc: 0.788100004196167
iter: 300, loss: 0.30938228964805603, acc: 0.7833999991416931
iter: 400, loss: 0.2985003590583801, acc: 0.795799970626831
iter: 500, loss: 0.27080923318862915, acc: 0.8014000058174133
iter: 600, loss: 0.27861061692237854, acc: 0.8065000176429749
iter: 700, loss: 0.2492048144340515, acc: 0.8115000128746033
iter: 800, loss: 0.250323623418808, acc: 0.8148999810218811
iter: 900, loss: 0.24730165302753448, acc: 0.8162000179290771
iter: 1000, loss: 0.2528766989707947, acc: 0.8203999996185303
iter: 1100, loss: 0.24798446893692017, acc: 0.8248000144958496
iter: 1200, loss: 0.2303975522518158, acc: 0.8259999752044678
iter: 1300, loss: 0.2494015246629715, acc: 0.8278999924659729
iter: 1400, loss: 0.2510981261730194, acc: 0.8299000263214111
iter: 1500, loss: 0.23252353072166443, acc: 0.8303999900817871
iter: 1600, loss: 0.265243053