In [9]:
import tensorflow as tf
from tensorflow import keras
import cv2
import numpy as np

After importing the required libraries, we next load the MNIST data set to train our digit recognition model
This dataset has thousands of 28x28 black and white digit images.

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize pixel vals to doubles in range [0, 1]
x_train = x_train / 255.0
x_test = x_test / 255.0

Next, we define our digit recognition model to take in 28x28 pixel images
The ReLU function optimizes the hidden layer while Softmax optimizes the output layer (good for multi-class classification)

In [10]:
digit_recognition_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

Finally, we can train the model and save the final product.

In [5]:
# Train model 
digit_recognition_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 15 epochs is near ideal for MNIST dataset
digit_recognition_model.fit(x_train, y_train, epochs=5)

test_loss, test_accuracy = digit_recognition_model.evaluate(x_test, y_test)

digit_recognition_model.save("trained_digit_model.h5")

print('Test Loss:', test_loss)
print('Test Accuracy:', test_accuracy)
print(digit_recognition_model.input_shape)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Test Loss: 0.052692003548145294
Test Accuracy: 0.9847000241279602
(None, 28, 28, 1)


Now, let's try running our model on the following example digit image:

![digit2.jpg](attachment:digit2.jpg)

In [6]:
ex_img = cv2.imread('img_1.jpg')
ex_img = cv2.cvtColor(ex_img, cv2.COLOR_BGR2GRAY)
print (ex_img.shape)

ex_img = ex_img.reshape(1, 28, 28)

(28, 28)


In [8]:
digit_probabilities = digit_recognition_model.predict(ex_img)

print(digit_probabilities)

digit = np.argmax(digit_probabilities)

print("Classified digit is: " + str(digit))


[[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]]
Classified digit is: 2
