# MNIST Digit recognizer Demo
Welcome to the demo of my MNIST digit recognizer model. The model is loaded from a text file of tensors which were saved by the "mnist_full_scratch_v2.ipynb" after several goes at training on which the model obtained 95.92% accuracy on a validation data set. Model was created from scratch by Sasha Kiselev

With this webapp you are able to upload your own handwritten digit (between 0 and 9) and see how it is recognized. Please make sure to crop you picture to have it be approximatly square and for the digit to take up ~75% of the image. For best results upload a digit written in black ink on black paper.

Note: as probabilities of each example are distributed over all the numbers, the network is unlikely to give a high probability even when very confident. Any probability over 40% means that the network is very confident, while around 10% it is guessing.

In [58]:
# Imports
from fastai.vision.all import *
from fastai.vision.widgets import widgets, VBox
from torchvision.transforms.functional import to_grayscale
from torchvision.transforms import transforms

In [59]:
# Load model parameters
params = torch.load("95.92%_acc_parameters.pt")
w1, b1, w2, b2, w3, b3 = params

In [60]:
def softmax(x):
    x = torch.exp(x)
    sum_exp = x.sum(dim = 1, keepdims = True)
    return x / sum_exp

In [61]:
def tanh(x):
    return (torch.exp(2*x) - 1) / (torch.exp(2*x) + 1)

In [62]:
# Setup forward prop
def model(x):
    # basic linear relu linear relu model
    res = (x @ w1 + b1).max(tensor(0.))
    res = (res @ w2 + b2).max(tensor(0.))
    res = tanh(res @ w3 + b3) # added to prevent exploding gradients
    return softmax(res) # softmax now added to model, as that is what we will use to guess the users digit

In [63]:
def digit_classifier(image):
    img_tns = tensor(1.) - tensor(image).view(1, 28*28).float()/255
    img_tns = ((img_tns - img_tns.mean()).max(tensor(0.)) * 5).min(tensor(1.))
    probs = model(img_tns)
    digit = probs.argmax().item()
    prob = probs[:,digit].item()
    return digit, prob, img_tns

In [64]:
def greyscale_resize(image):
    rsz = Resize(28, method='squish')
    image = to_grayscale(image)
    image_small = rsz(image)
    return image_small

In [65]:
def tns_to_img_upscale(img_tns):
    img = transforms.ToPILImage(mode="L")(img_tns.view(28, 28))
    rsz = Resize(256, method='squish')
    img = rsz(img)
    return img

In [66]:
### SETUP THE USER GUI
upload = widgets.FileUpload()
img_display = widgets.Output()
img_label = widgets.Label()
prediction = widgets.Label()
class_btn = widgets.Button(description = "Classify my digit")
def on_click_class(change):
    image = PILImage.create(upload.data[-1])
    image_small = greyscale_resize(image)
    digit, probability, img_tns = digit_classifier(image_small)
    image = tns_to_img_upscale(img_tns)
    img_display.clear_output()
    with img_display: display(image.to_thumb(256, 256))
    
    #pred, idx, prob = model_inf.predict(image)
    img_label.value = "Here is what I see:"
    prediction.value = f"I am guessing this is a {digit} with a probability of: \
                        {probability * 100:.02f}% "
    
class_btn.on_click(on_click_class)

In [67]:
VBox([widgets.Label("What digit did I write? Let me see!"), upload, class_btn, img_label, img_display, prediction])


VBox(children=(Label(value='What digit did I write? Let me see!'), FileUpload(value={}, description='Upload'),…