In [25]:
import os
import cv2
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt

# ---------------- Step 1: Preprocessing the full sheet image ----------------
def preprocess_sheet(image_path):
    image = cv2.imread(image_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    thresh = cv2.adaptiveThreshold(
        blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV, 11, 2
    )
    return thresh


# ---------------- Step 2: Extract & Save 50 Digits ----------------
def extract_and_save_digits(thresh_img, save_dir="digit_dataset"):
    os.makedirs(save_dir, exist_ok=True)
    contours, _ = cv2.findContours(thresh_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    boxes = [cv2.boundingRect(c) for c in contours]
    sorted_boxes = sorted(boxes, key=lambda b: (b[1] // 50, b[0]))  # row-wise sorting

    digit_counts = {i: 0 for i in range(10)}
    digit_index = 0

    for (x, y, w, h) in sorted_boxes:
        if w * h < 100:
            continue
        digit_img = thresh_img[y:y + h, x:x + w]
        digit_resized = cv2.resize(digit_img, (32, 32))

        label = digit_index // 5  # 5 samples per class
        filename = os.path.join(save_dir, f"{label}_{digit_counts[label]}.png")
        cv2.imwrite(filename, digit_resized)

        digit_counts[label] += 1
        digit_index += 1
        if label == 9 and digit_counts[9] == 5:
            break

    print(f"✅ Saved {digit_index} digits in '{save_dir}'")


# ---------------- Step 3: ZNCC Similarity Function ----------------
def zncc(img1, img2):
    img1 = img1.astype(np.float32)
    img2 = img2.astype(np.float32)

    mean1 = np.mean(img1)
    mean2 = np.mean(img2)

    numerator = np.sum((img1 - mean1) * (img2 - mean2))
    denominator = np.sqrt(np.sum((img1 - mean1) ** 2) * np.sum((img2 - mean2) ** 2))

    if denominator == 0:
        return 0
    return numerator / denominator


# ---------------- Step 4: Load Dataset ----------------
def load_templates(dataset_path):
    templates = []
    labels = []

    for file in os.listdir(dataset_path):
        if file.endswith('.png'):
            path = os.path.join(dataset_path, file)
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = cv2.resize(img, (32, 32))
                label = int(file.split('_')[0])
                templates.append(img)
                labels.append(label)

    return templates, labels


# ---------------- Step 5: Classify a Digit Using ZNCC + KNN ----------------
def classify_digit(test_img, templates, labels, k=3):
    test_img = cv2.resize(test_img, (32, 32))
    scores = [(zncc(test_img, template), label) for template, label in zip(templates, labels)]
    scores.sort(reverse=True)
    top_k = [label for score, label in scores[:k]]
    most_common = Counter(top_k).most_common(1)[0][0]
    return most_common


# ---------------- Step 6: Predict a Test Digit ----------------
def test_single_digit(test_img_path, dataset_path="digit_dataset", k=3):
    templates, labels = load_templates(dataset_path)
    test_img = cv2.imread(test_img_path, cv2.IMREAD_GRAYSCALE)
    if test_img is None:
        print(f"❌ Failed to load test image: {test_img_path}")
        return
    prediction = classify_digit(test_img, templates, labels, k=k)
    print(f"✅ Predicted Digit for '{test_img_path}':", prediction)


# ---------------- Example Usage ----------------
# 1. Preprocess and extract dataset from the sheet image (only run once)
sheet_path = "no.jpg"  # Replace with your scanned digit sheet image
thresh_image = preprocess_sheet(sheet_path)
extract_and_save_digits(thresh_image)

# 2. Test classification
test_single_digit("digit_dataset/9_2.png", dataset_path="digit_dataset", k=3)


✅ Saved 50 digits in 'digit_dataset'
✅ Predicted Digit for 'digit_dataset/9_2.png': 9
