In [6]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from typing import Callable
from tensorflow.keras.models import load_model
import cv2
import pygame

# Load data
def load_data(data_dir):
    images = []
    labels = []
    for split in ['train', 'test']:
        img_dir = os.path.join(data_dir, split, 'images')
        label_dir = os.path.join(data_dir, split, 'labels')
        
        for img_file in os.listdir(img_dir):
            if img_file.endswith('.png'):
                img_path = os.path.join(img_dir, img_file)
                label_path = os.path.join(label_dir, img_file.replace('.png', '.txt'))
                
                # Load image
                img = load_img(img_path, color_mode='grayscale')
                img = img_to_array(img)
                images.append(img)
                
                # Load label
                with open(label_path, 'r') as f:
                    label_lines = f.readlines()
                    label_data = []
                    for line in label_lines:
                        parts = line.strip().split(',')
                        label_data.append([int(parts[0]), float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])])
                    labels.append(label_data)
    
    images = np.array(images)
    labels = np.array(labels)
    return images, labels

# Define model
def create_model(input_shape):
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, activation='relu')(inputs)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(64, 3, activation='relu')(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(128, 3, activation='relu')(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(512, activation='relu')(x)
    digit_output = layers.Dense(10, activation='softmax', name='digit')(x)
    bbox_output = layers.Dense(4, activation='sigmoid', name='bbox')(x)
    
    model = keras.Model(inputs=inputs, outputs=[digit_output, bbox_output])
    return model

# Training
data_dir = './data/mnist_detection/'
images, labels = load_data(data_dir)
input_shape = images.shape[1:]

model = create_model(input_shape)
model.compile(optimizer='adam', 
              loss={'digit': 'sparse_categorical_crossentropy', 'bbox': 'mse'},
              metrics={'digit': 'accuracy', 'bbox': 'mse'})

# Prepare labels for training
digits = np.array([label[0][0] for label in labels])
bboxes = np.array([label[0][1:] for label in labels])

model.fit(images, {'digit': digits, 'bbox': bboxes}, epochs=10, batch_size=32)
model.save('mnist_object_detector.keras')


ValueError: The filepath provided must end in `.keras` (Keras model format). Received: filepath=saved_model\object_detection_model.h5

In [None]:
model = load_model('mnist_object_detector.keras')

def predictor(img: np.ndarray) -> np.ndarray:
    img = cv2.resize(img, (input_shape[0], input_shape[1]))
    img = img / 255.0  # Normalize
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    prediction = model.predict(img)
    return prediction

def predictor_formatter(img: np.ndarray) -> np.ndarray:
    # Convert image to grayscale and resize to model input size
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    return img


In [None]:
def constant_paint_program(window_title: str, window_icon_path: str | None, predictor: Callable[[np.ndarray],np.ndarray], predictor_formatter: Callable[[np.ndarray],np.ndarray], width: int, height: int, scale: int, fps: int = 60, graph_width: int = 280, graph_border_width: int = 2, blank_color: tuple[int,int,int] = (0, 0, 0), draw_color: tuple[int,int,int] = (255, 255, 255), graph_bg_color: tuple[int,int,int] = (64, 64, 64), graph_color: tuple[int,int,int] = (30, 42, 92), graph_text_color: tuple[int,int,int] = (32, 34, 46), graph_percent_text_color: tuple[int,int,int] = (25, 26, 31), init_pygame: bool = False, quit_pygame: bool = False) -> None:   
    """
    Opens a paint program and lets the user draw.
    LC: draw
    RC: erase
    MC: clear
    ESC/quit: save  
    Args:
        window_title (str): The window title.
        window_icon_path (str | None): The window icon's path.
        predictor (Callable[[np.ndarray],np.ndarray]): The predictor function.
        predictor_formatter (Callable[[np.ndarray],np.ndarray]): The function to format the array for the predictor.
        width (int): How many pixels on the width? (not including the prediction graph)
        height (int): How many pixels on the height?
        scale (int): What should width and height be multiplied by before being shown on the screen?
        fps (int, optional): Frames Per Second. Defaults to 60.
        graph_width (int, optional): _description_. Defaults to 280.
        graph_border_width (int, optional): _description_. Defaults to 2.
        blank_color (tuple[int,int,int], optional): The default/erase color in RGB. Defaults to (0, 0, 0).
        draw_color (tuple[int,int,int], optional): The draw color in RGB. Defaults to (255, 255, 255).
        graph_bg_color (tuple[int,int,int], optional): The background color for the graph in RGB. Defaults to (64, 64, 64).
        graph_color (tuple[int,int,int], optional): The fill-in color for the graph in RGB. Defaults to (30, 42, 92).
        graph_text_color (tuple[int,int,int], optional): The text color for the graph in RGB. Defaults to (32, 34, 46). 
        graph_percent_text_color (tuple[int,int,int], optional): The text color for the percent of the graph in RGB. Defaults to (25, 26, 31). 
        init_pygame (bool, optional): Should it run pygame.init()? Defaults to False.
        quit_pygame (bool, optional): Should it run pygame.quit()? Defaults to False.
    """
    
    if init_pygame:
        pygame.init()

    screen: pygame.surface.Surface = pygame.display.set_mode((width*scale+graph_width, height*scale))
    fpsClock = pygame.time.Clock()

    pygame.display.set_caption(window_title)
    if window_icon_path != None:
        pygame.display.set_icon(pygame.image.load(window_icon_path))

    # Track mouse button states
    mouse_draw_down: bool = False
    mouse_erase_down: bool = False

    font = pygame.font.Font(None, 30)
    percent_font = pygame.font.Font(None, 20)

    num_render: dict[int,None|pygame.surface.Surface] = {i:None for i in range(0,10)}
    for num in num_render.keys():
        num_render[num] = font.render(str(num),False,graph_text_color)

    percent_render: dict[str,None|pygame.surface.Surface] = {str(i):None for i in range(0,101)}
    for num in percent_render.keys():
        percent_render[num] = percent_font.render(f"{num}%",False,graph_percent_text_color)

    slot_centers: list[tuple[int,int]] = []
    for slot_y in range(0,screen.get_height()-10,int((screen.get_height()-10)/10)):
        slot_centers.append((int(((graph_width-np.floor(graph_width/100)*100)*0.75)/2)+width*scale,slot_y+20))

    percent_centers: list[tuple[int,int]] = []
    for slot_y in range(0,screen.get_height()-10,int((screen.get_height()-10)/10)):
        percent_centers.append((int(((graph_width-np.floor(graph_width/100)*100)*0.75)+((np.floor(graph_width/100)*100)/2))+width*scale,slot_y+20))

    # Game loop
    running: bool = True
    screen.fill(blank_color)

    pixels: np.ndarray = pygame.surfarray.array3d(screen)[0:width*scale]
    grid_pixels: np.ndarray = np.swapaxes(pixels[::scale, ::scale],0,1)
    prediction: np.ndarray = predictor(predictor_formatter(grid_pixels))


    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running: bool = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_ESCAPE:
                    running: bool = False
            
            # Check for mouse button press/release
            if event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 1:
                    mouse_draw_down: bool = True
                elif event.button == 2:
                    screen.fill(blank_color)
                elif event.button == 3:
                    mouse_erase_down: bool = True
            elif event.type == pygame.MOUSEBUTTONUP:
                if event.button == 1:
                    mouse_draw_down: bool = False
                if event.button == 3:
                    mouse_erase_down: bool = False
                if mouse_draw_down is False and mouse_erase_down is False:
                    pixels: np.ndarray = pygame.surfarray.array3d(screen)[0:width*scale]
                    grid_pixels: np.ndarray = np.swapaxes(pixels[::scale, ::scale],0,1)
                    prediction: np.ndarray = predictor(predictor_formatter(grid_pixels))
                    print(np.argmax(prediction[0]))
                    # Draw bounding boxes
                    for i, bbox in enumerate(prediction[1]):
                        x_min, y_min, x_max, y_max = bbox
                        x_min *= width
                        x_max *= width
                        y_min *= height
                        y_max *= height
                        pygame.draw.rect(screen, (255, 0, 0), (x_min * scale, y_min * scale, (x_max - x_min) * scale, (y_max - y_min) * scale), 2)

            if mouse_draw_down:
                x, y = pygame.mouse.get_pos()
                x: int = np.floor(x / scale) * scale
                y: int = np.floor(y / scale) * scale
                if x < width*scale:
                    pygame.draw.rect(screen, draw_color, (x, y, scale, scale))
            elif mouse_erase_down:
                x, y = pygame.mouse.get_pos()
                x: int = np.floor(x / scale) * scale
                y: int = np.floor(y / scale) * scale
                if x < width*scale:
                    pygame.draw.rect(screen, blank_color, (x, y, scale, scale))

        pygame.draw.rect(screen,graph_bg_color,pygame.Rect(width*scale,0,screen.get_width()-width*scale,screen.get_height()))

        for i, (k, v) in enumerate(prediction[0]):
            text = num_render[k]
            assert text != None
            textRect = text.get_rect()
            textRect.center = slot_centers[i]
            screen.blit(text, textRect)

            bar_y_offset = int((screen.get_height()-10)/10) * i

            pygame.draw.rect(screen, blank_color, pygame.Rect(((graph_width - np.floor(graph_width/100)*100) * 0.75) + (width * scale) - graph_border_width, (10 - graph_border_width) + bar_y_offset, np.floor(graph_width/100) * 100 + graph_border_width * 2, 20 + graph_border_width * 2), graph_border_width)
            pygame.draw.rect(screen, graph_color, pygame.Rect(((graph_width - np.floor(graph_width/100) * 100) * 0.75) + (width * scale), 10 + bar_y_offset, (np.floor(graph_width/100) * 100) * v, 20))

            percent = percent_render[str(int(round(v, 2) * 100))]
            assert percent != None
            textRect = percent.get_rect()
            textRect.center = percent_centers[i]
            screen.blit(percent, textRect)

        pygame.display.flip()
        fpsClock.tick(fps)

    if quit_pygame:
        pygame.quit()
