In [1]:
import torch
import torchvision 
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

In [2]:
dataset = torchvision.datasets.MNIST(root="./data", download=True)
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train

In [3]:
X = dataset.data
Y = dataset.targets

In [4]:
train_size = 45000
Xtrain = X[:train_size] /255
Ytrain = Y[:train_size]
Ytrain = Y[:train_size].unsqueeze(1)

Xtest = X[train_size:] /255
Ytest = Y[train_size:]
Ytest = Y[train_size:].unsqueeze(1)

In [5]:
ninput = 784
nhidden = 256
nclasses = 10

In [6]:
W1 = torch.randn(ninput, nhidden, requires_grad=True) * torch.sqrt(torch.tensor(2.0) / (ninput + nhidden))
b1 = torch.randn(1, nhidden, requires_grad=True)
W2 = torch.randn(nhidden, nclasses, requires_grad=True)* torch.sqrt(torch.tensor(2.0) / (ninput + nhidden))
b2 = torch.randn(1, nclasses, requires_grad=True)

In [7]:
X_train = Xtrain[:10]
Y_train = Ytrain[:10]

In [8]:
dataset.classes

['0 - zero',
 '1 - one',
 '2 - two',
 '3 - three',
 '4 - four',
 '5 - five',
 '6 - six',
 '7 - seven',
 '8 - eight',
 '9 - nine']

In [15]:
input_size = 28 * 28
alpha = 0.1
losses = []

for epoch in range(100):
    Z1 = X_train.view(-1, input_size) @ W1 + b1
    A1 = torch.tanh(Z1)
    Z2 = A1 @ W2 + b2

    zmax = Z2.max(dim=1, keepdim=True).values
    znorm = Z2 - zmax

    zexp = znorm.exp()
    zexp_sum = zexp.sum(dim=1, keepdim=True)
    zexp_sum_inv = zexp_sum ** (-1)
    probs = zexp * zexp_sum_inv
    log_probs = probs.log()
    L = -log_probs[torch.arange(len(Y_train)), Y_train.squeeze()].mean()
    losses.append(L.item())

    dL_dL = torch.ones_like(L)
    dL_dlogprobs = torch.zeros_like(log_probs)
    dL_dlogprobs[torch.arange(len(Y_train)), Y_train.squeeze()] = -dL_dL / len(Y_train)
    dL_dprobs = dL_dlogprobs * 1 / probs

    dL_dzexp = dL_dprobs * zexp_sum_inv
    dL_dzexp_sum_inv = (dL_dprobs * zexp).sum(1, keepdim=True)
    dL_dzexp_sum = -1 * dL_dzexp_sum_inv * zexp_sum**(-2)
    dL_dzexp += dL_dzexp_sum
    dL_dznorm = dL_dzexp * zexp.clone()

    dL_dzmax = -dL_dznorm.sum(1, keepdim=True)
    dL_dZ = dL_dznorm
    dL_dZ += torch.nn.functional.one_hot(Z2.max(dim=1).indices, nclasses) * dL_dzmax

    dL_dW2 = A1.T @ dL_dZ
    dL_db2 = dL_dZ.sum(0, keepdim=True)

    dL_dA1 = dL_dZ @ W2.T
    dL_dZ1 = dL_dA1 * (1 - A1**2)

    dL_dW1 = X_train.view(-1, input_size).T @ dL_dZ1
    dL_db1 = dL_dZ1.sum(0, keepdim=True)

    # Gradient Descent
    with torch.no_grad():
        W1 -= alpha * dL_dW1
        b1 -= alpha * dL_db1
        W2 -= alpha * dL_dW2
        b2 -= alpha * dL_db2

    Z1_val = Xtest.view(-1, input_size) @ W1 + b1
    A1_val = torch.tanh(Z1_val)
    Z2_val = A1_val @ W2 + b2
    val_loss = -torch.nn.functional.log_softmax(Z2_val, dim=1)[torch.arange(len(Ytest)), Ytest.squeeze()].mean()
    val_acc = (Z2_val.argmax(dim=1) == Ytest.squeeze()).float().mean().item() * 100
    print(f"Epoch {epoch}, Loss: {L.item()}, Validation Loss: {val_loss.item()}")

Epoch 0, Loss: 0.0004093652532901615, Validation Loss: 3.2141315937042236
Epoch 1, Loss: 0.0004091386799700558, Validation Loss: 3.214205265045166
Epoch 2, Loss: 0.00040894787525758147, Validation Loss: 3.2142786979675293
Epoch 3, Loss: 0.0004087093402631581, Validation Loss: 3.2143521308898926
Epoch 4, Loss: 0.0004085304099135101, Validation Loss: 3.2144253253936768
Epoch 5, Loss: 0.0004083276726305485, Validation Loss: 3.214498519897461
Epoch 6, Loss: 0.00040812493534758687, Validation Loss: 3.2145721912384033
Epoch 7, Loss: 0.0004079221689607948, Validation Loss: 3.2146449089050293
Epoch 8, Loss: 0.00040770749910734594, Validation Loss: 3.2147185802459717
Epoch 9, Loss: 0.00040748092578724027, Validation Loss: 3.214791774749756
Epoch 10, Loss: 0.00040731392800807953, Validation Loss: 3.214864730834961
Epoch 11, Loss: 0.0004071171279065311, Validation Loss: 3.214937925338745
Epoch 12, Loss: 0.0004069143906235695, Validation Loss: 3.2150111198425293
Epoch 13, Loss: 0.00040667588473297

In [None]:
plt.plot(losses)

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for i in range(len(Xtest)):
        # Forward pass
        Z_test = Xtest[i].view(1, -1) @ W1 + b1
        A1_test = torch.tanh(Z_test)
        Z2_test = A1_test @ W2 + b2
        predicted_class = torch.argmax(Z2_test, dim=1)

        if predicted_class == Ytest[i]:
            correct += 1
        total += 1
accuracy = correct / total
print(f"Test Accuracy: {accuracy}")
