In [None]:
import pickle
import math
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import cv2

# -------------------- Load Trained Model --------------------
with open('mnist_scratch_model.pkl', 'rb') as f:
    model = pickle.load(f)

W1 = model['W1']
b1 = model['b1']
W2 = model['W2']
b2 = model['b2']

print("✅ Model loaded successfully.")

# -------------------- Activation Functions --------------------
def relu(x): return [max(0, val) for val in x]
def softmax(x):
    max_x = max(x)
    exps = [math.exp(i - max_x) for i in x]
    return [j / sum(exps) for j in exps]

# -------------------- Prediction Function --------------------
def predict(x):
    z1 = [sum(w * xi for w, xi in zip(row, x)) + b for row, b in zip(W1, b1)]
    a1 = relu(z1)
    z2 = [sum(w * ai for w, ai in zip(row, a1)) + b for row, b in zip(W2, b2)]
    out = softmax(z2)
    pred = out.index(max(out))
    confidence = max(out)
    return pred, confidence

# -------------------- Image Preprocessing --------------------
def preprocess_image(frame):
    """Convert webcam frame to MNIST-like input."""
    img = Image.fromarray(frame).convert('L')         # Grayscale
    img = ImageOps.invert(img)                        # Invert: white digit on black
    img = img.point(lambda x: 0 if x < 30 else 255, '1')
    img = img.convert('L')
    bbox = img.getbbox()
    if bbox:
        img = img.crop(bbox)
    img = ImageOps.pad(img, (28, 28), method=Image.Resampling.LANCZOS, color=0)
    img = img.filter(ImageFilter.SHARPEN)
    return [p / 255.0 for p in img.getdata()]  # Flatten & normalize

# -------------------- Start Webcam --------------------
cap = cv2.VideoCapture(0)

print("📷 Press SPACE to predict a digit, ESC to exit.")

while True:
    ret, frame = cap.read()
    if not ret:
        break

    # Draw a rectangle in center of frame
    h, w = frame.shape[:2]
    size = 200
    x1, y1 = w // 2 - size // 2, h // 2 - size // 2
    x2, y2 = x1 + size, y1 + size
    roi = frame[y1:y2, x1:x2]

    # Draw rectangle for ROI
    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
    cv2.putText(frame, "Draw a digit and press SPACE", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)

    # Show the frame
    cv2.imshow("Digit Recognition (Scratch NN)", frame)

    key = cv2.waitKey(1)
    if key == 27:  # ESC to quit
        break
    elif key == 32:  # SPACE to predict
        roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
        img_vec = preprocess_image(roi_gray)
        pred, conf = predict(img_vec)
        print(f"🔢 Prediction: {pred} (Confidence: {conf * 100:.1f}%)")

cap.release()
cv2.destroyAllWindows()