# AI Drawing Inspector v4.0

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/skaumbdoallsaws-coder/AI-Drawing-Inspector/blob/main/notebooks/ai_inspector_v4.ipynb)

**Modular Type-Aware Architecture**

This notebook uses the `ai_inspector` package for shared utilities while keeping
all v3 functionality including assembly context, siblings, and mate relations.

In [None]:
# Cell 1: Install Dependencies
!pip install -q pymupdf pillow openai
!pip install -q accelerate qwen-vl-utils bitsandbytes
!pip install -q git+https://github.com/huggingface/transformers
!pip install -q json-repair

# Clone and install ai_inspector package
!git clone https://github.com/skaumbdoallsaws-coder/AI-Drawing-Inspector.git /content/AI-tool 2>/dev/null || (cd /content/AI-tool && git pull)
%cd /content/AI-tool
!pip install -q -e .
print('Dependencies installed!')

In [None]:
# Cell 2: Imports and Configuration
import os
import re
import json
import gc
import torch
from pathlib import Path
from datetime import datetime
from dataclasses import asdict
from typing import Dict, List, Optional, Any, Tuple

from google.colab import files, userdata
from IPython.display import display, Markdown
from PIL import Image

# ai_inspector package imports
from ai_inspector import classify_drawing, DrawingType, __version__
from ai_inspector.utils import render_pdf, SwJsonLibrary, extract_pdf_text, load_json_robust
from ai_inspector.utils.pdf_render import PageArtifact
from ai_inspector.analyzers import resolve_part_identity, ResolvedPartIdentity
from ai_inspector.extractors.ocr import preprocess_ocr_text, parse_ocr_callouts, PATTERNS
from ai_inspector.extractors.vlm import (
    FEATURE_EXTRACTION_PROMPT, QUALITY_AUDIT_PROMPT,
    BOM_EXTRACTION_PROMPT, MANUFACTURING_NOTES_PROMPT,
    PAGE_CLASSIFICATION_PROMPT
)
from ai_inspector.comparison import (
    extract_sw_requirements, extract_mate_requirements,
    generate_diff_result, create_stub_diff_result
)

print(f'ai_inspector v{__version__} loaded')

# Configuration
OUTPUT_DIR = '/content/output'
SOLIDWORKS_JSON_DIR = '/content/sw_json_library'
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(SOLIDWORKS_JSON_DIR, exist_ok=True)

print(f'Output directory: {OUTPUT_DIR}')

In [None]:
# Cell 3: Load AI Models
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
from qwen_vl_utils import process_vision_info
from json_repair import repair_json

hf_token = userdata.get('HF_TOKEN')

# Clear GPU memory
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Load LightOnOCR-2
print('Loading LightOnOCR-2-1B...')
ocr_device = 'cuda' if torch.cuda.is_available() else 'cpu'
ocr_dtype = torch.bfloat16 if ocr_device == 'cuda' else torch.float32

ocr_processor = LightOnOcrProcessor.from_pretrained('lightonai/LightOnOCR-2-1B', token=hf_token)
ocr_model = LightOnOcrForConditionalGeneration.from_pretrained(
    'lightonai/LightOnOCR-2-1B', torch_dtype=ocr_dtype, token=hf_token
).to(ocr_device)
print(f'LightOnOCR-2 loaded: {ocr_model.get_memory_footprint() / 1e9:.2f} GB')

# Load Qwen2.5-VL
print('\nLoading Qwen2.5-VL-7B...')
qwen_processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-7B-Instruct', trust_remote_code=True)
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    'Qwen/Qwen2.5-VL-7B-Instruct',
    torch_dtype=torch.bfloat16,
    device_map='auto',
    trust_remote_code=True
)
print(f'Qwen2.5-VL loaded: {qwen_model.get_memory_footprint() / 1e9:.2f} GB')

def run_lighton_ocr(image: Image.Image) -> List[str]:
    """Run LightOnOCR-2 on image, return list of text lines."""
    img = image.convert('RGB')
    conversation = [{'role': 'user', 'content': [{'type': 'image', 'image': img}]}]
    inputs = ocr_processor.apply_chat_template(
        conversation, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors='pt'
    )
    inputs = {k: v.to(device=ocr_device, dtype=ocr_dtype) if v.is_floating_point() else v.to(ocr_device) for k, v in inputs.items()}
    with torch.no_grad():
        output_ids = ocr_model.generate(**inputs, max_new_tokens=2048)
    generated_ids = output_ids[0, inputs['input_ids'].shape[1]:]
    output_text = ocr_processor.decode(generated_ids, skip_special_tokens=True)
    return [line.strip() for line in output_text.split('\n') if line.strip()]

def run_qwen_analysis(image: Image.Image, prompt: str) -> Dict[str, Any]:
    """Run Qwen2.5-VL with a given prompt and return parsed JSON."""
    messages = [{'role': 'user', 'content': [{'type': 'image', 'image': image}, {'type': 'text', 'text': prompt}]}]
    text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = qwen_processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt').to(qwen_model.device)
    with torch.no_grad():
        output_ids = qwen_model.generate(**inputs, max_new_tokens=4096, temperature=0.1)
    generated_ids = output_ids[0, inputs.input_ids.shape[1]:]
    response = qwen_processor.decode(generated_ids, skip_special_tokens=True)
    try:
        json_match = re.search(r'```json\s*([\s\S]*?)\s*```', response)
        json_str = json_match.group(1) if json_match else re.search(r'\{[\s\S]*\}', response).group()
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            return json.loads(repair_json(json_str))
    except Exception as e:
        return {'raw_response': response[:1000], 'parse_error': str(e)}

print('\nModels ready!')

In [None]:
# Cell 4: Load SolidWorks Library + Assembly Context Databases
import zipfile

print('Upload your sw_json_library.zip file:')
print('(Should contain part JSONs + optionally sw_part_context_complete.json and sw_inspector_requirements.json)')
uploaded = files.upload()

for filename in uploaded:
    if filename.endswith('.zip'):
        print(f'Extracting {filename}...')
        with zipfile.ZipFile(filename, 'r') as z:
            z.extractall(SOLIDWORKS_JSON_DIR)
        print(f'Extracted to {SOLIDWORKS_JSON_DIR}')
        break

# Load SW library using package
sw_library = SwJsonLibrary()
sw_library.load_from_directory(SOLIDWORKS_JSON_DIR)
print(f'\nLibrary ready: {len(sw_library)} parts indexed')

# === Load Mating & Sibling Context Databases ===
part_context_db = {}
inspector_requirements_db = {}

def _try_load_json(name, search_paths):
    """Try loading a JSON file from multiple paths."""
    for p in search_paths:
        full = os.path.join(p, name)
        if os.path.exists(full):
            with open(full, 'r', encoding='utf-8') as f:
                data = json.load(f)
            print(f'  Loaded {name}: {len(data)} entries from {p}')
            return data
    return None

_search_paths = [SOLIDWORKS_JSON_DIR, os.path.join(SOLIDWORKS_JSON_DIR, '..'), '/content']

print('\nLoading assembly context databases...')
part_context_db = _try_load_json('sw_part_context_complete.json', _search_paths) or {}
inspector_requirements_db = _try_load_json('sw_inspector_requirements.json', _search_paths) or {}

if not part_context_db:
    print('  sw_part_context_complete.json not found - sibling/mating context unavailable')
if not inspector_requirements_db:
    print('  sw_inspector_requirements.json not found - mate-derived thread requirements unavailable')

def get_part_context(part_number: str) -> Optional[Dict]:
    """Look up part in context DB by new_pn, old_pn, or normalized variants."""
    if not part_context_db:
        return None
    candidates = [part_number, part_number.upper(), part_number.lower(),
                  part_number.replace('-', ''), part_number.replace('_', '')]
    for key in candidates:
        if key in part_context_db:
            return part_context_db[key]
    for key, entry in part_context_db.items():
        identity = entry.get('identity', {})
        if identity.get('new_pn') == part_number or identity.get('old_pn') == part_number:
            return entry
    return None

def get_inspector_requirements(part_number: str) -> Optional[Dict]:
    """Look up inspector requirements by part number."""
    if not inspector_requirements_db:
        return None
    candidates = [part_number, part_number.upper(), part_number.lower(), part_number.replace('-', '')]
    for key in candidates:
        if key in inspector_requirements_db:
            return inspector_requirements_db[key]
    ctx = get_part_context(part_number)
    if ctx:
        old_pn = ctx.get('identity', {}).get('old_pn', '')
        if old_pn:
            for key in [old_pn, old_pn.replace('-', ''), old_pn.upper()]:
                if key in inspector_requirements_db:
                    return inspector_requirements_db[key]
    return None

print(f'\nAssembly context: {len(part_context_db)} parts')
print(f'Inspector requirements: {len(inspector_requirements_db)} entries')

In [None]:
# Cell 5: Upload and Render PDF Drawing
print('Upload your PDF drawing:')
uploaded = files.upload()

for filename in uploaded:
    if filename.lower().endswith('.pdf'):
        DRAWING_PDF_PATH = filename
        break

print(f'\nProcessing: {DRAWING_PDF_PATH}')
print('='*50)

# Render PDF using package
artifacts = render_pdf(DRAWING_PDF_PATH)

# Classify drawing type
pdf_text = extract_pdf_text(DRAWING_PDF_PATH)
classification = classify_drawing(pdf_text)

print(f'\nDrawing Type: {classification.drawing_type.value}')
print(f'Confidence: {classification.confidence:.0%}')
print(f'Signals: {classification.signals}')
print(f'Use OCR: {classification.use_ocr}')
print(f'Use Qwen: {classification.use_qwen}')

# Display first page
display(artifacts[0].get_thumbnail(800))

In [None]:
# Cell 6: Resolve Part Identity + Display Assembly Context
part_identity = resolve_part_identity(DRAWING_PDF_PATH, artifacts, sw_library)

print('='*50)
print('RESOLVED PART IDENTITY')
print('='*50)
print(f'Part Number:  {part_identity.partNumber}')
print(f'Confidence:   {part_identity.confidence}')
print(f'Source:       {part_identity.source}')
print(f'SW JSON:      {part_identity.swJsonPath or "Not found"}')
print(f'Candidates:   {part_identity.candidates_tried[:5]}')

# === Assembly Context Lookup ===
assembly_context = None
mate_thread_reqs = []

ctx = get_part_context(part_identity.partNumber)
if ctx:
    assembly_context = ctx
    hierarchy = ctx.get('hierarchy', {})
    mating = ctx.get('mating', {})
    siblings = hierarchy.get('siblings', [])
    mates_with = mating.get('mates_with', [])
    mate_reqs_from_mates = mating.get('requirements_from_mates', [])

    print()
    print('='*50)
    print('ASSEMBLY CONTEXT')
    print('='*50)
    print(f'Parent Assembly: {hierarchy.get("parent_assembly", "Unknown")}')
    print(f'Hierarchy Path:  {hierarchy.get("hierarchy_path", "N/A")}')

    if siblings:
        print(f'\nSiblings ({len(siblings)} parts in same assembly):')
        for s in siblings[:8]:
            print(f'  - {s.get("pn", s.get("name", "?"))} ({s.get("desc", s.get("description", ""))})')
        if len(siblings) > 8:
            print(f'  ... and {len(siblings) - 8} more')

    if mates_with:
        print(f'\nMate Relationships ({len(mates_with)} total):')
        for m in mates_with[:6]:
            thread_info = f" [{m['thread']}x{m.get('pitch', '')}]" if m.get('thread') else ''
            print(f'  - {m.get("mate_type", "MATE")}: {m.get("part", "?")} ({m.get("description", "")}){thread_info}')
        if len(mates_with) > 6:
            print(f'  ... and {len(mates_with) - 6} more')

    if mate_reqs_from_mates:
        print(f'\nMate-Derived Requirements ({len(mate_reqs_from_mates)}):')
        for req in mate_reqs_from_mates[:5]:
            print(f'  - {req}')
        if len(mate_reqs_from_mates) > 5:
            print(f'  ... and {len(mate_reqs_from_mates) - 5} more')
else:
    print('\nAssembly context: Not found in database')

# Check inspector requirements
insp_reqs = get_inspector_requirements(part_identity.partNumber)
if insp_reqs:
    reqs_by_type = insp_reqs.get('requirements_by_type', {})
    thread_holes = reqs_by_type.get('thread_holes', [])
    if thread_holes:
        print(f'\nThread Hole Requirements from Mates ({len(thread_holes)}):')
        for th in thread_holes:
            print(f'  - {th}')
        mate_thread_reqs = thread_holes

# Save outputs
identity_out = os.path.join(OUTPUT_DIR, 'ResolvedPartIdentity.json')
with open(identity_out, 'w') as f:
    json.dump(asdict(part_identity), f, indent=2)
print(f'\nSaved: {identity_out}')

if assembly_context:
    ctx_out = os.path.join(OUTPUT_DIR, 'AssemblyContext.json')
    with open(ctx_out, 'w') as f:
        json.dump({
            'partNumber': part_identity.partNumber,
            'hierarchy': assembly_context.get('hierarchy', {}),
            'mating': assembly_context.get('mating', {}),
            'identity': assembly_context.get('identity', {}),
        }, f, indent=2)
    print(f'Saved: {ctx_out}')

In [None]:
# Cell 7: Page Classification + Run OCR and Qwen Analysis
print('='*50)
print('PAGE CLASSIFICATION')
print('='*50)

# Classify each page
pages_needing_ocr = []
pages_with_bom = []
pages_with_details = []

for art in artifacts:
    page_class = run_qwen_analysis(art.image, PAGE_CLASSIFICATION_PROMPT)
    page_type = page_class.get('pageType', 'PART_DETAIL')
    art.drawing_type = page_type
    art.has_bom = page_class.get('hasBOM', False)
    
    # Override: PART_DETAIL always gets OCR
    if page_type == 'PART_DETAIL' or page_type == 'MIXED':
        art.needs_ocr = True
        pages_needing_ocr.append(art)
        pages_with_details.append(art)
    elif page_type == 'ASSEMBLY_BOM':
        art.needs_ocr = False
        pages_with_bom.append(art)
    else:
        art.needs_ocr = classification.use_ocr
        if art.needs_ocr:
            pages_needing_ocr.append(art)
    
    print(f'  Page {art.page}: {page_type} (OCR: {art.needs_ocr}, BOM: {art.has_bom})')

overall_drawing_type = classification.drawing_type.value.upper()
print(f'\nOverall type: {overall_drawing_type}')
print(f'Pages needing OCR: {[p.page for p in pages_needing_ocr]}')
print(f'Pages with BOM: {[p.page for p in pages_with_bom]}')

# === Run OCR ===
print('\n' + '='*50)
print('RUNNING OCR')
print('='*50)

all_ocr_lines = []
if pages_needing_ocr:
    print(f'Running OCR on {len(pages_needing_ocr)} page(s)...')
    for art in pages_needing_ocr:
        print(f'  Processing Page {art.page}...')
        try:
            page_ocr = run_lighton_ocr(art.image)
            print(f'    Extracted {len(page_ocr)} lines')
            all_ocr_lines.extend(page_ocr)
        except Exception as e:
            print(f'    Error: {e}')
    print(f'\nTotal OCR lines: {len(all_ocr_lines)}')
else:
    print('SKIPPING OCR - No pages require text extraction')

ocr_lines = all_ocr_lines

# === Run Qwen Analyses ===
print('\n' + '='*50)
print('QWEN DRAWING ANALYSIS')
print('='*50)

# Select primary image for analysis
primary_image = pages_with_details[0].image if pages_with_details else artifacts[0].image

print('Analyzing features...')
qwen_understanding = run_qwen_analysis(primary_image, FEATURE_EXTRACTION_PROMPT)
if 'parse_error' not in qwen_understanding:
    print(f'  Part: {qwen_understanding.get("partDescription", "N/A")}')
    print(f'  Features: {len(qwen_understanding.get("features", []))}')

print('\nAuditing drawing quality...')
drawing_quality = run_qwen_analysis(artifacts[0].image, QUALITY_AUDIT_PROMPT)

print('\nExtracting BOM...')
bom_image = pages_with_bom[0].image if pages_with_bom else artifacts[0].image
bom_data = run_qwen_analysis(bom_image, BOM_EXTRACTION_PROMPT)
print(f'  BOM found: {bom_data.get("hasBOM", False)}')

print('\nExtracting manufacturing notes...')
mfg_notes = run_qwen_analysis(artifacts[0].image, MANUFACTURING_NOTES_PROMPT)

# Save Qwen outputs
qwen_out = os.path.join(OUTPUT_DIR, 'QwenUnderstanding.json')
with open(qwen_out, 'w') as f:
    json.dump({
        'featureAnalysis': qwen_understanding,
        'qualityAudit': drawing_quality,
        'bomExtraction': bom_data,
        'manufacturingNotes': mfg_notes
    }, f, indent=2)
print(f'\nSaved: {qwen_out}')

In [None]:
# Cell 8: Merge Evidence + Compare to SW Requirements
print('='*50)
print('MERGING EVIDENCE')
print('='*50)

# Parse OCR callouts using package
ocr_callouts = parse_ocr_callouts(ocr_lines, verbose=True) if ocr_lines else []
print(f'OCR callouts parsed: {len(ocr_callouts)}')

# Convert Qwen features to callout format
qwen_callouts = []
for feat in qwen_understanding.get('features', []):
    ftype = feat.get('type', '').lower()
    callout = feat.get('callout', '')
    if 'tapped' in ftype or 'thread' in ftype:
        qwen_callouts.append({'calloutType': 'TappedHole', 'raw': callout, 'source': 'qwen'})
    elif 'through' in ftype and 'hole' in ftype:
        qwen_callouts.append({'calloutType': 'Hole', 'isThrough': True, 'raw': callout, 'source': 'qwen'})
    elif 'blind' in ftype and 'hole' in ftype:
        qwen_callouts.append({'calloutType': 'Hole', 'isThrough': False, 'raw': callout, 'source': 'qwen'})
    elif 'fillet' in ftype:
        qwen_callouts.append({'calloutType': 'Fillet', 'raw': callout, 'source': 'qwen'})
    elif 'chamfer' in ftype:
        qwen_callouts.append({'calloutType': 'Chamfer', 'raw': callout, 'source': 'qwen'})

# Merge with deduplication
seen_raws = set(c.get('raw', '') for c in ocr_callouts)
merged_callouts = list(ocr_callouts)
for qc in qwen_callouts:
    if qc.get('raw') and qc['raw'] not in seen_raws:
        merged_callouts.append(qc)
        seen_raws.add(qc['raw'])

print(f'Merged callouts: {len(merged_callouts)}')

# Build evidence dict
evidence = {
    'schemaVersion': '4.0.0',
    'partNumber': part_identity.partNumber,
    'units': 'inches',
    'drawingInfo': {
        'partDescription': qwen_understanding.get('partDescription', ''),
        'material': qwen_understanding.get('material', ''),
        'views': qwen_understanding.get('views', []),
        'notes': qwen_understanding.get('notes', []),
    },
    'foundCallouts': merged_callouts,
    'rawOcrSample': ocr_lines[:20] if ocr_lines else []
}

evidence_out = os.path.join(OUTPUT_DIR, 'Evidence.json')
with open(evidence_out, 'w') as f:
    json.dump(evidence, f, indent=2)
print(f'Saved: {evidence_out}')

# === Generate DiffResult ===
print('\n' + '='*50)
print('COMPARISON TO SW REQUIREMENTS')
print('='*50)

sw_data = None
diff_result = None
has_sw_comparison = False

if part_identity.swJsonPath:
    sw_data, err = load_json_robust(part_identity.swJsonPath)
    if sw_data:
        has_sw_comparison = True
        
        # Extract requirements from SW
        requirements = extract_sw_requirements(sw_data)
        
        # Add mate-derived requirements
        mate_reqs = extract_mate_requirements(
            part_identity.partNumber,
            inspector_requirements=insp_reqs,
            part_context=assembly_context
        )
        if mate_reqs:
            print(f'Adding {len(mate_reqs)} mate-derived requirements')
            requirements.extend(mate_reqs)
        
        print(f'\nSW Requirements: {len(requirements)}')
        for req in requirements[:10]:
            if req['type'] == 'TappedHole':
                print(f'  - {req["type"]}: {req.get("thread", {}).get("callout", "")} [{req.get("source", "")}]')
            elif req['type'] == 'Hole':
                print(f'  - {req["type"]}: {req.get("diameterInches", 0):.4f}" [{req.get("source", "")}]')
            else:
                print(f'  - {req["type"]}: {req.get("canonical", "")} [{req.get("source", "")}]')
        
        # Generate comparison
        diff_result = generate_diff_result(
            callouts=merged_callouts,
            requirements=requirements,
            part_number=part_identity.partNumber
        )
    else:
        print(f'Error loading SW JSON: {err}')

if not has_sw_comparison:
    print('NO SOLIDWORKS DATA AVAILABLE')
    diff_result = create_stub_diff_result(part_identity.partNumber)

# Display results
print('\n' + '='*50)
print('DIFF RESULT')
print('='*50)
summary = diff_result.get('summary', {})
print(f'Total Requirements: {summary.get("totalRequirements", 0)}')
print(f'FOUND:   {summary.get("found", 0)}')
print(f'MISSING: {summary.get("missing", 0)}')
print(f'EXTRA:   {summary.get("extra", 0)}')
print(f'Match Rate: {summary.get("matchRate", "N/A")}')
if summary.get('mateRequirements', 0) > 0:
    print(f'Mate-Derived: {summary.get("mateRequirements", 0)}')

if diff_result.get('details', {}).get('missing'):
    print('\nMissing from drawing:')
    for item in diff_result['details']['missing'][:5]:
        print(f'  X {item.get("note", "")}')

diff_out = os.path.join(OUTPUT_DIR, 'DiffResult.json')
with open(diff_out, 'w') as f:
    json.dump(diff_result, f, indent=2)
print(f'\nSaved: {diff_out}')

In [None]:
# Cell 9: Display Summary + Download Outputs
print('='*50)
print('INSPECTION SUMMARY')
print('='*50)

print(f'\nPart Number: {part_identity.partNumber}')
print(f'Drawing Type: {classification.drawing_type.value}')
print(f'SW Data: {"Available" if has_sw_comparison else "NOT AVAILABLE"}')

if has_sw_comparison:
    print(f'\nMatch Rate: {diff_result.get("summary", {}).get("matchRate", "N/A")}')
    missing_count = diff_result.get('summary', {}).get('missing', 0)
    if missing_count == 0:
        print('Status: PASS - All requirements found on drawing')
    else:
        print(f'Status: REVIEW NEEDED - {missing_count} requirement(s) missing')
else:
    print('Status: REVIEW NEEDED - No CAD comparison available')

if assembly_context:
    print(f'\nAssembly Context: Available')
    print(f'  Parent: {assembly_context.get("hierarchy", {}).get("parent_assembly", "N/A")}')
    print(f'  Siblings: {len(assembly_context.get("hierarchy", {}).get("siblings", []))}')
    print(f'  Mates: {len(assembly_context.get("mating", {}).get("mates_with", []))}')

# List output files
print('\n' + '='*50)
print('OUTPUT FILES')
print('='*50)
for filename in sorted(os.listdir(OUTPUT_DIR)):
    filepath = os.path.join(OUTPUT_DIR, filename)
    size = os.path.getsize(filepath)
    print(f'  {filename} ({size:,} bytes)')

# Download all outputs
print('\nDownloading outputs...')
for filename in os.listdir(OUTPUT_DIR):
    files.download(os.path.join(OUTPUT_DIR, filename))

print('\nDone!')