In [None]:
import sys
import os

In [None]:
sys.path.append(os.path.dirname(os.getcwd()))

In [None]:
from src.deep_learning.datahandling.dataset import DataLoader, CSVImageDataset
from src.deep_learning.optimizer.optimizer import Momentum
from src.deep_learning.RGrad.transform import ReLUBlock, Flatten, Transform
import src.deep_learning.RGrad.function
from src.deep_learning.RGrad.function import cross_entropy

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
dataset_path = os.path.join(os.path.dirname(os.getcwd()), 'datasets')

error_string = """datasets not found. 
The train and test mnist datasets in csv form can be downloaded from: 
https://www.kaggle.com/datasets/oddrationale/mnist-in-csv?resource=download.
They should be saved as mnist_train.csv and mnist_test.csv under the top-level datasets directory"""
try:
    mnist_train_dataset = CSVImageDataset(os.path.join(dataset_path, 'mnist_train.csv'), (28, 28))
    mnist_test_dataset = CSVImageDataset(os.path.join(dataset_path, 'mnist_test.csv'), (28, 28))
except FileNotFoundError:
    raise FileNotFoundError(error_string)

In [None]:
batch_size = 16
train_dataloader = DataLoader(mnist_train_dataset, 16, shuffle=True)
test_dataloader = DataLoader(mnist_test_dataset, 16, shuffle=False)

In [None]:
class MLP(Transform):

    def __init__(self):
        self.flatten = Flatten()
        self.relu1 = ReLUBlock(784, 60)
        self.relu2 = ReLUBlock(60, 60)
        self.relu3 = ReLUBlock(60, 10)
    
    def __call__(self, inpt):
        return self.relu3(self.relu2(self.relu1(self.flatten(inpt))))

In [None]:
model = MLP()
optimizer = Momentum(model.params(), 0.9, 0.01)

In [None]:
def get_accuracy(model, dataloader):
    num_right = 0
    num_wrong = 0
    for inpt, labels in dataloader:
        logits = model(inpt)
        predictions = np.argmax(logits.elems, axis=1)
        num_right_batch = np.sum(predictions==labels.elems)
        num_right += num_right_batch
        num_wrong += (len(predictions) - num_right_batch)
    return num_right/(num_right+num_wrong), num_right, num_wrong

In [None]:
accuracy, num_right, num_wrong = get_accuracy(model, test_dataloader)
print(f'accuracy: {accuracy}')

In [None]:
losses = []

pbar = tqdm(total=mnist_train_dataset.num_datapoints()//batch_size)
for inpt, labels in train_dataloader:
    logits = model(inpt)
    loss = cross_entropy(logits, labels)
    loss.backward()
    losses.append(loss.elems)
    optimizer.update()
    loss.zero_grads()
    pbar.update(1)

In [None]:
test_accuracy, num_right, num_wrong = get_accuracy(model, test_dataloader)
train_accuracy, num_right, num_wrong = get_accuracy(model, train_dataloader)
print(f'test accuracy: {test_accuracy}')
print(f'train accuracy: {train_accuracy}')

In [None]:
averaged_losses = [sum(losses[index-9:index+1])/10 if index >= 9 else sum(losses[:index+1])/(index+1) for index, loss in enumerate(losses)]
plt.plot(np.arange(len(losses)), averaged_losses)
plt.xlabel('batch number')
plt.ylabel('size 10 moving average batch loss')