In [8]:
import os
import cv2
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import time
import random
from tensorflow.keras import layers, models, optimizers, callbacks
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import pickle

In [2]:
#example usage:
from tensorflow.keras.models import load_model
IMG_SIZE = 128
DATA_DIR = "/mnt/d/MyEverything/PythonProjects/Recent_projects/cnn_analysis/Hand_Drawing/quickdraw_images"
model = load_model("checkpoints/step_90000.keras")
def predict_image(model, image_path):
    """Predict the class of a single image."""
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Image at {image_path} could not be loaded.")

    classes = sorted(os.listdir(DATA_DIR))
    class_to_idx = {cls: i for i, cls in enumerate(classes)}
    #preprocess image like training data
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    img = img.astype(np.float32) / 255.0
    img = np.expand_dims(img, axis=-1)  # Add channel dimension
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    
    #view the image
    plt.imshow(img.squeeze(), cmap='gray')
    plt.axis('off')
    plt.show()
    # Make prediction
    prediction = model.predict(img)
    predicted_class_idx = np.argmax(prediction, axis=1)[0]
    #show top 10 predictions
    top_indices = np.argsort(prediction[0])[::-1][:20]
    top_classes = [list(class_to_idx.keys())[i] for i in top_indices]
    top_probs = prediction[0][top_indices]
    print(f"Predicted class: {list(class_to_idx.keys())[predicted_class_idx]}")
    print("Top 10 predictions:")
    for cls, prob in zip(top_classes, top_probs):
        print(f"  {cls}: {prob:.4f}")
    return predicted_class_idx
# Example usage
image_path = "/mnt/d/MyEverything/PythonProjects/Recent_projects/cnn_analysis/Hand_Drawing/clock.jpg"
#predicted_class_idx = predict_image(model, image_path)


# Note: The above code assumes you have a trained model and the necessary directories set up.
# If you want to run the scribble pad, make sure to uncomment the last line. 

I0000 00:00:1749697834.940313    1728 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3620 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 6GB Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


Left click to draw, right click to erase. Press 'q' to quit.


In [7]:
!pip install tkinter



Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
[31mERROR: Could not find a version that satisfies the requirement tkinter (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for tkinter[0m[31m
[0m

In [1]:
import tkinter as tk
from tkinter import ttk
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageTk
import threading
import time
from tensorflow.keras.models import load_model
import os
import random

class DrawingInterface:
    def __init__(self, model_path, data_dir):
        self.root = tk.Tk()
        self.root.title("Real-time Drawing Prediction")
        self.root.geometry("800x600")
        self.root.configure(bg='#f0f0f0')
        
        # Constants
        self.CANVAS_SIZE = 400  # Display size
        self.IMG_SIZE = 128     # Model input size
        self.BRUSH_SIZE = 8
        
        # Load model and classes
        self.model = load_model(model_path)
        self.classes = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        
        # Drawing state
        self.drawing = False
        self.last_x = None
        self.last_y = None
        
        # Create PIL image for drawing (white background)
        self.image = Image.new('L', (self.CANVAS_SIZE, self.CANVAS_SIZE), 255)
        self.draw = ImageDraw.Draw(self.image)
        
        # Prediction state
        self.prediction_active = True
        self.last_prediction_time = 0
        self.num_predictions = 5  # Default number of predictions
        
        self.setup_ui()
        self.start_prediction_thread()
        
    def setup_ui(self):
        # Main frame
        main_frame = tk.Frame(self.root, bg='#f0f0f0')
        main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Left side - Drawing area
        left_frame = tk.Frame(main_frame, bg='#f0f0f0')
        left_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
        
        # Canvas title
        title_label = tk.Label(left_frame, text="Draw Here", font=('Arial', 16, 'bold'), bg='#f0f0f0')
        title_label.pack(pady=(0, 5))
        
        # Random challenge title
        self.challenge_label = tk.Label(left_frame, text="Click 'New Challenge' to get a drawing prompt!", 
                                       font=('Arial', 14, 'italic'), bg='#f0f0f0', fg='#4ecdc4', 
                                       wraplength=350)
        self.challenge_label.pack(pady=(0, 10))
        
        # Canvas frame with border
        canvas_frame = tk.Frame(left_frame, bg='black', padx=2, pady=2)
        canvas_frame.pack()
        
        # Drawing canvas
        self.canvas = tk.Canvas(canvas_frame, width=self.CANVAS_SIZE, height=self.CANVAS_SIZE, 
                               bg='white', cursor='crosshair')
        self.canvas.pack()
        
        # Bind mouse events
        self.canvas.bind('<Button-1>', self.start_draw)
        self.canvas.bind('<B1-Motion>', self.draw_line)
        self.canvas.bind('<ButtonRelease-1>', self.stop_draw)
        
        # Control buttons
        button_frame = tk.Frame(left_frame, bg='#f0f0f0')
        button_frame.pack(pady=10)
        
        new_challenge_btn = tk.Button(button_frame, text="New Challenge", command=self.new_challenge,
                                     bg='#45b7aa', fg='white', font=('Arial', 12, 'bold'),
                                     padx=20, pady=5)
        new_challenge_btn.pack(side=tk.LEFT, padx=5)
        
        clear_btn = tk.Button(button_frame, text="Clear Canvas", command=self.clear_canvas,
                             bg='#ff6b6b', fg='white', font=('Arial', 12, 'bold'),
                             padx=20, pady=5)
        clear_btn.pack(side=tk.LEFT, padx=5)
        
        toggle_btn = tk.Button(button_frame, text="Pause Predictions", command=self.toggle_predictions,
                              bg='#4ecdc4', fg='white', font=('Arial', 12, 'bold'),
                              padx=20, pady=5)
        toggle_btn.pack(side=tk.LEFT, padx=5)
        self.toggle_btn = toggle_btn
        
        # Right side - Predictions
        right_frame = tk.Frame(main_frame, bg='#f0f0f0', width=300)
        right_frame.pack(side=tk.RIGHT, fill=tk.Y, padx=(20, 0))
        right_frame.pack_propagate(False)
        
        # Predictions title and controls
        pred_title_frame = tk.Frame(right_frame, bg='#f0f0f0')
        pred_title_frame.pack(fill=tk.X, pady=(0, 10))
        
        pred_title = tk.Label(pred_title_frame, text="Live Predictions", font=('Arial', 16, 'bold'), bg='#f0f0f0')
        pred_title.pack(side=tk.LEFT)
        
        # Number of predictions control
        pred_control_frame = tk.Frame(pred_title_frame, bg='#f0f0f0')
        pred_control_frame.pack(side=tk.RIGHT)
        
        tk.Label(pred_control_frame, text="Show top:", font=('Arial', 10), bg='#f0f0f0').pack(side=tk.LEFT)
        
        self.num_pred_var = tk.StringVar(value="5")
        pred_spinbox = tk.Spinbox(pred_control_frame, from_=1, to=20, width=3, 
                                 textvariable=self.num_pred_var, font=('Arial', 10),
                                 command=self.update_prediction_count)
        pred_spinbox.pack(side=tk.LEFT, padx=(5, 0))
        
        # Predictions frame with scrollbar
        pred_outer_frame = tk.Frame(right_frame, bg='#f0f0f0')
        pred_outer_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Create canvas for scrolling
        self.pred_canvas = tk.Canvas(pred_outer_frame, bg='white', relief=tk.RAISED, bd=2)
        scrollbar = tk.Scrollbar(pred_outer_frame, orient="vertical", command=self.pred_canvas.yview)
        self.scrollable_frame = tk.Frame(self.pred_canvas, bg='white')
        
        self.scrollable_frame.bind(
            "<Configure>",
            lambda e: self.pred_canvas.configure(scrollregion=self.pred_canvas.bbox("all"))
        )
        
        self.pred_canvas.create_window((0, 0), window=self.scrollable_frame, anchor="nw")
        self.pred_canvas.configure(yscrollcommand=scrollbar.set)
        
        self.pred_canvas.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")
        
        # Status label
        self.status_label = tk.Label(self.scrollable_frame, text="Start drawing to see predictions!", 
                                    font=('Arial', 11), bg='white', fg='#666')
        self.status_label.pack(pady=20)
        
        # Container for prediction labels
        self.predictions_container = tk.Frame(self.scrollable_frame, bg='white')
        self.predictions_container.pack(fill=tk.BOTH, expand=True)
        
        # Initialize prediction widgets
        self.pred_labels = []
        self.prob_bars = []
        self.create_prediction_widgets()
        
        # Instructions
        instructions = tk.Label(right_frame, 
                               text="• Draw with mouse\n• Predictions update every second\n• Use clear button to start over\n• Adjust 'Show top' to see more predictions",
                               font=('Arial', 9), bg='#f0f0f0', fg='#666', justify=tk.LEFT)
        instructions.pack(pady=(10, 0))
        
    def start_draw(self, event):
        self.drawing = True
        self.last_x = event.x
        self.last_y = event.y
        
    def draw_line(self, event):
        if self.drawing and self.last_x and self.last_y:
            # Draw on canvas
            self.canvas.create_line(self.last_x, self.last_y, event.x, event.y,
                                   width=self.BRUSH_SIZE, fill='black', capstyle=tk.ROUND, smooth=tk.TRUE)
            
            # Draw on PIL image
            self.draw.line([self.last_x, self.last_y, event.x, event.y], fill=0, width=self.BRUSH_SIZE)
            
            self.last_x = event.x
            self.last_y = event.y
            
    def stop_draw(self, event):
        self.drawing = False
        self.last_x = None
        self.last_y = None
        
    def create_prediction_widgets(self):
        """Create prediction display widgets based on current num_predictions."""
        # Clear existing widgets
        for widget in self.predictions_container.winfo_children():
            widget.destroy()
        
        self.pred_labels = []
        self.prob_bars = []
        
        for i in range(self.num_predictions):
            # Container for each prediction
            pred_container = tk.Frame(self.predictions_container, bg='white')
            pred_container.pack(fill=tk.X, padx=15, pady=5)
            
            # Rank number
            rank_label = tk.Label(pred_container, text=f"{i+1}.", font=('Arial', 12, 'bold'), 
                                 bg='white', width=2)
            rank_label.pack(side=tk.LEFT)
            
            # Class name label
            class_label = tk.Label(pred_container, text="", font=('Arial', 11), 
                                  bg='white', anchor='w', width=15)
            class_label.pack(side=tk.LEFT, padx=(5, 10))
            
            # Probability bar background
            bar_bg = tk.Frame(pred_container, bg='#e0e0e0', height=20, width=100)
            bar_bg.pack(side=tk.LEFT, padx=(0, 10))
            bar_bg.pack_propagate(False)
            
            # Probability bar
            prob_bar = tk.Frame(bar_bg, bg='#4ecdc4', height=20)
            prob_bar.place(x=0, y=0)
            
            # Probability percentage
            prob_label = tk.Label(pred_container, text="", font=('Arial', 10), bg='white', width=6)
            prob_label.pack(side=tk.LEFT)
            
            self.pred_labels.append((class_label, prob_label))
            self.prob_bars.append(prob_bar)
    
    def update_prediction_count(self):
        """Update the number of predictions to display."""
        try:
            new_count = int(self.num_pred_var.get())
            if 1 <= new_count <= 20:
                self.num_predictions = new_count
                self.create_prediction_widgets()
                # Update canvas scroll region
                self.pred_canvas.configure(scrollregion=self.pred_canvas.bbox("all"))
        except ValueError:
            # Reset to previous valid value
            self.num_pred_var.set(str(self.num_predictions))
    
    def new_challenge(self):
        """Select and display a random drawing challenge."""
        random_class = random.choice(self.classes)
        # Format the class name nicely
        formatted_class = random_class.replace('_', ' ').title()
        self.challenge_label.config(text=f"🎯 Draw: {formatted_class}")
        
        # Also clear the canvas when starting a new challenge
        self.clear_canvas()
        
    def clear_canvas(self):
        self.canvas.delete("all")
        self.image = Image.new('L', (self.CANVAS_SIZE, self.CANVAS_SIZE), 255)
        self.draw = ImageDraw.Draw(self.image)
        
        # Clear predictions
        self.status_label.config(text="Start drawing to see predictions!")
        for i in range(len(self.pred_labels)):
            self.pred_labels[i][0].config(text="")
            self.pred_labels[i][1].config(text="")
            self.prob_bars[i].config(width=0)
            
    def toggle_predictions(self):
        self.prediction_active = not self.prediction_active
        if self.prediction_active:
            self.toggle_btn.config(text="Pause Predictions", bg='#4ecdc4')
            self.status_label.config(text="Predictions active")
        else:
            self.toggle_btn.config(text="Resume Predictions", bg='#ff9f43')
            self.status_label.config(text="Predictions paused")
            
    def preprocess_for_prediction(self):
        # Convert PIL image to numpy array
        img_array = np.array(self.image)
        
        # Check if canvas is empty (all white)
        if np.all(img_array == 255):
            return None
            
        # Resize to model input size
        img_resized = cv2.resize(img_array, (self.IMG_SIZE, self.IMG_SIZE))
        
        # Normalize
        img_normalized = img_resized.astype(np.float32) / 255.0
        
        # Add dimensions for model
        img_final = np.expand_dims(img_normalized, axis=-1)  # Add channel dimension
        img_final = np.expand_dims(img_final, axis=0)        # Add batch dimension
        
        return img_final
        
    def make_prediction(self):
        try:
            processed_img = self.preprocess_for_prediction()
            
            if processed_img is None:
                return None
                
            # Make prediction
            prediction = self.model.predict(processed_img, verbose=0)
            
            # Get top N predictions
            top_indices = np.argsort(prediction[0])[::-1][:self.num_predictions]
            top_classes = [self.classes[i] for i in top_indices]
            top_probs = prediction[0][top_indices]
            
            return list(zip(top_classes, top_probs))
            
        except Exception as e:
            print(f"Prediction error: {e}")
            return None
            
    def update_predictions(self, predictions):
        if predictions is None:
            self.status_label.config(text="Start drawing to see predictions!")
            for i in range(len(self.pred_labels)):
                self.pred_labels[i][0].config(text="")
                self.pred_labels[i][1].config(text="")
                self.prob_bars[i].config(width=0)
            return
            
        self.status_label.config(text="")
        
        for i, (class_name, prob) in enumerate(predictions):
            # Update class name and probability
            self.pred_labels[i][0].config(text=class_name.replace('_', ' ').title())
            self.pred_labels[i][1].config(text=f"{prob*100:.1f}%")
            
            # Update probability bar width
            bar_width = int(prob * 100)  # Scale to 100px max
            self.prob_bars[i].config(width=bar_width)
            
            # Color code the bars
            if i == 0:  # Top prediction
                color = '#4ecdc4'
            elif i == 1:  # Second
                color = '#45b7aa' 
            else:  # Others
                color = '#95e1d3'
            self.prob_bars[i].config(bg=color)
            
    def prediction_worker(self):
        while True:
            try:
                current_time = time.time()
                
                if (self.prediction_active and 
                    current_time - self.last_prediction_time >= 1.0):  # Update every second
                    
                    predictions = self.make_prediction()
                    
                    # Update UI in main thread
                    self.root.after(0, self.update_predictions, predictions)
                    
                    self.last_prediction_time = current_time
                    
                time.sleep(0.1)  # Small sleep to prevent high CPU usage
                
            except Exception as e:
                #print(f"Prediction worker error: {e}")
                time.sleep(1)
                
    def start_prediction_thread(self):
        prediction_thread = threading.Thread(target=self.prediction_worker, daemon=True)
        prediction_thread.start()
        
    def run(self):
        self.root.mainloop()

# Example usage
if __name__ == "__main__":
    # Update these paths to match your setup
    MODEL_PATH = "checkpoints/step_110000.keras"
    DATA_DIR = "/mnt/d/MyEverything/PythonProjects/Recent_projects/cnn_analysis/Hand_Drawing/quickdraw_images"
    
    try:
        app = DrawingInterface(MODEL_PATH, DATA_DIR)
        app.run()
    except Exception as e:
        print(f"Error starting application: {e}")
        print("Please make sure the model path and data directory are correct.")

2025-06-17 01:57:03.577063: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-17 01:57:10.762488: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750093033.270814    1419 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750093034.111926    1419 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-17 01:57:20.801247: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr