# Test Matching: Quantity Expansion (M9) + Feature Matching (M10) + Scoring

This notebook tests the matching pipeline:

- **M9 -- Quantity Expander**: Expands quantity-bearing callouts and SW features
  into individual instances before matching. A `4X \u2300.500` becomes 4 separate
  Hole instances.
- **M10 -- FeatureMatcher**: Matches expanded drawing callouts against expanded
  SW features using type-specific tolerances. Produces MatchResult with status:
  MATCHED, MISSING, EXTRA, TOLERANCE_FAIL, or SKIPPED.
- **Scoring**: `compute_scores()` calculates instance_match_rate and total_rate
  with SKIPPED excluded from denominators.

**Runtime requirement:** CPU is sufficient. No GPU or models needed.

**Tests use synthetic data** with known expected outcomes.

In [None]:
# Cell 1: Install dependencies
# NOTE: GPU not required for this notebook.
%pip install pillow --quiet

# Clone the repo
!git clone https://github.com/skaumbdoallsaws-coder/AI-Drawing-Inspector.git /content/AI-Drawing-Inspector 2>/dev/null || \
    (cd /content/AI-Drawing-Inspector && git pull)

import sys
sys.path.insert(0, '/content/AI-Drawing-Inspector')

print('Dependencies installed.')

In [None]:
# Cell 2: Import modules
from ai_inspector.comparison.quantity_expander import (
    expand_drawing_callouts,
    expand_sw_features,
    expand_both_sides,
    expansion_summary,
)
from ai_inspector.comparison.matcher import (
    FeatureMatcher,
    MatchResult,
    MatchStatus,
)
from ai_inspector.comparison.sw_extractor import SwFeature
from ai_inspector.detection.classes import FUTURE_TYPES

print(f'FUTURE_TYPES (skipped in scoring): {FUTURE_TYPES}')
print('Imports OK.')

In [None]:
# Cell 3: Create synthetic drawing callouts and SW features (known match)

# Drawing callouts (what the OCR pipeline produces)
drawing_callouts = [
    # 4 identical holes -- should match 4 SW hole instances
    {
        'calloutType': 'Hole',
        'diameter': 0.500,
        'depth': None,  # THRU
        'quantity': 4,
        'raw': '4X \u2300.500 THRU',
    },
    # 1 tapped hole -- should match SW tapped hole
    {
        'calloutType': 'TappedHole',
        'diameter': 0.394,  # ~M10
        'thread': {
            'standard': 'Metric',
            'nominalDiameterMm': 10.0,
            'pitch': 1.5,
        },
        'quantity': 1,
        'raw': 'M10x1.5',
    },
    # 2 fillets
    {
        'calloutType': 'Fillet',
        'radius': 0.125,
        'quantity': 2,
        'raw': 'R.125 TYP.',
    },
    # 1 chamfer -- uses "size" field
    {
        'calloutType': 'Chamfer',
        'size': 0.045,
        'angle': 45,
        'quantity': 1,
        'raw': '.045 x 45\u00b0',
    },
    # GDT callout -- should be SKIPPED (FUTURE_TYPE)
    {
        'calloutType': 'GDT',
        'raw': 'TRUE POS \u2300.010 M',
    },
    # SurfaceFinish -- should be SKIPPED
    {
        'calloutType': 'SurfaceFinish',
        'raw': 'Ra 63',
    },
]

# SolidWorks features (what the SW extractor produces)
sw_features = [
    SwFeature(
        feature_type='Hole',
        diameter_inches=0.500,
        depth_inches=None,
        quantity=4,
        location='Hole1',
    ),
    SwFeature(
        feature_type='TappedHole',
        diameter_inches=0.394,
        thread={
            'standard': 'Metric',
            'nominalDiameterMm': 10.0,
            'pitch': 1.5,
        },
        quantity=1,
        location='TapHole1',
    ),
    SwFeature(
        feature_type='Fillet',
        radius_inches=0.125,
        quantity=2,
        location='Fillet1',
    ),
    SwFeature(
        feature_type='Chamfer',
        radius_inches=0.045,  # Chamfer distance stored in radius_inches
        quantity=1,
        location='Chamfer1',
    ),
]

print(f'Drawing callouts: {len(drawing_callouts)}')
for c in drawing_callouts:
    print(f'  {c["calloutType"]:20s} qty={c.get("quantity", 1)}  raw="{c["raw"]}"')

print(f'\nSW features: {len(sw_features)}')
for f in sw_features:
    print(f'  {f.feature_type:20s} qty={f.quantity}  loc={f.location}')

In [None]:
# Cell 4: Test expand_both_sides() -- verify expansion counts
print('=== Expansion Test ===')

expanded_callouts, expanded_sw = expand_both_sides(drawing_callouts, sw_features)

summary = expansion_summary(
    drawing_callouts, expanded_callouts,
    sw_features, expanded_sw,
)

print(f'Drawing: {summary["drawing"]["before"]} -> {summary["drawing"]["after"]} '
      f'(+{summary["drawing"]["expanded_count"]})')
print(f'SW:      {summary["solidworks"]["before"]} -> {summary["solidworks"]["after"]} '
      f'(+{summary["solidworks"]["expanded_count"]})')

# Verify expected counts:
# Drawing: 4 holes + 1 tapped + 2 fillets + 1 chamfer + 1 GDT + 1 SF = 10 instances
# SW: 4 holes + 1 tapped + 2 fillets + 1 chamfer = 8 instances
expected_drawing = 4 + 1 + 2 + 1 + 1 + 1  # 10
expected_sw = 4 + 1 + 2 + 1  # 8
print(f'\nExpected drawing instances: {expected_drawing}, got: {len(expanded_callouts)}')
print(f'Expected SW instances: {expected_sw}, got: {len(expanded_sw)}')
assert len(expanded_callouts) == expected_drawing, \
    f'Drawing expansion mismatch: expected {expected_drawing}, got {len(expanded_callouts)}'
assert len(expanded_sw) == expected_sw, \
    f'SW expansion mismatch: expected {expected_sw}, got {len(expanded_sw)}'

# Show expanded instances
print('\nExpanded drawing callouts:')
for i, c in enumerate(expanded_callouts):
    inst = c.get('_instance_index', '?')
    orig_qty = c.get('_original_quantity', '?')
    print(f'  [{i:2d}] {c["calloutType"]:15s} inst={inst}/{orig_qty}')

print('\nExpanded SW features:')
for i, f in enumerate(expanded_sw):
    print(f'  [{i:2d}] {f.feature_type:15s} loc={f.location}')

print('\nPASS')

In [None]:
# Cell 5: Test FeatureMatcher.match_all() -- verify correct matches
print('=== Matching Test ===')

matcher = FeatureMatcher()
match_results = matcher.match_all(expanded_callouts, expanded_sw)

print(f'Total match results: {len(match_results)}')
print()

# Count by status
status_counts = {}
for r in match_results:
    status_counts[r.status.value] = status_counts.get(r.status.value, 0) + 1

print('Status breakdown:')
for status, count in sorted(status_counts.items()):
    print(f'  {status:20s}: {count}')

# Show each match
print(f'\n{"#":>3s} {"Status":15s} {"Type":15s} {"Delta":>10s} {"Notes"}')
print('-' * 80)
for i, r in enumerate(match_results):
    callout_type = ''
    if r.drawing_callout:
        callout_type = r.drawing_callout.get('calloutType', '')
    elif r.sw_feature:
        callout_type = r.sw_feature.feature_type

    delta_str = f'{r.delta:+.4f}' if r.delta is not None else 'N/A'
    print(f'{i:3d} {r.status.value:15s} {callout_type:15s} {delta_str:>10s} {r.notes}')

# Verify: all 4 holes matched, tapped hole matched, 2 fillets matched, 1 chamfer matched
matched_count = sum(1 for r in match_results if r.status == MatchStatus.MATCHED)
skipped_count = sum(1 for r in match_results if r.status == MatchStatus.SKIPPED)
print(f'\nMatched: {matched_count} (expected 8: 4 holes + 1 tapped + 2 fillets + 1 chamfer)')
print(f'Skipped: {skipped_count} (expected 2: GDT + SurfaceFinish)')
assert matched_count == 8, f'Expected 8 matches, got {matched_count}'
assert skipped_count == 2, f'Expected 2 skipped, got {skipped_count}'
print('\nPASS')

In [None]:
# Cell 6: Test compute_scores() -- verify scoring math
print('=== Scoring Test ===')

scores = matcher.compute_scores(match_results)

print('Scores:')
for key, value in scores.items():
    print(f'  {key:25s}: {value}')

# Verify scoring math:
# matched=8, missing=0, extra=0, skipped=2, tolerance_fail=0
# instance_match_rate = 8 / (8 + 0 + 0) = 1.0
# total_rate = 8 / (8 + 0 + 0 + 0) = 1.0
assert scores['matched'] == 8, f'Expected matched=8, got {scores["matched"]}'
assert scores['missing'] == 0, f'Expected missing=0, got {scores["missing"]}'
assert scores['extra'] == 0, f'Expected extra=0, got {scores["extra"]}'
assert scores['skipped'] == 2, f'Expected skipped=2, got {scores["skipped"]}'
assert scores['instance_match_rate'] == 1.0, \
    f'Expected instance_match_rate=1.0, got {scores["instance_match_rate"]}'
assert scores['total_rate'] == 1.0, \
    f'Expected total_rate=1.0, got {scores["total_rate"]}'

print('\nAll scoring assertions passed.')
print('\nPASS')

In [None]:
# Cell 7: Test edge cases
print('=== Edge Case Tests ===')

# Test 1: FUTURE_TYPES get SKIPPED, not EXTRA/MISSING
print('\n--- Test 1: FUTURE_TYPES get SKIPPED ---')
future_callouts = [
    {'calloutType': 'Slot', 'raw': 'SLOT 1.000 x .500', 'quantity': 1},
    {'calloutType': 'Bend', 'raw': 'BEND R.250', 'quantity': 1},
    {'calloutType': 'Note', 'raw': 'BREAK ALL SHARP EDGES', 'quantity': 1},
]
future_sw = []  # No SW features for these

exp_future_c, exp_future_sw = expand_both_sides(future_callouts, future_sw)
future_results = matcher.match_all(exp_future_c, exp_future_sw)

for r in future_results:
    ct = r.drawing_callout.get('calloutType', '') if r.drawing_callout else ''
    print(f'  {ct:15s} -> {r.status.value:15s} ({r.notes})')
    assert r.status == MatchStatus.SKIPPED, \
        f'Expected SKIPPED for {ct}, got {r.status.value}'

future_scores = matcher.compute_scores(future_results)
print(f'  Scores: {future_scores}')
assert future_scores['skipped'] == 3
# With only skipped items, rates should be 1.0 (no denominator)
assert future_scores['instance_match_rate'] == 1.0
print('  PASS')

# Test 2: Chamfer uses "size" field for matching
print('\n--- Test 2: Chamfer "size" field matching ---')
chamfer_callouts = [
    {'calloutType': 'Chamfer', 'size': 0.060, 'raw': '.060 x 45', 'quantity': 1},
]
chamfer_sw = [
    SwFeature(feature_type='Chamfer', radius_inches=0.060, quantity=1),
]
exp_cc, exp_cs = expand_both_sides(chamfer_callouts, chamfer_sw)
chamfer_results = matcher.match_all(exp_cc, exp_cs)
assert chamfer_results[0].status == MatchStatus.MATCHED, \
    f'Expected MATCHED for chamfer, got {chamfer_results[0].status.value}'
print(f'  Chamfer matched: delta={chamfer_results[0].delta}')
print('  PASS')

# Test 3: Tolerance failure
print('\n--- Test 3: Tolerance failure ---')
tol_callouts = [
    {'calloutType': 'Hole', 'diameter': 0.500, 'raw': '\u2300.500', 'quantity': 1},
]
tol_sw = [
    SwFeature(feature_type='Hole', diameter_inches=0.530, quantity=1),  # 0.030 off
]
exp_tc, exp_ts = expand_both_sides(tol_callouts, tol_sw)
tol_results = matcher.match_all(exp_tc, exp_ts)
for r in tol_results:
    print(f'  Status: {r.status.value}, delta={r.delta}, notes="{r.notes}"')
assert any(r.status == MatchStatus.TOLERANCE_FAIL for r in tol_results), \
    'Expected TOLERANCE_FAIL for 0.030 delta'
print('  PASS')

print('\nAll edge case tests passed.')