In [13]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pathlib
from data import get_mnist



In [14]:
images, labels = get_mnist()
w_i_h = np.random.uniform(-0.5, 0.5, (20, 784))
w_h_o = np.random.uniform(-0.5, 0.5, (10, 20))
b_i_h = np.zeros((20, 1))
b_h_o = np.zeros((10, 1))

In [16]:
print(images.shape)

(60000, 784)


In [18]:
learn_rate = 0.01
nr_correct = 0
epochs = 3

In [19]:
for epoch in range(epochs):
    for img, l in zip(images, labels):
        img.shape += (1,)
        l.shape += (1,)
        # Forward propagation input -> hidden
        h_pre = b_i_h + w_i_h @ img
        h = 1 / (1 + np.exp(-h_pre))
        # Forward propagation hidden -> output
        o_pre = b_h_o + w_h_o @ h
        o = 1 / (1 + np.exp(-o_pre))

        # Cost / Error calculation
        e = 1 / len(o) * np.sum((o - l) ** 2, axis=0)
        nr_correct += int(np.argmax(o) == np.argmax(l))

        # Backpropagation output -> hidden (cost function derivative)
        delta_o = o - l
        w_h_o += -learn_rate * delta_o @ np.transpose(h)
        b_h_o += -learn_rate * delta_o
        # Backpropagation hidden -> input (activation function derivative)
        delta_h = np.transpose(w_h_o) @ delta_o * (h * (1 - h))
        w_i_h += -learn_rate * delta_h @ np.transpose(img)
        b_i_h += -learn_rate * delta_h

    # Show accuracy for this epoch
    print(f"Acc: {round((nr_correct / images.shape[0]) * 100, 2)}%")
    nr_correct = 0

Acc: 85.39%
Acc: 92.63%
Acc: 93.69%


In [20]:
# Show results
def convert_img_to_mnist(path):
    image = Image.open(path)
    
    # Resize the image to 28x28 pixels
    image = image.resize((28, 28))
    # Convert the image to grayscale
    image = image.convert('L')

    image_array = np.array(image).flatten()

    return image_array 


In [32]:
image_array = convert_img_to_mnist(f'./img/8.jpg')
h_pre = b_i_h + w_i_h @ image_array.reshape(784, 1)
h = 1 / (1 + np.exp(-h_pre))
o_pre = b_h_o + w_h_o @ h
o = 1 / (1 + np.exp(-o_pre))
print(o.argmax())

3


  h = 1 / (1 + np.exp(-h_pre))
