In [None]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import v2
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt 
import numpy as np


resize = (28, 28)
input_size = 28 * 28
batch_size = 16
learning_rate = 0.01
epochs = 10

classes_count = 10

In [None]:
transf = v2.Compose([
    v2.Resize(resize),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

target_transf = lambda y: F.one_hot(torch.tensor(y), num_classes=classes_count).float()

dataset_train = MNIST('./resources', transform=transf, target_transform=target_transf, train=True, download=True)
dataset_test = MNIST('./resources', transform=transf, target_transform=target_transf, train=False, download=True)

dataloader_train = DataLoader(dataset_train, batch_size, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size, shuffle=False)


In [None]:
class Model:
    def __init__(self):
        self.w = torch.normal(0.0, 0.01, (input_size, classes_count), requires_grad=True) 
        self.b = torch.normal(0.0, 0.01, (classes_count, ), requires_grad=True) 

    # 16 . 784 x 784 . 10
    # takes batch of 28 x 28 matrices
    # returns 10 outputs
    def __call__(self, X):
        X = X.reshape(X.shape[0], -1)
        return X @ self.w + self.b
    
    def zero_grad(self):
        self.w.grad.zero_()
        self.b.grad.zero_()

    def update(self):
        with torch.no_grad():
            self.w -= self.w.grad * learning_rate
            self.b -= self.b.grad * learning_rate

In [None]:
def softmax(x):
    exp_sum = torch.sum(torch.exp(x))
    result = torch.exp(x) / exp_sum
    return result

def cross_entropy_loss(y, y_hat):
    """
        y       - OHE vector of labels
        y_hat   - raw values predicted from the model 
    """
    probs = softmax(y_hat)
    return -torch.mean(torch.log(probs) * y)


In [None]:
def train(model, loss_fn, dataset):

    total_loss = 0
    
    for i, (X, y) in enumerate(dataset):
        pred = model(X)
        loss = loss_fn(y, pred)
        loss.backward()
        model.update()
        model.zero_grad()
        total_loss += loss.item()

        if i % 500 == 0:
            print(f'Current train loss: {total_loss / ((i + 1) )}')

def test(model, loss_fn, dataset):
    with torch.no_grad():
        full_loss = 0
        correct_predictions = 0

        for i, (X, y) in enumerate(dataset):
            pred = model(X)
            loss = loss_fn(y, pred)
            full_loss += loss
            correct_predictions += torch.sum(torch.argmax(pred, dim=1) == torch.argmax(y, dim=1))
            

        print(f'Test loss: {full_loss / len(dataset)}')
        print(f'Test accuracy: {correct_predictions / len(dataset.dataset)}')


In [None]:
model = Model()

for i in range(epochs):
    print(f'EPOCH: {i + 1}')
    train(model, cross_entropy_loss, dataloader_train)
    test(model, cross_entropy_loss, dataloader_test)

In [None]:
r = torch.randint(0, len(dataset_train), (1, )).item()
img = torch.reshape(dataset_train[r][0], (28, 28))
plt.imshow(img)
plt.show()
peasant_input = int(input("What number is this?"))

result = model(torch.reshape(dataset_train[r][0], (1, 1, 28, 28)))
rr = torch.argmax(result).item()
real_label = torch.argmax(dataset_train[r][1]).item()
print(f'Meatball peasant suggestion: {peasant_input}')
print(f'My Awesome Model\'s prediction: {rr}')
print(f"Actual label: {real_label}")
if peasant_input != real_label and rr == real_label:
    print('Literally worse than a machine - LUZER!!!')
    print('Literally worse than a machine - LUZER!!!')
    print('Literally worse than a machine - LUZER!!!')
    print('Literally worse than a machine - LUZER!!!')
    print('Literally worse than a machine - LUZER!!!')
    print('Literally worse than a machine - LUZER!!!')