In [None]:
import numpy as np
import torch

torch.cuda.is_available()

In [None]:
dtype = torch.double
device = torch.device("cpu")

In [None]:
# yhat, y: (num_examples, num_outputs)
# return: (num_examples,)
def log_cross_entropy(yhat, y):
#     return - torch.sum(y * torch.log(yhat+1e-6))
    return - yhat[y > 0].log().sum()

def cross_entropy(yhat, y):
    return - yhat[y > 0].sum()

In [None]:
# W: (num_inputs, num_outputs)
# b: (num_outputs,)
# x: (num_examples, num_inputs)
# return: (num_examples, num_outputs)
def linear_transform(W, b, x):
    return torch.mm(x, W) + b

In [None]:
# z: (num_examples, num_outputs)
# return: (num_examples, num_outputs)
def softmax(z):
    z_min, _ = torch.min(z, dim=1)
    z = z - z_min.unsqueeze(1)
    exp = torch.exp(z)
    sum_exp = torch.sum(exp, dim=1).unsqueeze(1)
    return exp / sum_exp

def log_softmax(z):
    z_mean = torch.mean(z, dim=1)
    z = z - z_mean.unsqueeze(1)
    exp = torch.exp(z)
    sum_exp = torch.sum(exp, dim=1).unsqueeze(1)
    return z - sum_exp.log()

In [None]:
from dataset import MNIST

train_data = MNIST('data/mnist/train-images-idx3-ubyte.gz', 'data/mnist/train-labels-idx1-ubyte.gz')
test_data = MNIST('data/mnist/t10k-images-idx3-ubyte.gz', 'data/mnist/t10k-labels-idx1-ubyte.gz')

In [None]:
train_data.input_dim, train_data.target_dim, train_data.sample_size

In [None]:
num_examples = train_data.sample_size
num_inputs = np.prod(train_data.input_dim)
num_outputs = np.prod(train_data.target_dim)

In [None]:
W = torch.randn(num_inputs, num_outputs, device=device, dtype=dtype, requires_grad=True)
b = torch.randn(num_outputs, device=device, dtype=dtype, requires_grad=True)

In [None]:
from common import split_data

batch_size = 64
epochs = 10
learning_rate = .002

for e in range(epochs):
    cumulative_cross_entroy = 0
    for i, (input_batch, target_batch) in enumerate(split_data(train_data.inputs, train_data.targets, batch_size)):
        bx = torch.from_numpy(input_batch.reshape(-1, num_inputs)).to(device)
        by = torch.from_numpy(target_batch).to(device)
        z = linear_transform(W, b, bx)
        yhat = log_softmax(z)
        loss = cross_entropy(yhat, by)
        cumulative_cross_entroy += loss.item()
        loss.backward()
        with torch.no_grad():
            W -= learning_rate * W.grad
            b -= learning_rate * b.grad
        W.grad.zero_()
        b.grad.zero_()
    print(cumulative_cross_entroy)

In [None]:
def model_predict(W, b, x):
    z = linear_transform(W, b, x)
    yhat = softmax(z)
    pred = torch.argmax(yhat, dim=1)
    return pred

In [None]:
num_correct = 0.0
num_total = len(test_data.inputs)
for i, (input_batch, target_batch) in enumerate(split_data(test_data.inputs, test_data.targets, batch_size)):
    tx = torch.from_numpy(input_batch.reshape(-1, num_inputs)).to(device)
    ty = torch.from_numpy(target_batch).to(device)
    pred = model_predict(W, b, tx)
    label = torch.argmax(ty, axis=1)
    num_correct += torch.sum(pred == label).item()
print(num_correct / num_total)

In [None]:
for i in range(5):
    test_input = torch.from_numpy(test_data.inputs[np.newaxis, i].reshape(-1, num_inputs)).to(device)
    pred = model_predict(W, b, test_input)
    test_data.show_image(i, "prediction %d" % pred)