# Interactive Digit Drawing and Prediction

Now let's create an interactive tool where you can draw a digit and have the model predict what digit you drew.

## How it works:
1. Draw a digit on a 28Ã—28 pixel canvas
2. The drawing is preprocessed the same way as training data
3. The trained model predicts the digit
4. We display the confidence scores for all 10 digits

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import json
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the model architectures (same as in the training notebook)
class MNISTClassifier_MLP(nn.Module):
    def __init__(self):
        super(MNISTClassifier_MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class MNISTClassifier_MLP_WithHiddenLayer(nn.Module):
    def __init__(self):
        super(MNISTClassifier_MLP_WithHiddenLayer, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 16)
        self.fc3 = nn.Linear(16, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

# Model loading function
def load_model(model_name):
    """
    Load a model from a file.
    
    Args:
        model_name: Name of the model (without 'model_' prefix)
    
    Returns:
        Loaded model on the appropriate device
    """
    models_dir = '../models'  # Adjust path if needed
    model_path = os.path.join(models_dir, f'model_{model_name}.pt')
    metadata_path = os.path.join(models_dir, f'model_{model_name}_metadata.json')
    
    if not os.path.exists(model_path):
        print(f"Model file not found: {model_path}")
        return None
    
    # Load metadata to determine model class
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    model_class_name = metadata['model_class']
    
    # Create appropriate model instance
    if model_class_name == 'MNISTClassifier_MLP':
        model = MNISTClassifier_MLP()
    elif model_class_name == 'MNISTClassifier_MLP_WithHiddenLayer':
        model = MNISTClassifier_MLP_WithHiddenLayer()
    else:
        print(f"Unknown model class: {model_class_name}")
        return None
    
    # Load weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    
    return model, metadata

# Try to load available models
print("\nSearching for available models...")
models_dir = '../models'
available_models = {}

if os.path.exists(models_dir):
    for file in os.listdir(models_dir):
        if file.startswith('model_') and file.endswith('_metadata.json'):
            model_name = file.replace('model_', '').replace('_metadata.json', '')
            with open(os.path.join(models_dir, file), 'r') as f:
                metadata = json.load(f)
                available_models[model_name] = metadata['accuracy']
    
    if available_models:
        print(f"Found {len(available_models)} saved models:")
        for name, accuracy in sorted(available_models.items(), key=lambda x: x[1], reverse=True):
            print(f"  - {name}: {accuracy}% accuracy")
    else:
        print("No saved models found. You can save models from the training notebook.")
else:
    print("Models directory not found. Run the training notebook first to save models.")

# Load the best model by default (highest accuracy)
if available_models:
    best_model_name = max(available_models.items(), key=lambda x: x[1])[0]
    print(f"\nLoading best model: {best_model_name}")
    result = load_model(best_model_name)
    if result:
        current_model, current_metadata = result
        print(f"Model loaded successfully! Accuracy: {current_metadata['accuracy']}%")
    else:
        current_model = None
        print("Failed to load model")
else:
    current_model = None
    print("No model available to load")

In [1]:
%pip install ipywidgets ipycanvas > out.log

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: C:\Users\rpenaguiao\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [2]:
from PIL import Image, ImageDraw
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np
from ipycanvas import Canvas, hold_canvas
print("Imports successful")

Imports successful


In [None]:
def on_predict_click(b):
    if current_model is None:
        print("Error: No model loaded. Please run the model loading cell first.")
        return
    
    # Get the image from the canvas
    img_array = get_image_array()
    
    # Convert to tensor and add batch dimension
    img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
    img_tensor = img_tensor.to(device)
    
    # Make prediction
    with torch.no_grad():
        outputs = current_model(img_tensor)
        probabilities = torch.softmax(outputs, dim=1)[0]
        predicted_digit = torch.argmax(outputs, dim=1).item()
    
    # Display results
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Show the drawn image
    axes[0].imshow(img_array, cmap='gray')
    axes[0].set_title(f'Your Drawing\n(Predicted: {predicted_digit})', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Show confidence scores
    digits = list(range(10))
    probs = probabilities.cpu().numpy()
    colors = ['green' if i == predicted_digit else 'skyblue' for i in range(10)]
    axes[1].barh(digits, probs)
    axes[1].set_xlabel('Confidence')
    axes[1].set_ylabel('Digit')
    axes[1].set_title('Model Confidence Scores', fontsize=14, fontweight='bold')
    axes[1].set_xlim(0, 1)
    axes[1].invert_yaxis()
    
    # Color the bars
    for i, bar in enumerate(axes[1].patches):
        bar.set_color(colors[i])
    
    plt.tight_layout()
    plt.show()

In [None]:
# Create a button to display the 28x28 version
show_28x28_button = widgets.Button(description='Show 28x28', button_style='info')
output_28x28 = widgets.Output()

def on_show_28x28_click(b):
    output_28x28.clear_output()
    with output_28x28:
        img_array = get_image_array()
        fig, ax = plt.subplots(figsize=(4, 4))
        ax.imshow(img_array, cmap='gray')
        ax.set_title('28x28 Pixel Version', fontsize=12, fontweight='bold')
        ax.axis('off')
        # Add grid to see individual pixels
        ax.set_xticks(np.arange(-0.5, 28, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, 28, 1), minor=True)
        ax.grid(which='minor', color='lightgray', linestyle='-', linewidth=0.5)
        plt.tight_layout()
        plt.show()

show_28x28_button.on_click(on_show_28x28_click)

# Update button display to include the new button
print("Draw a digit on the canvas below (use your mouse):")
print("The canvas is 400x400 pixels. Your drawing will be processed as a 28x28 image.")
display(canvas)
display(widgets.HBox([clear_button, predict_button, show_28x28_button]))
display(output_28x28)

In [None]:
# Add model selection dropdown
if available_models:
    model_dropdown = widgets.Dropdown(
        options=[(f"{name} ({acc}%)", name) for name, acc in sorted(available_models.items(), key=lambda x: x[1], reverse=True)],
        description='Select Model:',
        style={'description_width': '100px'}
    )
    
    def on_model_change(change):
        global current_model, current_metadata
        selected_model_name = change['new']
        print(f"Loading model: {selected_model_name}...")
        result = load_model(selected_model_name)
        if result:
            current_model, current_metadata = result
            print(f"Model loaded! Accuracy: {current_metadata['accuracy']}%")
        else:
            print("Failed to load model")
    
    model_dropdown.observe(on_model_change, names='value')
    
    print("Model selector available:")
    display(model_dropdown)