In [None]:
import sys
import numpy as np
import cv2
import tensorflow as tf
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QLabel, QVBoxLayout, QHBoxLayout
from PyQt5.QtGui import QPainter, QPen, QImage
from PyQt5.QtCore import Qt, QPoint
import matplotlib.pyplot as plt

class DrawApp(QWidget):
    def __init__(self):
        super().__init__()

        self.setWindowTitle("Draw a Number")
        self.setGeometry(100, 100, 400, 450)
        self.setStyleSheet("background-color: white;")

        self.image = QImage(400, 400, QImage.Format_RGB32)
        self.image.fill(Qt.white)

        self.clear_button = QPushButton("Clear", self)
        self.predict_button = QPushButton("Predict", self)
        self.result_label = QLabel("Prediction: ", self)
        self.quit_button = QPushButton("Quit", self)

        self.model = tf.keras.models.load_model("model.h5", compile=False)
        self.model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

        self.drawing = False
        self.last_point = QPoint()

        self.init_ui()

    def init_ui(self):
        button_layout = QHBoxLayout()
        button_layout.addWidget(self.result_label) # add the label to the button layout
        button_layout.addWidget(self.clear_button)
        button_layout.addWidget(self.predict_button)
        button_layout.addWidget(self.quit_button)

        main_layout = QVBoxLayout(self)
        main_layout.addStretch(1) # Add a stretch to push the buttons down
        main_layout.addLayout(button_layout)

        self.clear_button.clicked.connect(self.clear_canvas)
        self.predict_button.clicked.connect(self.predict_digit)
        self.quit_button.clicked.connect(QApplication.instance().quit)

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.drawing = True
            self.last_point = event.pos()

    def mouseMoveEvent(self, event):
        if self.drawing and event.buttons() & Qt.LeftButton:
            painter = QPainter(self.image)
            pen = QPen(Qt.black, 12, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
            painter.setPen(pen)
            painter.drawLine(self.last_point, event.pos())
            self.last_point = event.pos()
            self.update()

    def mouseReleaseEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.drawing = False

    def paintEvent(self, event):
        canvas = QPainter(self)
        canvas.drawImage(self.rect(), self.image, self.image.rect())

    def clear_canvas(self):
        self.image.fill(Qt.white)
        self.update()

    def preprocess_image(self):
        img = self.image.convertToFormat(QImage.Format_Grayscale8)
        buffer = img.bits()
        buffer.setsize(400 * 400)
        img_array = np.frombuffer(buffer, dtype=np.uint8).reshape(400, 400)
        img_array = 255 - img_array

        non_zero_pixels = np.argwhere(img_array != 255)
        if len(non_zero_pixels) > 0:
            (min_y, min_x) = non_zero_pixels.min(axis=0)
            (max_y, max_x) = non_zero_pixels.max(axis=0)
            cropped_digit = img_array[min_y:max_y + 1, min_x:max_x + 1]

            size = max(cropped_digit.shape)
            padded_digit = np.ones((size, size), dtype=np.uint8) * 255

            offset_y = (size - cropped_digit.shape[0]) // 2
            offset_x = (size - cropped_digit.shape[1]) // 2

            padded_digit[offset_y:offset_y + cropped_digit.shape[0], offset_x:offset_x + cropped_digit.shape[1]] = cropped_digit

            img_resized = cv2.resize(padded_digit, (28, 28), interpolation=cv2.INTER_AREA)

        else:
            img_resized = np.ones((28, 28), dtype=np.uint8) * 255

        img_resized = img_resized.astype("float32") / 255.0
        img_resized = np.expand_dims(img_resized, axis=(0, -1))

        plt.imshow(img_resized[0, :, :, 0], cmap="gray")
        plt.title("Processed Image for Prediction")
        plt.colorbar()
        plt.show()

        return img_resized

    def predict_digit(self):
        processed_img = self.preprocess_image()
        prediction = self.model.predict(processed_img)
        print("Raw Prediction:", prediction)
        digit = np.argmax(prediction)
        self.result_label.setText(f"Prediction: {digit}")

    def show_processed_image(image):
        if image is not None:
            plt.imshow(image, cmap='gray')
            plt.title("Captured Input from Canvas")
            plt.show()
        else:
            print("No image received from the frontend!")

    def closeEvent(self, event):
        print("Closing the application...")
        event.accept()

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = DrawApp()
    window.show()
    sys.exit(app.exec_())