In [1]:
"""
Parte del codice è tratto dal seguente sito:
https://keras.io/examples/vision/mnist_convnet/
"""

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

from PIL import Image


# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [2]:
# Definisci il modello
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.1),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 1600)              0         
                                                                 
 dropout (Dropout)           (None, 1600)              0

In [3]:
batch_size = 128
epochs = 5

# Per la scelta della loss function, vedi https://keras.io/api/losses/
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

In [4]:
# Accuracy iniziale

score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

Test loss: 2.29482364654541
Test accuracy: 0.14710000157356262


In [5]:
# Se hai già salvato il modello, caricalo.
# Altrimenti, addestra il modello e poi salvalo

try:
    model = keras.models.load_model('mnist_model')
except:
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    model.save('mnist_model')

In [6]:
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

Test loss: 0.02450031042098999
Test accuracy: 0.9909999966621399


In [7]:
def get_prediction(model, data):
    predictions = model.predict(data.reshape((1,28,28,1)))
    # print(predictions[0])
    print( np.where(predictions[0] == max(predictions[0]))[0][0] )
    # plt.imshow(data, cmap='gray')

def load_and_normalize_data():
    image = Image.open('image.png').convert('L').resize((28,28))
    data = np.asarray(image)
    data = data.astype("float32") / 255
    return data

def get_label(model, inverse=False):
    data = load_and_normalize_data()
    if inverse:
        data = 1-data
    get_prediction(model, data)


In [8]:
import sys
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *


class Drawer(QWidget):
    def __init__(self, parent=None):
        QWidget.__init__(self, parent)
        self.setAttribute(Qt.WA_StaticContents)
        h = 280
        w = 280
        self.myPenWidth = 30
        self.myPenColor = Qt.black
        self.image = QImage(w, h, QImage.Format_RGB32)
        self.path = QPainterPath()
        self.clearImage()

    def setPenColor(self, newColor):
        self.myPenColor = newColor

    def setPenWidth(self, newWidth):
        self.myPenWidth = newWidth

    def clearImage(self):
        self.path = QPainterPath()
        self.image.fill(Qt.white)  ## switch it to else
        self.update()

    def saveImage(self, fileName, fileFormat):
        self.image.save(fileName, fileFormat)

    def paintEvent(self, event):
        painter = QPainter(self)
        painter.drawImage(event.rect(), self.image, self.rect())

    def mousePressEvent(self, event):
        self.path.moveTo(event.pos())

    def mouseMoveEvent(self, event):
        self.path.lineTo(event.pos())
        p = QPainter(self.image)
        p.setPen(QPen(self.myPenColor,
                      self.myPenWidth, Qt.SolidLine, Qt.RoundCap,
                      Qt.RoundJoin))
        p.drawPath(self.path)
        p.end()
        self.update()

    def sizeHint(self):
        return QSize(300, 300)


In [9]:
# if __name__ == '__main__':
app = QApplication([]) #sys.argv
w = QWidget()
btnSave = QPushButton("Save image")
btnClear = QPushButton("Clear")
# btnGetLabel = QPushButton("Get label")
# btnGetLabelInverse = QPushButton("Get label inverse image")
btnGetLabelInverse = QPushButton("Get label")
drawer = Drawer()

w.setLayout(QVBoxLayout())
w.layout().addWidget(btnSave)
# w.layout().addWidget(btnGetLabel)
w.layout().addWidget(btnGetLabelInverse)
w.layout().addWidget(btnClear)
w.layout().addWidget(drawer)


btnSave.clicked.connect(lambda: drawer.saveImage("image.png", "PNG"))
#btnGetLabel.clicked.connect(lambda: get_label(model))
btnGetLabelInverse.clicked.connect(lambda: get_label(model, inverse=True))
btnClear.clicked.connect(drawer.clearImage)

w.show()
app.exec_()
# sys.exit(app.exec_())

0
9
8
7
5
6
8
9


0