In [None]:
import pyarrow.parquet as pa
import numpy as np
import pandas as pd
from PIL import Image
import io

# download train and test parquet files from https://huggingface.co/datasets/ylecun/mnist
table = pa.read_table('mnist/train.parquet')
table_test = pa.read_table('mnist/test.parquet')

df = table.to_pandas()
df_test = table_test.to_pandas()

from stefgrad.nn.layers import Linear
from stefgrad.nn.activation import softmax, sigmoid, tanh, relu
from stefgrad.tensor import Tensor

class MLP:
    def __init__(self, nin, nouts: list, hidden_fn, last_fn):
        size = [nin] + nouts
        self.layers = [Linear(size[i], size[i+1], hidden_fn) for i in range(len(nouts) - 1)]
        self.layers += [Linear(size[len(nouts) - 1], size[len(nouts)], last_fn)]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]

    def parameter_shapes(self):
        return [p.data.shape for layer in self.layers for p in layer.parameters()]

model = MLP(784, [512, 256, 10], relu, None)

In [2]:
X = [
    list(Image.open(io.BytesIO(df.iloc[row]['image']['bytes'])).getdata()) for row in range(len(df))
]
Y = [
    [0 if x != df.iloc[row]['label'] else 1 for x in range(10)] for row in range(len(df))
]
X_test = [
    Tensor(list(Image.open(io.BytesIO(df_test.iloc[row]['image']['bytes'])).getdata())) for row in range(len(df_test))
]
Y_test = [
    Tensor([0 if x != df_test.iloc[row]['label'] else 1 for x in range(10)]) for row in range(len(df_test))
]

In [None]:
from stefgrad.tensor import Tensor

model = MLP(784, [512, 256, 10], sigmoid, softmax)

X = [
    Tensor(list(Image.open(io.BytesIO(row['image']['bytes'])).getdata())) for row in df.iloc
]
Y = [
    Tensor([0 if x != row['label'] else 1 for x in range(10)]) for row in df.iloc
]

## non batched
for k in range(len(X)):
    for p in model.parameters():
        p.grad = np.zeros_like(p.data)

    ypred_outputs = model(X[k])

    loss = (-(Y[k] * (ypred_outputs + Tensor(1e-12)).log())).sum()

    loss.backward()

    learning_rate = 0.05
    for p in model.parameters():
        p.data = p.data - learning_rate * p.grad

    if k % 100 == 0:
        print(f"Iteration {k}: loss = {loss.data:.10f}")

print(f"Final loss: {loss.data:.6f}")

In [3]:
## batched training
batch_size = 100
epochs = 3

for epoch in range(3):
    for start in range(0, len(X), batch_size):
        x_batch = Tensor(X[start:start + batch_size])
        y_batch = Tensor(Y[start:start + batch_size])

        for p in model.parameters():
            p.grad = np.zeros_like(p.data)

        # normalize the pixel values
        x = x_batch / 255.0
        y = y_batch / 255.0

        z = model(x) # run model prediction

        logexpsum = z.logsumexp(axis=1) # softmax logsumexp trick
        loss = -(y * logexpsum).sum(axis=1) # cross-entropy loss

        loss.backward() # propagate gradient

        learning_rate = 0.05
        for p in model.parameters():
            p.data = p.data - learning_rate * p.grad # update parameters

print(f"Final loss: {loss.sum().data}")

Final loss: 0.5778870965008347


In [4]:
print("\ntesting accurracy")

accuracy = 0
for i in range(len(X_test)):
    z = model(X_test[i])
    pred = z.logsumexp()
    true_num = int(np.nanargmax(Y_test[i].data))
    pred_num = int(np.nanargmax(pred.data))
    if true_num == pred_num: accuracy += 1
    # print(f"Sample {i}: target={true_num}, prediction={pred_num}")

print(f"Accuracy: {accuracy / len(X_test)}")


testing accurracy
Accuracy: 0.8922


In [7]:
model_not_trained = MLP(784, [512, 256, 10], relu, softmax)

print("\ntesting accurracy")
accuracy = 0
for i in range(len(X_test)):
    z = model_not_trained(X_test[i])
    pred = z.logsumexp()
    true_num = int(np.nanargmax(Y_test[i].data))
    pred_num = int(np.nanargmax(pred.data))
    if true_num == pred_num: accuracy += 1
    # print(f"Sample {i}: target={true_num}, prediction={pred_num}")

print(f"Accuracy: {accuracy / len(X_test)}")


testing accurracy
Accuracy: 0.0787
