In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

from tkinter import *
import tkinter as tk
import win32gui
from PIL import ImageGrab, Image


In [4]:
model = load_model('mnist-2-2.h5')


In [5]:

def predict_digit(img):
    #resize image to 56x28 pixels
    img = img.resize((56,28))

    #convert rgb to grayscale
    img = img.convert('L')
    img = np.array(img)

    #     print('Image to be predicted...')
    #     plt.imshow(img, cmap = 'gray')
    #     plt.show()

    #reshaping to support our model input and normalizing
    img = img.reshape(1,28,56,1).astype(float)
    img = img/255.0


    #predicting the class
    #     the model predicts in batches, since we are just inferencing on one image threfore we just take the first array of
    #     ouptut in our variable 'res'.
    res = model.predict([img])[0]
    indices = res.argsort()[-2:][::-1]
    res = [(indices[0], res[indices[0]]), (indices[1], res[indices[1]])]
    return res[0], res[1]

class App(tk.Tk):
    def __init__(self):
        tk.Tk.__init__(self)

        self.x = self.y = 0

        # Creating elements
        self.title('Handwritten Digit Recognition')
        self.canvas = tk.Canvas(self, width=500, height=400, bg = 'black', cursor="cross")
        self.label1 = tk.Label(self, text="Draw..", font=("Helvetica", 48))

        self.label2 = tk.Label(self, text="Draw..", font=("Helvetica", 48))

        self.classify_btn = tk.Button(self, text = "Recognise", command = self.classify_handwriting)
        self.button_clear = tk.Button(self, text = "Clear", command = self.clear_all)

        # Grid structure
        self.canvas.grid(row=0, column=0, pady=0, sticky=W, rowspan = 2)
        self.label1.grid(row=0, column=1,pady=0, padx=0)

        self.label2.grid(row = 1, column = 1, pady = 0, padx = 0)

        self.classify_btn.grid(row=2, column=1, pady=2, padx=2)
        self.button_clear.grid(row=2, column=0, pady=2)

        self.canvas.bind("<B1-Motion>", self.draw_lines) # event handler for mouse events

    def clear_all(self):
        self.canvas.delete("all")

    def classify_handwriting(self):
        # code to convert drawing on canvas to an image
        HWND = self.canvas.winfo_id()        # get the handle of the canvas
        rect = win32gui.GetWindowRect(HWND)  # get the coordinate of the canvas
        a,b,c,d = rect
        rect=(a+4, b+4, c-4, d-4)            # padding to the image
        im = ImageGrab.grab(rect)

        # predict what the image is...
        class_1, class_2 = predict_digit(im)
        digit_1, acc_1 = class_1      #it holds the larger variable class_1
        digit_2, acc_2 = class_2

        self.label1.configure(text = str(digit_1)+ ',' + str(int(acc_1 * 100)) + '%\n')
        if acc_2 > 0.1:
            self.label2.configure(text = str(digit_2)+ ',' + str(int(acc_2 * 100)) + '%\n')
        else:
            self.label2.configure(text = '!')

    def draw_lines(self, event):
        # draw on the canvas
        self.x = event.x
        self.y = event.y
        r= 9                                # control the width of strokes
        self.canvas.create_oval(self.x+r, self.y+r, self.x - r, self.y - r, fill='white',outline = 'white')

app = App()
mainloop()