# Document Type Classification Testing - Phase 1

Interactive testing of the document type detection system.
Test classification accuracy and tune parameters.

In [None]:
import sys
from pathlib import Path
from IPython.display import display, HTML, Markdown
from PIL import Image
import time

# Add project root to path
project_root = Path('..').absolute()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f"📂 Project root: {project_root}")
print("✅ Environment configured")

In [None]:
# Import required modules
try:
    from common.document_type_detector import DocumentTypeDetector
    from common.extraction_parser import discover_images
    print("✅ Imports successful")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("💡 Make sure all required files exist")

## Configuration

In [None]:
# Configuration
MODEL_TYPE = "llama"  # Change to "internvl3" to test different model
TEST_DIRECTORY = "../evaluation_data"
CONFIDENCE_THRESHOLD = 0.85

print(f"🎯 Configuration:")
print(f"   Model: {MODEL_TYPE}")
print(f"   Test directory: {TEST_DIRECTORY}")
print(f"   Confidence threshold: {CONFIDENCE_THRESHOLD}")

## Initialize Model and Detector

In [None]:
# Initialize processor
print(f"🚀 Initializing {MODEL_TYPE.upper()} processor...")

try:
    if MODEL_TYPE.lower() == "llama":
        from models.llama_processor import LlamaProcessor
        processor = LlamaProcessor()
    elif MODEL_TYPE.lower() == "internvl3":
        from models.internvl3_processor import InternVL3Processor
        processor = InternVL3Processor()
    else:
        raise ValueError(f"Unsupported model type: {MODEL_TYPE}")
    
    print(f"✅ {MODEL_TYPE.upper()} processor initialized successfully")
    
except Exception as e:
    print(f"❌ Failed to initialize processor: {e}")
    print("💡 Make sure you're running on a machine with GPU and model access")
    processor = None

In [None]:
# Initialize document type detector
if processor:
    detector = DocumentTypeDetector(processor)
    detector.confidence_threshold = CONFIDENCE_THRESHOLD
    
    print(f"✅ Document type detector initialized")
    print(f"🎯 Supported types: {detector.supported_types}")
    print(f"🎯 Confidence threshold: {detector.confidence_threshold}")
else:
    detector = None
    print("❌ Cannot initialize detector without processor")

## Single Image Classification Test

In [None]:
# Test single image classification
TEST_IMAGE = "../evaluation_data/synthetic_invoice_001.png"

if detector and Path(TEST_IMAGE).exists():
    display(HTML("<h3>🖼️ Test Image</h3>"))
    
    # Display image
    img = Image.open(TEST_IMAGE)
    if img.width > 600:
        ratio = 600 / img.width
        new_height = int(img.height * ratio)
        img_display = img.resize((600, new_height), Image.Resampling.LANCZOS)
        display(img_display)
    else:
        display(img)
    
    print(f"📁 Testing: {Path(TEST_IMAGE).name}")
    print(f"📐 Image size: {img.width}x{img.height}")
    
else:
    print(f"❌ Cannot test - detector: {detector is not None}, image exists: {Path(TEST_IMAGE).exists()}")

In [None]:
# Run classification
if detector and Path(TEST_IMAGE).exists():
    print("🔍 Running document type classification...")
    
    result = detector.detect_document_type(TEST_IMAGE)
    
    display(HTML("<h3>📊 Classification Result</h3>"))
    
    # Display results in a nice format
    confidence_color = "green" if result['confidence'] >= CONFIDENCE_THRESHOLD else "orange" if result['confidence'] >= 0.5 else "red"
    confidence_icon = "✅" if result['confidence'] >= CONFIDENCE_THRESHOLD else "⚠️" if result['confidence'] >= 0.5 else "❌"
    
    result_html = f"""
    <div style="border: 2px solid #ccc; padding: 20px; border-radius: 10px; background-color: #f9f9f9;">
        <h4>🏷️ Document Type: <span style="color: {confidence_color}; font-weight: bold;">{result['type'].upper()}</span></h4>
        <p><strong>🎯 Confidence:</strong> <span style="color: {confidence_color};">{result['confidence']:.3f}</span> {confidence_icon}</p>
        <p><strong>⏱️ Processing Time:</strong> {result['processing_time']:.2f} seconds</p>
        <p><strong>💭 Reasoning:</strong> {result.get('reasoning', 'None provided')}</p>
    """
    
    if result.get('fallback_used'):
        result_html += '<p><strong>⚠️ Note:</strong> Fallback classification was used</p>'
    
    if result.get('manual_review_needed'):
        result_html += '<p><strong>🔍 Note:</strong> Manual review recommended</p>'
    
    result_html += "</div>"
    
    display(HTML(result_html))
    
    # Show raw response for debugging
    if 'raw_response' in result:
        display(HTML("<h4>🔧 Raw Model Response (for debugging):</h4>"))
        display(HTML(f'<div style="background-color: #f0f0f0; padding: 10px; border-radius: 5px; font-family: monospace; white-space: pre-wrap; max-height: 200px; overflow-y: auto;">{result["raw_response"]}</div>'))

else:
    print("❌ Cannot run classification test")

## Batch Classification Test

In [None]:
# Discover available test images
if detector:
    try:
        test_images = discover_images(TEST_DIRECTORY)
        print(f"📂 Found {len(test_images)} test images in {TEST_DIRECTORY}")
        
        # Show first few images
        for i, img_path in enumerate(test_images[:5]):
            print(f"   {i+1}. {Path(img_path).name}")
        
        if len(test_images) > 5:
            print(f"   ... and {len(test_images) - 5} more")
            
    except Exception as e:
        print(f"❌ Error discovering images: {e}")
        test_images = []
else:
    test_images = []

In [None]:
# Run batch classification
if detector and test_images:
    display(HTML("<h3>🔬 Running Batch Classification</h3>"))
    
    print(f"Processing {len(test_images)} images...")
    
    # Run batch classification
    batch_results = detector.batch_classify_images(TEST_DIRECTORY)
    
    print(f"\n✅ Batch classification completed!")
    
else:
    print("❌ Cannot run batch classification")
    batch_results = []

In [None]:
# Display batch results
if batch_results:
    display(HTML("<h3>📊 Batch Classification Report</h3>"))
    
    # Generate report
    report = detector.generate_classification_report(batch_results)
    
    # Display as formatted text
    display(HTML(f'<pre style="background-color: #f8f8f8; padding: 15px; border-radius: 5px; font-family: monospace;">{report}</pre>'))
    
else:
    print("❌ No batch results to display")

In [None]:
# Detailed results table
if batch_results:
    display(HTML("<h3>📋 Detailed Results</h3>"))
    
    # Create HTML table
    table_html = """
    <table style="border-collapse: collapse; width: 100%;">
        <thead style="background-color: #f0f0f0;">
            <tr>
                <th style="border: 1px solid #ccc; padding: 8px; text-align: left;">Image</th>
                <th style="border: 1px solid #ccc; padding: 8px; text-align: left;">Type</th>
                <th style="border: 1px solid #ccc; padding: 8px; text-align: left;">Confidence</th>
                <th style="border: 1px solid #ccc; padding: 8px; text-align: left;">Time (s)</th>
                <th style="border: 1px solid #ccc; padding: 8px; text-align: left;">Notes</th>
            </tr>
        </thead>
        <tbody>
    """
    
    for result in batch_results:
        if result.get('error'):
            table_html += f"""
            <tr>
                <td style="border: 1px solid #ccc; padding: 8px;">{result['image_name']}</td>
                <td style="border: 1px solid #ccc; padding: 8px; color: red;">ERROR</td>
                <td style="border: 1px solid #ccc; padding: 8px;">-</td>
                <td style="border: 1px solid #ccc; padding: 8px;">-</td>
                <td style="border: 1px solid #ccc; padding: 8px; color: red;">{result.get('reasoning', 'Unknown error')}</td>
            </tr>
            """
        else:
            confidence = result['confidence']
            confidence_color = "green" if confidence >= CONFIDENCE_THRESHOLD else "orange" if confidence >= 0.5 else "red"
            confidence_icon = "✅" if confidence >= CONFIDENCE_THRESHOLD else "⚠️" if confidence >= 0.5 else "❌"
            
            notes = []
            if result.get('fallback_used'):
                notes.append('Fallback')
            if result.get('manual_review_needed'):
                notes.append('Manual Review')
            notes_text = ', '.join(notes) if notes else '-'
            
            table_html += f"""
            <tr>
                <td style="border: 1px solid #ccc; padding: 8px;">{result['image_name']}</td>
                <td style="border: 1px solid #ccc; padding: 8px; color: {confidence_color}; font-weight: bold;">{result['type'].upper()}</td>
                <td style="border: 1px solid #ccc; padding: 8px; color: {confidence_color};">{confidence:.3f} {confidence_icon}</td>
                <td style="border: 1px solid #ccc; padding: 8px;">{result.get('processing_time', 0):.2f}</td>
                <td style="border: 1px solid #ccc; padding: 8px;">{notes_text}</td>
            </tr>
            """
    
    table_html += """
        </tbody>
    </table>
    """
    
    display(HTML(table_html))

## Ground Truth Analysis (if available)

In [None]:
# Simple ground truth analysis based on filenames
if batch_results:
    display(HTML("<h3>🔍 Ground Truth Analysis</h3>"))
    print("Note: This uses filename patterns to infer document types")
    
    correct_predictions = 0
    total_predictions = 0
    accuracy_data = []
    
    for result in batch_results:
        if result.get('error'):
            continue
        
        filename = result['image_name'].lower()
        predicted_type = result['type']
        
        # Infer actual type from filename
        actual_type = None
        if 'invoice' in filename:
            actual_type = 'invoice'
        elif 'statement' in filename or 'bank' in filename:
            actual_type = 'bank_statement'
        elif 'receipt' in filename:
            actual_type = 'receipt'
        
        if actual_type:
            total_predictions += 1
            is_correct = (predicted_type == actual_type)
            if is_correct:
                correct_predictions += 1
            
            accuracy_data.append({
                'filename': result['image_name'],
                'predicted': predicted_type,
                'actual': actual_type,
                'correct': is_correct,
                'confidence': result['confidence']
            })
    
    if total_predictions > 0:
        accuracy = (correct_predictions / total_predictions) * 100
        
        print(f"\n📊 ACCURACY METRICS:")
        print(f"   Correct predictions: {correct_predictions}/{total_predictions}")
        print(f"   Overall accuracy: {accuracy:.1f}%")
        
        # Accuracy assessment
        if accuracy >= 95:
            print("🎉 Excellent accuracy! Ready for production.")
        elif accuracy >= 85:
            print("👍 Good accuracy. Consider minor tuning.")
        elif accuracy >= 70:
            print("⚠️ Moderate accuracy. Tuning recommended.")
        else:
            print("🚨 Low accuracy. Significant improvements needed.")
        
        # Show misclassifications
        misclassifications = [item for item in accuracy_data if not item['correct']]
        if misclassifications:
            print(f"\n❌ MISCLASSIFICATIONS ({len(misclassifications)}):"))
            for item in misclassifications:
                print(f"   {item['filename']}: {item['predicted']} != {item['actual']} (conf: {item['confidence']:.2f})")
    
    else:
        print("💡 No ground truth data available from filenames")
        print("💡 Consider using more descriptive filenames for accuracy testing")

else:
    print("❌ No results available for ground truth analysis")

## Performance Tuning

In [None]:
# Performance analysis and tuning recommendations
if batch_results:
    display(HTML("<h3>⚡ Performance Analysis</h3>"))
    
    # Calculate performance metrics
    processing_times = [r.get('processing_time', 0) for r in batch_results if not r.get('error')]
    confidences = [r['confidence'] for r in batch_results if not r.get('error')]
    fallback_count = sum(1 for r in batch_results if r.get('fallback_used'))
    
    if processing_times:
        avg_time = sum(processing_times) / len(processing_times)
        max_time = max(processing_times)
        min_time = min(processing_times)
        
        print(f"⏱️ PROCESSING TIMES:")
        print(f"   Average: {avg_time:.2f}s")
        print(f"   Range: {min_time:.2f}s - {max_time:.2f}s")
        
        if avg_time < 2.0:
            print("🚀 Excellent speed performance")
        elif avg_time < 5.0:
            print("👍 Good speed performance")
        else:
            print("⚠️ Consider optimizing for better speed")
    
    if confidences:
        avg_confidence = sum(confidences) / len(confidences)
        high_conf_count = sum(1 for c in confidences if c >= CONFIDENCE_THRESHOLD)
        high_conf_rate = (high_conf_count / len(confidences)) * 100
        
        print(f"\n🎯 CONFIDENCE METRICS:")
        print(f"   Average confidence: {avg_confidence:.3f}")
        print(f"   High confidence rate: {high_conf_rate:.1f}% ({high_conf_count}/{len(confidences)})")
        print(f"   Fallback usage: {fallback_count} classifications")
    
    # Tuning recommendations
    print(f"\n💡 TUNING RECOMMENDATIONS:")
    
    if avg_confidence < 0.8:
        print(f"   - Consider lowering confidence threshold to {avg_confidence:.1f}")
    
    if fallback_count > len(batch_results) * 0.2:
        print(f"   - High fallback usage - refine classification prompts")
    
    if avg_time > 3.0:
        print(f"   - Consider reducing max_tokens for faster classification")
    
    if high_conf_rate < 80:
        print(f"   - Low high-confidence rate - improve prompt specificity")

else:
    print("❌ No results available for performance analysis")

## Summary and Next Steps

In [None]:
# Summary
display(HTML("<h3>📋 Phase 1 Summary</h3>"))

if batch_results:
    successful = len([r for r in batch_results if not r.get('error')])
    total = len(batch_results)
    success_rate = (successful / total * 100) if total > 0 else 0
    
    print(f"✅ PHASE 1 COMPLETE: Document Type Detection")
    print(f"\n📊 RESULTS SUMMARY:")
    print(f"   Documents processed: {total}")
    print(f"   Successful classifications: {successful} ({success_rate:.1f}%)")
    print(f"   Model used: {MODEL_TYPE.upper()}")
    print(f"   Confidence threshold: {CONFIDENCE_THRESHOLD}")
    
    print(f"\n🎯 NEXT STEPS FOR PHASE 2:")
    print(f"   1. ✅ Document classification working")
    print(f"   2. 📝 Create document-specific schemas")
    print(f"   3. 🔄 Implement schema routing logic")
    print(f"   4. 🧪 Test type-specific extraction")
    
    if success_rate >= 90:
        print(f"\n🎉 Classification accuracy looks excellent! Ready for Phase 2.")
    elif success_rate >= 80:
        print(f"\n👍 Good classification results. Minor tuning may help before Phase 2.")
    else:
        print(f"\n⚠️ Classification needs improvement before proceeding to Phase 2.")

else:
    print("❌ Phase 1 testing incomplete - no results available")
    print("💡 Make sure to run on a machine with GPU and model access")

print(f"\n📚 DOCUMENTATION:")
print(f"   Phase 1 implementation: common/document_type_detector.py")
print(f"   Testing script: test_document_classification.py")
print(f"   Full proposal: docs/document_type_specific_extraction_proposal.md")