In [3]:
# digit_capture.py
# Draw digits with touchpad/touchscreen or mouse and save as test_images/n_m.png
# - Choose the digit class (0..9) from the dropdown or press keys 0..9
# - Click "Save" (or press 's') to create a new file with auto-incremented index (n_m.png)
# - Image is saved as 28x28 grayscale PNG (MNIST-compatible)

import os
import re
import tkinter as tk
from tkinter import ttk, messagebox
from PIL import Image, ImageDraw, ImageOps

TARGET_DIR = "test_images"
CANVAS_SIZE = 280      # draw big, we'll downsample to 28
OUT_SIZE = 28          # MNIST size
PEN_RADIUS = 10        # stroke thickness (adjust if you want)
INK = 0                # black ink (0) on white background (255)
BG = 255

os.makedirs(TARGET_DIR, exist_ok=True)

def find_next_index_for_digit(digit: int) -> int:
    """Return next available m for files like digit_m.png in TARGET_DIR."""
    pattern = re.compile(rf"^{digit}_(\d+)\.png$")
    max_idx = -1
    for name in os.listdir(TARGET_DIR):
        m = pattern.match(name)
        if m:
            max_idx = max(max_idx, int(m.group(1)))
    return max_idx + 1

class DigitCaptureApp:
    def __init__(self, root):
        self.root = root
        self.root.title("MNIST Digit Capture")

        # --- top controls ---
        top = tk.Frame(root)
        top.pack(padx=8, pady=6, fill="x")

        tk.Label(top, text="Digit:").pack(side="left")
        self.digit_var = tk.StringVar(value="0")
        self.digit_combo = ttk.Combobox(top, textvariable=self.digit_var, width=3,
                                        values=[str(i) for i in range(10)], state="readonly")
        self.digit_combo.pack(side="left", padx=6)

        self.invert_var = tk.BooleanVar(value=False)
        tk.Checkbutton(top, text="Invert (white digit on black)", variable=self.invert_var)\
            .pack(side="left", padx=10)

        tk.Button(top, text="Clear (C)", command=self.clear).pack(side="left", padx=6)
        tk.Button(top, text="Save (S)", command=self.save).pack(side="left", padx=6)

        # --- canvas for drawing ---
        self.canvas = tk.Canvas(root, width=CANVAS_SIZE, height=CANVAS_SIZE, bg="white", cursor="tcross")
        self.canvas.pack(padx=8, pady=8)

        # backing PIL image (grayscale) to save exactly what we draw
        self.img = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), color=BG)
        self.draw = ImageDraw.Draw(self.img)

        # Bind mouse events
        self.last = None
        self.canvas.bind("<Button-1>", self._on_down)
        self.canvas.bind("<B1-Motion>", self._on_drag)
        self.canvas.bind("<ButtonRelease-1>", self._on_up)

        # Key bindings
        self.root.bind("<Key-s>", lambda e: self.save())
        self.root.bind("<Key-S>", lambda e: self.save())
        self.root.bind("<Key-c>", lambda e: self.clear())
        self.root.bind("<Key-C>", lambda e: self.clear())
        for d in map(str, range(10)):
            self.root.bind(d, self._set_digit_key)

        # Tip
        tk.Label(root, text="Tips: draw with mouse/touch • S=Save • C=Clear • 0..9 to set digit").pack(pady=(0,8))

    # --- drawing handlers ---
    def _on_down(self, e):
        self.last = (e.x, e.y)
        self._dot(e.x, e.y)

    def _on_drag(self, e):
        if self.last is not None:
            x0, y0 = self.last
            x1, y1 = e.x, e.y
            # draw on canvas
            self.canvas.create_line(x0, y0, x1, y1, fill="black", width=PEN_RADIUS*2, capstyle=tk.ROUND, smooth=True)
            # draw on PIL image (thick line)
            self.draw.line([x0, y0, x1, y1], fill=INK, width=PEN_RADIUS*2, joint="curve")
            self.last = (x1, y1)

    def _on_up(self, e):
        self.last = None

    def _dot(self, x, y):
        r = PEN_RADIUS
        self.canvas.create_oval(x-r, y-r, x+r, y+r, fill="black", outline="")
        self.draw.ellipse([x-r, y-r, x+r, y+r], fill=INK)

    def _set_digit_key(self, event):
        self.digit_var.set(event.char)

    # --- actions ---
    def clear(self):
        self.canvas.delete("all")
        self.img = Image.new("L", (CANVAS_SIZE, CANVAS_SIZE), color=BG)
        self.draw = ImageDraw.Draw(self.img)

    def save(self):
        # 1) downsample to 28x28 and optionally invert
        small = self.img.resize((OUT_SIZE, OUT_SIZE), Image.LANCZOS)
        if self.invert_var.get():
            # If you drew white on black, invert to match MNIST (white background, dark digit)
            small = ImageOps.invert(small)

        # 2) determine filename n_m.png
        try:
            digit = int(self.digit_var.get())
            assert 0 <= digit <= 9
        except Exception:
            messagebox.showerror("Error", f"Invalid digit: {self.digit_var.get()}")
            return

        m = find_next_index_for_digit(digit)
        out_name = f"{digit}_{m}.png"
        out_path = os.path.join(TARGET_DIR, out_name)

        # 3) save grayscale PNG (0..255); your code will normalize to [0,1] when loading
        small.save(out_path)
        messagebox.showinfo("Saved", f"Saved: {out_path}")
        print(f"Saved: {out_path}")

        # 4) auto-clear for next drawing
        self.clear()

def main():
    root = tk.Tk()
    app = DigitCaptureApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()


Saved: test_images\0_0.png
Saved: test_images\0_1.png
Saved: test_images\0_2.png
Saved: test_images\0_3.png
Saved: test_images\0_4.png
Saved: test_images\1_0.png
Saved: test_images\1_1.png
Saved: test_images\1_2.png
Saved: test_images\1_3.png
Saved: test_images\1_4.png
Saved: test_images\2_0.png
Saved: test_images\2_1.png
Saved: test_images\2_2.png
Saved: test_images\2_3.png
Saved: test_images\2_4.png
Saved: test_images\3_0.png
Saved: test_images\3_1.png
Saved: test_images\3_2.png
Saved: test_images\3_3.png
Saved: test_images\3_4.png
Saved: test_images\4_0.png
Saved: test_images\4_1.png
Saved: test_images\4_2.png
Saved: test_images\4_3.png
Saved: test_images\4_4.png
Saved: test_images\5_0.png
Saved: test_images\5_1.png
Saved: test_images\5_2.png
Saved: test_images\5_3.png
Saved: test_images\5_4.png
Saved: test_images\6_0.png
Saved: test_images\6_1.png
Saved: test_images\6_2.png
Saved: test_images\6_3.png
Saved: test_images\6_4.png
Saved: test_images\7_0.png
Saved: test_images\7_1.png
S