In [None]:
import PIL
import cv2
import glob
import os
from tkinter import *
from PIL import Image, ImageDraw, ImageGrab
# let's keep our keras backend tensorflow quiet
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'

# imports for array-handling and plotting
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

In [None]:
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, y_train.shape)

In [None]:
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)
num_classes = 10
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

In [None]:
batch_size = 128
num_classes = 10
epochs = 10
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adadelta(),metrics=['accuracy'])

In [None]:
hist = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test))
print("The model has successfully trained")
model.save('mnist.h5')

In [None]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

In [None]:
save_dir = "models/"
model_name = 'mnist.h5'
model_path = os.path.join(save_dir, model_name)
model = load_model(model_name)

In [None]:
root = Tk()
root.resizable(0,0)
root.title("Handwritten Digit Recognition GUI App")

lastx, lasty = None, None
image_number = 0

def clear_widget():
    global cv
    cv.delete("all")
    
def activate_event(event):
    global lastx, lasty
    cv.bind('<B1-Motion>', draw_lines)
    lastx, lasty = event.x, event.y
    

def draw_lines(event):
    global lastx, lasty
    x,y = event.x, event.y
    cv.create_line((lastx, lasty, x,y), width=8, fill="black",
                   capstyle=ROUND, smooth=True, splinesteps=12)
    lastx, lasty = x,y
    
    
def Recognize_Digit():
    save_dir = "models/"
    model_name = 'mnist.h5'
    model_path = os.path.join(save_dir, model_name)
    model = load_model(model_name)
    global image_number
    predictions = []
    percentage = []
    filename = f'image_{image_number}.png'
    widget = cv
    
    x = root.winfo_rootx()+widget.winfo_x()
    y = root.winfo_rooty()+widget.winfo_y()
    x1 = x+widget.winfo_width()
    y1 = y+widget.winfo_height()
    
    ImageGrab.grab().crop((x,y,x1,y1)).save(filename)
    print("Saved")
    
    image = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
    
    img_pil = Image.fromarray(image)
    img = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)    
    img = img.reshape(1,28,28,1)
    img = img/255.0
    pred = model.predict([img])[0]
    final_pred = np.argmax(pred)
    data = str(final_pred)+ ' ' + str(int(max(pred)*100)) + '%'
    print(data)
    font = cv2.FONT_HERSHEY_SIMPLEX
    fontScale = 0.5
    color = (255,0,0)
    thickness=1
    cv2.putText(image, data, (24,46), font, fontScale, color, thickness)
    cv2.waitKey(0)


cv = Canvas(root, width=640, height=480, bg="white")
cv.grid(row=0, column=0, pady=2, sticky=W, columnspan=2)

cv.bind('<Button-1>', activate_event)

btn_save = Button(text = "Recognize Digit", command=Recognize_Digit)
btn_save.grid(row=2, column=0, pady=1, padx=1)
button_clear = Button(text="Clear Widget", command = clear_widget)
btn_save.grid(row=2, column=1, pady=1, padx=1)

root.mainloop()