In [None]:
# app.py - Trains MNIST on first launch, then reuses model
import gradio as gr
import numpy as np
import tensorflow as tf
import os
import cv2
import pandas as pd
from PIL import Image

# Define model path
MODEL_DIR = "model"
MODEL_PATH = os.path.join(MODEL_DIR, "mnist_model.h5")

# Create model directory if not exists
os.makedirs(MODEL_DIR, exist_ok=True)

print("Checking for existing model...")
if os.path.exists(MODEL_PATH):
    print(" Loading pre-trained model...")
    model = tf.keras.models.load_model(MODEL_PATH)
    model.summary()
else:
    print(" Model not found. Training new model...")
    # Load and preprocess data
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Build model
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dropout(0.1),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    # Train (fast: 5 epochs on full MNIST)
    print("Starting training...")
    model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test), verbose=1)

    # Save model
    model.save(MODEL_PATH)
    print(f" Model saved to {MODEL_PATH}")

    # Evaluate
    test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
    print(f"Training complete! Test Accuracy: {test_accuracy:.4f}")

def preprocess_image(image):
    if image is None:
        print("Warning: Received None as image input.")
        return None

    # Handle Gradio ImageEditor output (dict with 'image' key)
    if isinstance(image, dict):
        image = image.get('image', None)
        if image is None:
            print("Warning: No image found in dictionary.")
            return None

    # Handle Gradio UploadButton output (file path or PIL image)
    if isinstance(image, str):
        try:
            image = Image.open(image)
            image = np.array(image.convert('L'))
        except Exception as e:
            print(f"Error opening uploaded image: {e}")
            return None

    # Convert PIL image to NumPy array if needed
    if hasattr(image, 'convert'):
        image = np.array(image.convert('L'))

    # Ensure it's a grayscale image
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    elif len(image.shape) == 2:
        gray = image
    else:
        print(f"Warning: Unexpected image shape {image.shape}")
        return None

    # Resize to 28x28
    resized = cv2.resize(gray, (28, 28), interpolation=cv2.INTER_AREA)

    # Invert: black strokes on white background -> white on black for MNIST
    # Check if background is bright (white canvas with black pen)
    if resized.mean() > 128:
        resized = 255 - resized

    # Normalize
    normalized = resized.astype('float32') / 255.0

    print("Preprocessed image shape:", normalized.shape)
    print("Preprocessed image mean:", normalized.mean())

    return normalized

def predict_digit(image):
    processed = preprocess_image(image)
    if processed is None:
        return {str(i): 0.0 for i in range(10)}, pd.DataFrame(columns=['Digit', 'Probability'])

    # Reshape for model input: (1, 28, 28)
    input_tensor = np.expand_dims(processed, axis=0)
    preds = model.predict(input_tensor, verbose=0)[0]

    # Format output for Gradio Label component
    predictions_dict = {str(i): float(preds[i]) for i in range(10)}

    # Format output for Gradio BarPlot component
    barplot_data = [[str(i), float(preds[i])] for i in range(10)]
    df = pd.DataFrame(barplot_data, columns=['Digit', 'Probability'])

    return predictions_dict, df

# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("#  Kilele AI Tutor: Handwritten Digit Recognition")
    gr.Markdown("Draw or upload a digit (0–9). Our AI will classify it!")

    with gr.Row():
        with gr.Column():
            sketchpad = gr.ImageEditor(
                label="Draw Digit",
                image_mode="L",
                canvas_size=(280, 280),
                brush=gr.Brush(
                    default_size=15,
                    colors=["#000000"],
                    default_color="#000000",
                    color_mode="fixed"
                ),
                eraser=gr.Eraser(default_size=15),
                sources=["upload"],
                type="numpy",
                layers=False,
                transforms=[]
            )
            upload_btn = gr.UploadButton("📁 Upload Image", file_types=["image"])
            clear_btn = gr.Button("🧹 Clear Canvas")

        with gr.Column():
            output_label = gr.Label(label="Prediction", num_top_classes=1)
            output_bars = gr.BarPlot(
                label="Confidence Scores",
                x="Digit",
                y="Probability",
                vertical=False,
                height=200
            )

    # Connect events
    sketchpad.change(predict_digit, inputs=sketchpad, outputs=[output_label, output_bars])
    upload_btn.upload(predict_digit, inputs=upload_btn, outputs=[output_label, output_bars])
    clear_btn.click(lambda: None, inputs=None, outputs=sketchpad)

demo.launch(debug=True)

Checking for existing model...
❌ Model not found. Training new model...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


  super().__init__(**kwargs)


Starting training...
Epoch 1/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 5ms/step - accuracy: 0.8370 - loss: 0.5337 - val_accuracy: 0.9589 - val_loss: 0.1346
Epoch 2/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - accuracy: 0.9511 - loss: 0.1608 - val_accuracy: 0.9699 - val_loss: 0.1005
Epoch 3/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - accuracy: 0.9645 - loss: 0.1181 - val_accuracy: 0.9729 - val_loss: 0.0899
Epoch 4/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 4ms/step - accuracy: 0.9691 - loss: 0.0976 - val_accuracy: 0.9712 - val_loss: 0.0893
Epoch 5/5
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - accuracy: 0.9726 - loss: 0.0828 - val_accuracy: 0.9766 - val_loss: 0.0797




💾 Model saved to model/mnist_model.h5
🎉 Training complete! Test Accuracy: 0.9766
It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://818fb9e65555c2e7c8.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


