<a href="https://colab.research.google.com/github/skaumbdoallsaws-coder/AI-Drawing-Inspector/blob/main/tests/notebooks/test_ocr_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Test OCR Pipeline: Rotation (M3) + OCR Adapter (M4) + Crop Reader (M5)

This notebook tests the OCR stages of the AI Inspector pipeline:

- **M3 -- Rotation Selector**: Tries 4 rotations (0, 90, 180, 270) of each crop,
  runs OCR on each, and picks the rotation with the best text quality score.
- **M4 -- OCR Adapter**: Wraps LightOnOCR-2-1B with confidence estimation and
  canonicalization. Exposes a `read_simple(image) -> (text, confidence)` interface.
- **M5 -- Crop Reader**: Full OCR-to-parse pipeline per crop: OCR -> canonicalize
  -> regex parse by YOLO class -> optional VLM fallback.

**Runtime requirement:** GPU with at least 8 GB VRAM (A100 preferred).

**Note:** LightOnOCR-2-1B is a gated model on HuggingFace. You need an HF_TOKEN
with access granted.

In [None]:
# Cell 1: Install dependencies
# NOTE: Set your runtime to GPU before running (Runtime > Change runtime type > A100)
%pip install transformers torch pillow accelerate matplotlib --quiet

# Clone the repo
!git clone https://github.com/skaumbdoallsaws-coder/AI-Drawing-Inspector.git /content/AI-Drawing-Inspector 2>/dev/null || \
    (cd /content/AI-Drawing-Inspector && git pull)

print('Dependencies installed.')

In [None]:
# Cell 2: Set paths, set HF_TOKEN (Drive mount skipped)
import sys
import os

# Add repo to Python path
sys.path.insert(0, '/content/AI-Drawing-Inspector')

# ---- Environment detection ----
try:
    from google.colab import files, userdata
    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    files = None
    userdata = None

print("Skipping Google Drive mount as requested.")

# ---- HuggingFace token (required for LightOnOCR-2) ----
HF_TOKEN = None
if IN_COLAB and userdata is not None:
    try:
        HF_TOKEN = userdata.get('HF_TOKEN')
    except Exception:
        pass

if not HF_TOKEN and IN_COLAB:
    try:
        from getpass import getpass
        entered = getpass('Enter HF_TOKEN (leave blank to skip): ').strip()
        HF_TOKEN = entered or None
    except Exception:
        pass

if HF_TOKEN:
    os.environ['HF_TOKEN'] = HF_TOKEN
    print(f'HF_TOKEN set (length={len(HF_TOKEN)})')
else:
    print('WARNING: HF_TOKEN not set. OCR model loading will fail until token is provided.')

# ---- Paths ----
# Search for crops in local directories
candidate_dirs = [
    '/content/debug/crops',
    '/content',
]

CROPS_DIR = next((d for d in candidate_dirs if os.path.isdir(d)), '/content')
print(f'Crops directory: {CROPS_DIR}')

# Reminder for manual upload
if IN_COLAB:
    print('Ready for manual file uploads.')

In [None]:
# Cell 3: Import modules
from ai_inspector.extractors.ocr_adapter import OCRAdapter, MockOCRAdapter
from ai_inspector.extractors.rotation import (
    select_best_rotation, _compute_text_quality, ROTATIONS
)
from ai_inspector.extractors.crop_reader import read_crop, read_crops_batch
from ai_inspector.extractors.canonicalize import canonicalize
from ai_inspector.contracts import OCRResult, RotationResult, ReaderResult
from PIL import Image

print('All OCR pipeline modules imported successfully.')

In [None]:
# Cell 4: Load OCR model (LightOnOCR-2-1B)
import torch

print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

if not HF_TOKEN:
    raise RuntimeError('HF_TOKEN is required to load LightOnOCR-2. Set token in Cell 2.')

adapter = OCRAdapter(hf_token=HF_TOKEN)
adapter.load()
print(f'OCR adapter loaded: {adapter.is_loaded}')

# Quick smoke test with a blank image
blank = Image.new('RGB', (200, 50), 'white')
result = adapter.read(blank)
print(f'Smoke test -- text: "{result.text}", confidence: {result.confidence:.2f}')

In [None]:
# Cell 5: Test rotation selection on sample crops
import matplotlib.pyplot as plt
import glob

# Load sample crop images (PNG files from crops dir or uploaded files)
crop_files = sorted(glob.glob(f'{CROPS_DIR}/*.png'))[:3]

# Include uploaded PNGs from /content when not already in CROPS_DIR
extra_uploads = [p for p in sorted(glob.glob('/content/*.png')) if p not in crop_files]
for p in extra_uploads:
    if len(crop_files) < 3:
        crop_files.append(p)

if not crop_files:
    print('No crop files found. Creating synthetic test crops...')
    # Create a synthetic crop with text-like content
    from PIL import ImageDraw, ImageFont
    for text, fname in [
        ('\u2300.500 THRU', 'synth_hole.png'),
        ('R.125', 'synth_fillet.png'),
        ('M10x1.5', 'synth_thread.png'),
    ]:
        img = Image.new('RGB', (300, 80), 'white')
        draw = ImageDraw.Draw(img)
        draw.text((10, 20), text, fill='black')
        path = f'/content/{fname}'
        img.save(path)
        crop_files.append(path)

# Test rotation selection on each crop
for crop_path in crop_files:
    crop_img = Image.open(crop_path).convert('RGB')
    print(f'\n--- {os.path.basename(crop_path)} ({crop_img.size}) ---')

    # Show all 4 rotations with their scores
    fig, axes = plt.subplots(1, 4, figsize=(16, 3))
    for idx, angle in enumerate(ROTATIONS):
        if angle == 0:
            rotated = crop_img
        else:
            rotated = crop_img.rotate(-angle, resample=Image.BICUBIC, expand=True)

        text, conf = adapter.read_simple(rotated)
        quality = _compute_text_quality(text)

        axes[idx].imshow(rotated)
        axes[idx].set_title(f'{angle}deg\nq={quality:.1f} c={conf:.2f}\n"{text[:30]}"', fontsize=8)
        axes[idx].axis('off')

    plt.suptitle(os.path.basename(crop_path), fontsize=10)
    plt.tight_layout()
    plt.show()

    # Run official rotation selector
    rot_result = select_best_rotation(crop_img, adapter.read_simple, yolo_class='Hole')
    print(f'  Best rotation: {rot_result.rotation_used}deg')
    print(f'  Quality score: {rot_result.quality_score:.2f}')
    print(f'  Raw text: "{rot_result.raw}"')

In [None]:
# Cell 6: Test canonicalization on sample strings
test_strings = [
    ('\u00d8.500 THRU', 'Diameter variant Oslash'),
    ('\u2205.250 DEEP .750', 'Diameter variant empty set'),
    ('2 X \u2300.375', 'Quantity with spaces'),
    ('R.125 TYP.', 'Fillet radius'),
    ('M10\u00d71.5', 'Metric thread with multiplication sign'),
    ('.045 \u00d7 45\u00ba', 'Chamfer with degree variant'),
    ('+/- .005', 'Plus-minus tolerance'),
    ('3/8\u201316 UNC', 'Thread with en-dash'),
    ('  2X   \u2300.500   THRU  ', 'Extra whitespace'),
]

print(f'{"Input":40s} | {"Canonicalized":40s} | Description')
print('-' * 110)
for raw, desc in test_strings:
    canon = canonicalize(raw)
    print(f'{repr(raw):40s} | {repr(canon):40s} | {desc}')

In [None]:
# Cell 7: Test crop_reader.read_crop() -- full OCR->parse pipeline on crops
from ai_inspector.detection.classes import CLASS_TO_CALLOUT_TYPE

# Define test cases: (crop_image_or_path, yolo_class)
test_cases = []
for crop_path in crop_files:
    # Infer class from filename or default to Hole
    name = os.path.basename(crop_path).lower()
    if 'fillet' in name:
        cls = 'Fillet'
    elif 'thread' in name or 'tapped' in name:
        cls = 'TappedHole'
    elif 'chamfer' in name:
        cls = 'Chamfer'
    else:
        cls = 'Hole'
    test_cases.append((crop_path, cls))

results = []
for crop_path, yolo_class in test_cases:
    crop_img = Image.open(crop_path).convert('RGB')

    # First run rotation selection to get best OCR text
    rot = select_best_rotation(crop_img, adapter.read_simple, yolo_class=yolo_class)

    # Then run crop_reader with pre-OCR text from rotation stage
    pre_ocr = None
    if rot.ocr_result:
        pre_ocr = (rot.ocr_result.text, rot.ocr_result.confidence)

    reader_result = read_crop(
        image=crop_img,
        ocr_fn=adapter.read_simple,
        yolo_class=yolo_class,
        pre_ocr=pre_ocr,
    )
    results.append((os.path.basename(crop_path), yolo_class, reader_result))

    print(f'\n--- {os.path.basename(crop_path)} (class={yolo_class}) ---')
    print(f'  Callout type: {reader_result.callout_type}')
    print(f'  Raw text:     "{reader_result.raw}"')
    print(f'  Parse source: {reader_result.source}')
    print(f'  OCR conf:     {reader_result.ocr_confidence:.2f}')
    print(f'  Parsed fields: {reader_result.parsed}')

In [None]:
# Cell 8: Print OCR results table
print(f'\n{"File":25s} {"Class":15s} {"Raw Text":30s} {"Parsed Type":15s} {"Source":8s} {"Conf":>6s}')
print('=' * 110)

for fname, yolo_class, rr in results:
    raw_display = rr.raw[:28] + '..' if len(rr.raw) > 30 else rr.raw
    print(
        f'{fname:25s} '
        f'{yolo_class:15s} '
        f'{raw_display:30s} '
        f'{rr.callout_type:15s} '
        f'{rr.source:8s} '
        f'{rr.ocr_confidence:6.2f}'
    )

# Parsed fields detail
print('\n--- Parsed Fields Detail ---')
for fname, yolo_class, rr in results:
    print(f'\n{fname}:')
    if rr.parsed:
        for k, v in rr.parsed.items():
            if not k.startswith('_'):
                print(f'  {k}: {v}')
    else:
        print('  (no parsed fields)')

# Cleanup
adapter.unload()
print('\nOCR adapter unloaded.')