In [None]:
# Configuration - All settings at top of notebook
import textwrap

print("🏆 INFORMATION EXTRACTION COMPARISON: Llama 3.2 Vision vs InternVL3")
print("🎯 Focus: Information extraction performance with simplified unified prompts")
print("=" * 80)

# CONFIGURATION - All settings defined here
CONFIG = {
    "model_paths": {
        "llama": "/home/jovyan/nfs_share/models/Llama-3.2-11B-Vision",
        "internvl": "/home/jovyan/nfs_share/models/InternVL3-8B"
    },
    # SIMPLIFIED: Back to proven KEY-VALUE format with Australian business requirements
    "extraction_prompt": textwrap.dedent('''
        <|image|>Extract data from this Australian business document in KEY-VALUE format.

        Output format:
        STORE: [business name]
        ABN: [11-digit Australian Business Number]
        DATE: [date in DD/MM/YYYY format]
        TOTAL: [total amount in AUD]
        SUBTOTAL: [subtotal amount]
        GST: [GST amount]
        ITEMS: [item names separated by |]

        ABN is crucial - look for 11-digit numbers formatted as XX XXX XXX XXX or XXXXXXXXXXX. Use Australian date format (DD/MM/YYYY) and include currency symbols. Extract all visible text and format as KEY: VALUE pairs only. Stop after completion.
    ''').strip(),

    "max_new_tokens": 64,  # Back to original limit
    "enable_quantization": True,
    "test_models": ["llama", "internvl"],
    "test_images": [
        ("image14.png", "TAX_INVOICE"),
        ("image65.png", "TAX_INVOICE"),
        ("image71.png", "TAX_INVOICE"),
        ("image74.png", "TAX_INVOICE"),
        ("image205.png", "FUEL_RECEIPT"),
        ("image23.png", "TAX_INVOICE"),
        ("image45.png", "TAX_INVOICE"),
        ("image1.png", "BANK_STATEMENT"),
        ("image39.png", "TAX_INVOICE"),
        ("image76.png", "TAX_INVOICE"),
        ("image71.png", "TAX_INVOICE"),
    ]
}

print("✅ Configuration loaded:")
print(f"   - Models: {', '.join(CONFIG['test_models'])}")
print(f"   - Documents: {len(CONFIG['test_images'])} test images")
print("   - Prompt: SIMPLIFIED KEY-VALUE format (Australian ABN + dates DD/MM/YYYY)")
print(f"   - Max tokens: {CONFIG['max_new_tokens']}")
print(f"   - Quantization: {CONFIG['enable_quantization']}")
print("\n📋 Ready for step-by-step information extraction comparison")

In [None]:
# Imports and Modular Classes
import gc
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
from PIL import Image

# CUDA DIAGNOSTICS - Check GPU availability
print("🔍 CUDA DIAGNOSTICS:")
print(f"   CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   PyTorch Version: {torch.__version__}")
    print(f"   GPU Count: {torch.cuda.device_count()}")
    print(f"   Current Device: {torch.cuda.current_device()}")
    print(f"   Device Name: {torch.cuda.get_device_name(0)}")
    print(f"   Device Capability: {torch.cuda.get_device_capability(0)}")
else:
    print("   ❌ CUDA not available - models will load on CPU!")

class MemoryManager:
    """Memory management and monitoring utilities for model testing"""

    @staticmethod
    def cleanup_gpu_memory():
        """Minimize memory footprint as requested"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

    @staticmethod
    def get_memory_usage() -> Dict[str, float]:
        """Get current GPU memory usage in GB"""
        if torch.cuda.is_available():
            return {
                "allocated": torch.cuda.memory_allocated() / 1024**3,
                "reserved": torch.cuda.memory_reserved() / 1024**3,
                "free": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**3
            }
        return {"allocated": 0.0, "reserved": 0.0, "free": 0.0}

    @staticmethod
    def print_memory_usage(label: str = "Memory"):
        """Print formatted memory usage"""
        if torch.cuda.is_available():
            memory = MemoryManager.get_memory_usage()
            total_gpu = torch.cuda.get_device_properties(0).total_memory / 1024**3
            print(f"   💾 {label}: {memory['allocated']:.1f}GB allocated | {memory['reserved']:.1f}GB reserved | {memory['free']:.1f}GB free | {total_gpu:.1f}GB total")
        else:
            print(f"   💾 {label}: No CUDA available")

    @staticmethod
    def get_memory_delta(before: Dict[str, float], after: Dict[str, float]) -> Dict[str, float]:
        """Calculate memory usage delta"""
        return {
            "allocated_delta": after["allocated"] - before["allocated"],
            "reserved_delta": after["reserved"] - before["reserved"]
        }

class UltraAggressiveRepetitionController:
    """Business document repetition detection and cleanup"""

    def __init__(self, word_threshold: float = 0.15, phrase_threshold: int = 2):
        self.word_threshold = word_threshold
        self.phrase_threshold = phrase_threshold

        # Business document specific repetition patterns
        self.toxic_patterns = [
            r"THANK YOU FOR SHOPPING WITH US[^.]*",
            r"All prices include GST where applicable[^.]*",
            r"applicable\.\s*applicable\.",
            r"GST where applicable[^.]*applicable",
            r"\\+[a-zA-Z]*\{[^}]*\}",  # LaTeX artifacts
            r"\(\s*\)",  # Empty parentheses
            r"[.-]\s*THANK YOU",
        ]

    def clean_response(self, response: str) -> str:
        """Clean business document extraction response"""
        if not response or len(response.strip()) == 0:
            return ""

        # Remove toxic business document patterns
        response = self._remove_business_patterns(response)

        # Remove repetitive words and phrases
        response = self._remove_word_repetition(response)
        response = self._remove_phrase_repetition(response)

        # Clean artifacts
        response = re.sub(r'\s+', ' ', response)
        response = re.sub(r'[.]{2,}', '.', response)
        response = re.sub(r'[!]{2,}', '!', response)

        return response.strip()

    def _remove_business_patterns(self, text: str) -> str:
        """Remove business document specific repetitive patterns"""
        for pattern in self.toxic_patterns:
            text = re.sub(pattern, "", text, flags=re.IGNORECASE)

        # Remove excessive "applicable" repetition
        text = re.sub(r'(applicable\.\s*){2,}', 'applicable. ', text, flags=re.IGNORECASE)

        return text

    def _remove_word_repetition(self, text: str) -> str:
        """Remove word repetition in business documents"""
        # Remove consecutive identical words
        text = re.sub(r'\b(\w+)(\s+\1){1,}', r'\1', text, flags=re.IGNORECASE)

        return text

    def _remove_phrase_repetition(self, text: str) -> str:
        """Remove phrase repetition"""
        for phrase_length in range(2, 7):
            pattern = r'\b((?:\w+\s+){' + str(phrase_length-1) + r'}\w+)(\s+\1){1,}'
            text = re.sub(pattern, r'\1', text, flags=re.IGNORECASE)

        return text

class KeyValueExtractionAnalyzer:
    """Analyzer for KEY-VALUE extraction results with realistic Australian business requirements"""

    @staticmethod
    def analyze(response: str, img_name: str) -> Dict[str, Any]:
        """Analyze KEY-VALUE extraction results with Australian format but realistic success criteria"""
        response_clean = response.strip()

        # Detect KEY-VALUE format (including ABN patterns)
        is_structured = bool(re.search(r'(STORE:|ABN:|DATE:|TOTAL:|store_name:|abn:|date:|total:)', response_clean, re.IGNORECASE))

        # Extract data from KEY-VALUE format
        store_match = re.search(r'(?:STORE|store_name):\s*"?([^"\n]+)"?', response_clean, re.IGNORECASE)
        abn_match = re.search(r'(?:ABN|abn):\s*"?([^"\n]+)"?', response_clean, re.IGNORECASE)
        date_match = re.search(r'(?:DATE|date):\s*"?([^"\n]+)"?', response_clean, re.IGNORECASE)
        total_match = re.search(r'(?:TOTAL|total_amount|total):\s*"?([^"\n]+)"?', response_clean, re.IGNORECASE)

        # Fallback detection for non-structured responses
        if not store_match:
            # Australian store patterns
            store_match = re.search(r'(spotlight|woolworths|coles|bunnings|officeworks|kmart|target|harvey norman|jb hi-fi)', response_clean, re.IGNORECASE)

        if not abn_match:
            # ABN patterns: 11 digits, formatted as XX XXX XXX XXX or XXXXXXXXXXX
            abn_match = re.search(r'\b(\d{2}\s*\d{3}\s*\d{3}\s*\d{3}|\d{11})\b', response_clean)
            # Also look for "ABN:" prefix patterns
            if not abn_match:
                abn_match = re.search(r'(?:ABN|A\.B\.N\.?)\s*:?\s*(\d{2}\s*\d{3}\s*\d{3}\s*\d{3}|\d{11})', response_clean, re.IGNORECASE)

        if not date_match:
            # Australian date format patterns: DD/MM/YYYY, DD-MM-YYYY, DD.MM.YYYY
            date_match = re.search(r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{1,2}\.\d{1,2}\.\d{2,4})\b', response_clean)

        if not total_match:
            # Australian currency patterns: $X.XX, AUD X.XX
            total_match = re.search(r'(\$\d+\.\d{2}|\$\d+|AUD\s*\d+\.\d{2})', response_clean)

        has_store = bool(store_match)
        has_abn = bool(abn_match)
        has_date = bool(date_match)
        has_total = bool(total_match)

        # Core business fields for extraction (STORE, DATE, TOTAL are essential)
        core_fields = [has_store, has_date, has_total]
        all_fields = [has_store, has_abn, has_date, has_total]

        extraction_score = sum(all_fields)
        core_score = sum(core_fields)

        # REALISTIC SUCCESS CRITERIA:
        # Success = at least 2/3 core fields (STORE, DATE, TOTAL)
        # ABN is bonus but not required for all document types
        successful = core_score >= 2

        return {
            "img_name": img_name,
            "response": response_clean,
            "is_structured": is_structured,
            "has_store": has_store,
            "has_abn": has_abn,
            "has_date": has_date,
            "has_total": has_total,
            "extraction_score": extraction_score,
            "core_score": core_score,
            "successful": successful  # Based on core fields, not ABN requirement
        }

class DatasetManager:
    """Dataset verification and management"""

    def __init__(self, datasets_path: str = "datasets"):
        self.datasets_path = Path(datasets_path)

    def verify_images(self, test_images: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
        """Verify that test images exist and return verified list"""
        verified_images = []

        for img_name, doc_type in test_images:
            img_path = self.datasets_path / img_name
            if img_path.exists():
                verified_images.append((img_name, doc_type))

        return verified_images

    def print_verification_report(self, test_images: List[Tuple[str, str]], verified_images: List[Tuple[str, str]]):
        """Print dataset verification report"""
        print("📊 DATASET VERIFICATION")
        print("=" * 50)

        for img_name, doc_type in test_images:
            img_path = self.datasets_path / img_name
            if img_path.exists():
                print(f"   ✅ {img_name:<12} → {doc_type}")
            else:
                print(f"   ❌ {img_name:<12} → {doc_type} (MISSING)")

        print("\n📋 Dataset Summary:")
        print(f"   - Expected: {len(test_images)} documents")
        print(f"   - Found: {len(verified_images)} documents")
        print(f"   - Missing: {len(test_images) - len(verified_images)} documents")

        if len(verified_images) == 0:
            print("❌ No test images found! Check datasets/ directory")
            raise FileNotFoundError("No test images found")
        elif len(verified_images) < len(test_images):
            print("⚠️ Some test images missing but proceeding with available images")
        else:
            print("✅ All test images found")

# Initialize global utilities
memory_manager = MemoryManager()
repetition_controller = UltraAggressiveRepetitionController()
extraction_analyzer = KeyValueExtractionAnalyzer()  # Updated for realistic ABN requirements
dataset_manager = DatasetManager()

print("\n✅ Modular classes initialized:")
print("   - MemoryManager for GPU cleanup and monitoring")
print("   - UltraAggressiveRepetitionController for text cleanup")
print("   - KeyValueExtractionAnalyzer for REALISTIC Australian business analysis")
print("   - DatasetManager for image verification")
print("\n💾 Initial GPU Status:")
memory_manager.print_memory_usage("Baseline")

In [None]:
# Dataset Verification
# Use the modular DatasetManager class

verified_extraction_images = dataset_manager.verify_images(CONFIG["test_images"])
dataset_manager.print_verification_report(CONFIG["test_images"], verified_extraction_images)

print("\n🔬 Unified Extraction Prompt:")
print(f"   {CONFIG['extraction_prompt'][:80]}...")
print("\n📋 Ready for sequential model testing with unified prompt")

In [None]:
# Model Loading Classes
class LlamaModelLoader:
    """Modular Llama model loader with validation"""

    @staticmethod
    def load_model(model_path: str, enable_quantization: bool = True):
        """Load Llama model with proper configuration"""
        from transformers import AutoProcessor, BitsAndBytesConfig, MllamaForConditionalGeneration

        processor = AutoProcessor.from_pretrained(
            model_path, trust_remote_code=True, local_files_only=True
        )

        model_kwargs = {
            "torch_dtype": torch.float16,
            "local_files_only": True
        }

        if enable_quantization:
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_enable_fp32_cpu_offload=True,
                llm_int8_skip_modules=["vision_tower", "multi_modal_projector"],
            )
            model_kwargs["quantization_config"] = quantization_config

        model = MllamaForConditionalGeneration.from_pretrained(
            model_path, **model_kwargs
        )

        # EXPLICIT: Set model to eval mode for inference
        model.eval()

        return model, processor

    @staticmethod
    def run_inference(model, processor, prompt: str, image, max_new_tokens: int = 64):
        """Run inference with proper device handling"""
        # SAFETY: Ensure model is in eval mode
        model.eval()

        inputs = processor(text=prompt, images=image, return_tensors="pt")
        device = next(model.parameters()).device
        if device.type != "cpu":
            device_target = str(device).split(":")[0]
            inputs = {k: v.to(device_target) if hasattr(v, "to") else v for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=processor.tokenizer.eos_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
            )

        raw_response = processor.decode(
            outputs[0][inputs["input_ids"].shape[-1]:],
            skip_special_tokens=True
        )

        # Cleanup tensors immediately
        del inputs, outputs

        return raw_response

class InternVLModelLoader:
    """Modular InternVL model loader with validation"""

    @staticmethod
    def load_model(model_path: str, enable_quantization: bool = True):
        """Load InternVL model with proper configuration"""
        import warnings

        from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig

        # Comprehensive warning suppression for InternVL
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            warnings.simplefilter("ignore", FutureWarning)

            tokenizer = AutoTokenizer.from_pretrained(
                model_path, trust_remote_code=True, local_files_only=True
            )

            # Set pad_token_id to eos_token_id to prevent warnings
            if tokenizer.pad_token_id is None:
                tokenizer.pad_token_id = tokenizer.eos_token_id

        model_kwargs = {
            "trust_remote_code": True,
            "torch_dtype": torch.bfloat16,
            "local_files_only": True
        }

        if enable_quantization:
            # Use BitsAndBytesConfig for proper quantization setup
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_enable_fp32_cpu_offload=True,
            )
            model_kwargs["quantization_config"] = quantization_config

        # Suppress additional loading warnings
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            warnings.simplefilter("ignore", FutureWarning)

            model = AutoModel.from_pretrained(
                model_path, **model_kwargs
            )

        # EXPLICIT: Set model to eval mode for inference
        model.eval()

        return model, tokenizer

    @staticmethod
    def run_inference(model, tokenizer, prompt: str, image, max_new_tokens: int = 64):
        """Run inference with comprehensive warning suppression"""
        import io
        import sys
        import warnings

        # SAFETY: Ensure model is in eval mode
        model.eval()

        import torchvision.transforms as T
        from torchvision.transforms.functional import InterpolationMode

        transform = T.Compose([
            T.Resize((448, 448), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])

        pixel_values = transform(image).unsqueeze(0)
        if torch.cuda.is_available():
            pixel_values = pixel_values.cuda().to(torch.bfloat16).contiguous()

        # COMPREHENSIVE warning suppression for model.chat()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", UserWarning)
            warnings.simplefilter("ignore", FutureWarning)
            warnings.filterwarnings("ignore", message="Setting `pad_token_id`")
            warnings.filterwarnings("ignore", message=".*pad_token_id.*")

            # Temporarily capture stderr to suppress print statements
            old_stderr = sys.stderr
            sys.stderr = buffer = io.StringIO()

            try:
                raw_response = model.chat(
                    tokenizer=tokenizer,
                    pixel_values=pixel_values,
                    question=prompt,
                    generation_config={"max_new_tokens": max_new_tokens, "do_sample": False}
                )
            finally:
                # Restore stderr
                sys.stderr = old_stderr

        if isinstance(raw_response, tuple):
            raw_response = raw_response[0]

        # Cleanup tensors immediately
        del pixel_values

        return raw_response

def validate_model(model_loader_class, model_path: str, config: Dict, model_name: str) -> Tuple[bool, Optional[Any], Optional[Any], float]:
    """STEP 1: LOAD MODEL FIRST - separate from prompt testing with memory monitoring"""

    # Clean up before loading
    memory_manager.cleanup_gpu_memory()
    memory_before = memory_manager.get_memory_usage()
    memory_manager.print_memory_usage(f"Pre-{model_name}")

    model_start_time = time.time()

    try:
        print(f"🔄 STEP 1: Loading {model_name.upper()} model from {model_path}...")

        # LOAD MODEL FIRST - no prompts yet
        model, processor_or_tokenizer = model_loader_class.load_model(
            model_path, config["enable_quantization"]
        )

        # EXPLICIT: Ensure model is in eval mode
        model.eval()

        model_load_time = time.time() - model_start_time

        # Monitor memory after loading
        memory_after = memory_manager.get_memory_usage()
        memory_delta = memory_manager.get_memory_delta(memory_before, memory_after)

        print(f"✅ {model_name.upper()} model loaded successfully in {model_load_time:.1f}s (eval mode)")
        memory_manager.print_memory_usage(f"Post-{model_name}")
        print(f"   📊 Memory usage: +{memory_delta['allocated_delta']:.1f}GB allocated | +{memory_delta['reserved_delta']:.1f}GB reserved")

        # STEP 2: Simple validation that model can run basic inference
        print(f"🔍 STEP 2: Testing basic {model_name.upper()} model functionality...")
        img_path = dataset_manager.datasets_path / "image14.png"

        if not img_path.exists():
            print(f"❌ Test image not found: {img_path}")
            del model, processor_or_tokenizer
            memory_manager.cleanup_gpu_memory()
            return False, None, None, model_load_time

        image = Image.open(img_path).convert("RGB")

        # Use the simplest possible prompt to test model loading (not extraction quality)
        simple_test_prompt = "<|image|>What do you see?"

        try:
            validation_start = time.time()
            raw_response = model_loader_class.run_inference(
                model, processor_or_tokenizer, simple_test_prompt,
                image, 32  # Short response for validation
            )
            validation_time = time.time() - validation_start

            # MODEL VALIDATION: Just check that inference works
            if raw_response and len(raw_response.strip()) > 0:
                print(f"✅ {model_name.upper()} model validation passed in {validation_time:.1f}s - inference works")
                print(f"   Test response: {raw_response[:50]}...")
                memory_manager.print_memory_usage(f"Ready-{model_name}")
                return True, model, processor_or_tokenizer, model_load_time
            else:
                print(f"❌ {model_name.upper()} model validation failed - no response")
                del model, processor_or_tokenizer
                memory_manager.cleanup_gpu_memory()
                return False, None, None, model_load_time

        except Exception as inference_error:
            print(f"❌ {model_name.upper()} model validation failed - inference error: {str(inference_error)[:100]}...")
            del model, processor_or_tokenizer
            memory_manager.cleanup_gpu_memory()
            return False, None, None, model_load_time

    except Exception as e:
        print(f"❌ {model_name.upper()} model loading failed: {str(e)[:100]}...")
        memory_manager.cleanup_gpu_memory()
        return False, None, None, 0.0

print("✅ Model loader classes defined:")
print("   - LlamaModelLoader with validation")
print("   - InternVLModelLoader with COMPREHENSIVE warning suppression")
print("   - validate_model() - STEP 1: Load model, STEP 2: Test basic inference")
print("   - 💾 MEMORY MONITORING: Before/after loading + usage deltas")
print("   - 🔇 SILENCE: Complete suppression of InternVL pad_token_id warnings")
print("   - EXPLICIT: Models set to .eval() mode for inference")
print("   - UNIFIED: Both models use same prompt (robust for both)")
print("   - SEPARATED: Model loading from prompt application")

In [None]:
# Sequential Model Testing - Llama First
print("🔬 SEQUENTIAL MODEL TESTING: LLAMA → CLEANUP → INTERNVL")
print("=" * 70)

# Initialize results storage
extraction_results = {
    "llama": {"documents": [], "successful": 0, "total_time": 0},
    "internvl": {"documents": [], "successful": 0, "total_time": 0}
}

print("🔥 STEP 1: LOAD LLAMA → RUN ALL INFERENCE → CLEANUP")
print("=" * 50)

# Load Llama model with memory monitoring
llama_valid, llama_model, llama_processor, llama_load_time = validate_model(
    LlamaModelLoader,
    CONFIG["model_paths"]["llama"],
    CONFIG,
    "llama"
)

if llama_valid:
    print("✅ Llama model loaded - running full extraction test")
    print("🎯 Using SIMPLIFIED KEY-VALUE prompt (Australian ABN + dates)")

    # SHOW MEMORY USAGE AFTER MODEL IS LOADED
    print("\n📊 LLAMA MODEL MEMORY FOOTPRINT:")
    memory_manager.print_memory_usage("Llama-loaded")
    llama_memory = memory_manager.get_memory_usage()
    print(f"   🔧 Model size: ~{llama_memory['allocated']:.1f}GB allocated")
    print(f"   🔧 Reserved: ~{llama_memory['reserved']:.1f}GB reserved")
    print(f"   🔧 Available: {llama_memory['free']:.1f}GB remaining")
    print(f"   ⏱️ Load time: {llama_load_time:.1f}s")
    print()

    total_inference_time = 0

    for i, (img_name, doc_type) in enumerate(verified_extraction_images, 1):
        try:
            img_path = dataset_manager.datasets_path / img_name
            image = Image.open(img_path).convert("RGB")

            inference_start = time.time()

            # Use simplified KEY-VALUE extraction prompt with ABN
            raw_response = LlamaModelLoader.run_inference(
                llama_model, llama_processor, CONFIG["extraction_prompt"],
                image, CONFIG["max_new_tokens"]
            )

            inference_time = time.time() - inference_start
            total_inference_time += inference_time

            cleaned_response = repetition_controller.clean_response(raw_response)
            analysis = extraction_analyzer.analyze(cleaned_response, img_name)
            analysis["inference_time"] = inference_time
            analysis["doc_type"] = doc_type

            extraction_results["llama"]["documents"].append(analysis)

            if analysis["successful"]:
                extraction_results["llama"]["successful"] += 1

            # DETAILED field-by-field output
            status = "✅" if analysis["successful"] else "❌"
            structured_status = "S" if analysis["is_structured"] else "T"
            abn_status = "A" if analysis["has_abn"] else "-"

            # Show which specific fields were detected
            fields_detected = []
            if analysis["has_store"]: fields_detected.append("STORE")
            if analysis["has_abn"]: fields_detected.append("ABN")
            if analysis["has_date"]: fields_detected.append("DATE")
            if analysis["has_total"]: fields_detected.append("TOTAL")

            fields_str = "|".join(fields_detected) if fields_detected else "none"

            print(f"   {i:2d}. {img_name:<12} {status} {inference_time:.1f}s | {structured_status}{abn_status} | {analysis['core_score']}/3 core | Fields: {fields_str}")

            # Show raw response for key images that should have ABN
            if img_name in ["image39.png", "image76.png", "image71.png"]:
                print(f"       Raw response: {cleaned_response[:100]}...")

            # Immediate tensor cleanup
            del image

            # Periodic GPU cleanup every 3 images
            if i % 3 == 0:
                memory_manager.cleanup_gpu_memory()

        except Exception as e:
            print(f"   {i:2d}. {img_name:<12} ❌ Error: {str(e)[:30]}...")

    # Calculate Llama results
    extraction_results["llama"]["total_time"] = total_inference_time
    extraction_results["llama"]["avg_time"] = total_inference_time / len(verified_extraction_images)

    # Count ABN detection rate
    llama_abn_count = sum(1 for doc in extraction_results["llama"]["documents"] if doc.get("has_abn", False))

    print("\n📊 Llama Results:")
    print(f"   Success rate: {extraction_results['llama']['successful']}/{len(verified_extraction_images)}")
    print(f"   ABN detection: {llama_abn_count}/{len(verified_extraction_images)} ({llama_abn_count/len(verified_extraction_images)*100:.1f}%)")
    print(f"   Average time: {extraction_results['llama']['avg_time']:.1f}s per document")
    print(f"   Total time: {extraction_results['llama']['total_time']:.1f}s for {len(verified_extraction_images)} documents")

    # DETAILED FIELD ANALYSIS
    print("\n🔍 DETAILED FIELD ANALYSIS:")
    store_count = sum(1 for doc in extraction_results["llama"]["documents"] if doc.get("has_store", False))
    abn_count = sum(1 for doc in extraction_results["llama"]["documents"] if doc.get("has_abn", False))
    date_count = sum(1 for doc in extraction_results["llama"]["documents"] if doc.get("has_date", False))
    total_count = sum(1 for doc in extraction_results["llama"]["documents"] if doc.get("has_total", False))

    print(f"   STORE detection: {store_count}/{len(verified_extraction_images)} ({store_count/len(verified_extraction_images)*100:.1f}%)")
    print(f"   ABN detection: {abn_count}/{len(verified_extraction_images)} ({abn_count/len(verified_extraction_images)*100:.1f}%)")
    print(f"   DATE detection: {date_count}/{len(verified_extraction_images)} ({date_count/len(verified_extraction_images)*100:.1f}%)")
    print(f"   TOTAL detection: {total_count}/{len(verified_extraction_images)} ({total_count/len(verified_extraction_images)*100:.1f}%)")

    # Monitor memory before cleanup
    memory_manager.print_memory_usage("Before-cleanup")

    # COMPLETE LLAMA CLEANUP
    print("\n🧹 STEP 2: COMPLETE LLAMA CLEANUP")
    memory_before_cleanup = memory_manager.get_memory_usage()

    del llama_model, llama_processor
    memory_manager.cleanup_gpu_memory()

    memory_after_cleanup = memory_manager.get_memory_usage()
    memory_freed = memory_manager.get_memory_delta(memory_after_cleanup, memory_before_cleanup)

    print("✅ Llama model cleaned up - memory freed")
    memory_manager.print_memory_usage("After-cleanup")
    print(f"   📊 Memory freed: {abs(memory_freed['allocated_delta']):.1f}GB allocated | {abs(memory_freed['reserved_delta']):.1f}GB reserved")

else:
    print("❌ Llama model validation failed - skipping")

print(f"\n{'='*70}")
print("🎯 STEP 3: LOAD INTERNVL → RUN ALL INFERENCE → CLEANUP")
print("=" * 50)

In [None]:
# Sequential Model Testing - InternVL Second
# Load InternVL model with memory monitoring
internvl_valid, internvl_model, internvl_tokenizer, internvl_load_time = validate_model(
    InternVLModelLoader,
    CONFIG["model_paths"]["internvl"],
    CONFIG,
    "internvl"
)

if internvl_valid:
    print("✅ InternVL model loaded - running full extraction test")
    print("🎯 Using SIMPLIFIED KEY-VALUE prompt (Australian ABN + dates)")

    # SHOW MEMORY USAGE AFTER MODEL IS LOADED
    print("\n📊 INTERNVL MODEL MEMORY FOOTPRINT:")
    memory_manager.print_memory_usage("InternVL-loaded")
    internvl_memory = memory_manager.get_memory_usage()
    print(f"   🔧 Model size: ~{internvl_memory['allocated']:.1f}GB allocated")
    print(f"   🔧 Reserved: ~{internvl_memory['reserved']:.1f}GB reserved")
    print(f"   🔧 Available: {internvl_memory['free']:.1f}GB remaining")
    print(f"   ⏱️ Load time: {internvl_load_time:.1f}s")
    print()

    total_inference_time = 0

    for i, (img_name, doc_type) in enumerate(verified_extraction_images, 1):
        try:
            img_path = dataset_manager.datasets_path / img_name
            image = Image.open(img_path).convert("RGB")

            inference_start = time.time()

            # Use simplified KEY-VALUE extraction prompt with ABN
            raw_response = InternVLModelLoader.run_inference(
                internvl_model, internvl_tokenizer, CONFIG["extraction_prompt"],
                image, CONFIG["max_new_tokens"]
            )

            inference_time = time.time() - inference_start
            total_inference_time += inference_time

            cleaned_response = repetition_controller.clean_response(raw_response)
            analysis = extraction_analyzer.analyze(cleaned_response, img_name)
            analysis["inference_time"] = inference_time
            analysis["doc_type"] = doc_type

            extraction_results["internvl"]["documents"].append(analysis)

            if analysis["successful"]:
                extraction_results["internvl"]["successful"] += 1

            # Updated output format to include ABN tracking
            status = "✅" if analysis["successful"] else "❌"
            structured_status = "S" if analysis["is_structured"] else "T"
            abn_status = "A" if analysis["has_abn"] else "-"
            print(f"   {i:2d}. {img_name:<12} {status} {inference_time:.1f}s | {structured_status}{abn_status} | {analysis['extraction_score']}/4")

            # Immediate tensor cleanup
            del image

            # Periodic GPU cleanup every 3 images
            if i % 3 == 0:
                memory_manager.cleanup_gpu_memory()

        except Exception as e:
            print(f"   {i:2d}. {img_name:<12} ❌ Error: {str(e)[:30]}...")

    # Calculate InternVL results
    extraction_results["internvl"]["total_time"] = total_inference_time
    extraction_results["internvl"]["avg_time"] = total_inference_time / len(verified_extraction_images)

    # Count ABN detection rate
    internvl_abn_count = sum(1 for doc in extraction_results["internvl"]["documents"] if doc.get("has_abn", False))

    print("\n📊 InternVL Results:")
    print(f"   Success rate: {extraction_results['internvl']['successful']}/{len(verified_extraction_images)}")
    print(f"   ABN detection: {internvl_abn_count}/{len(verified_extraction_images)} ({internvl_abn_count/len(verified_extraction_images)*100:.1f}%)")
    print(f"   Average time: {extraction_results['internvl']['avg_time']:.1f}s per document")
    print(f"   Total time: {extraction_results['internvl']['total_time']:.1f}s for {len(verified_extraction_images)} documents")

    # Monitor memory before cleanup
    memory_manager.print_memory_usage("Before-cleanup")

    # COMPLETE INTERNVL CLEANUP
    print("\n🧹 STEP 4: COMPLETE INTERNVL CLEANUP")
    memory_before_cleanup = memory_manager.get_memory_usage()

    del internvl_model, internvl_tokenizer
    memory_manager.cleanup_gpu_memory()

    memory_after_cleanup = memory_manager.get_memory_usage()
    memory_freed = memory_manager.get_memory_delta(memory_after_cleanup, memory_before_cleanup)

    print("✅ InternVL model cleaned up - memory freed")
    memory_manager.print_memory_usage("After-cleanup")
    print(f"   📊 Memory freed: {abs(memory_freed['allocated_delta']):.1f}GB allocated | {abs(memory_freed['reserved_delta']):.1f}GB reserved")

else:
    print("❌ InternVL model validation failed - skipping")

print(f"\n{'='*70}")
print("🏆 STEP 5: FINAL COMPARISON")
print("=" * 50)

In [None]:
# Final Comparison with Comprehensive Analytics
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import f1_score, precision_score, recall_score


class ComprehensiveResultsAnalyzer:
    """Advanced results analysis with statistical metrics and visualizations"""

    def __init__(self):
        plt.style.use('default')
        sns.set_palette("husl")

    def create_detailed_dataframe(self, extraction_results: Dict, verified_images: List) -> pd.DataFrame:
        """Create comprehensive DataFrame for analysis"""
        all_results = []

        for model_name, results in extraction_results.items():
            if not results["documents"]:
                continue

            for doc in results["documents"]:
                all_results.append({
                    'model': model_name.upper(),
                    'image': doc['img_name'],
                    'doc_type': doc['doc_type'],
                    'inference_time': doc['inference_time'],
                    'is_structured': doc['is_structured'],
                    'has_store': doc['has_store'],
                    'has_abn': doc['has_abn'],
                    'has_date': doc['has_date'],
                    'has_total': doc['has_total'],
                    'extraction_score': doc['extraction_score'],
                    'core_score': doc['core_score'],
                    'successful': doc['successful']
                })

        return pd.DataFrame(all_results)

    def calculate_field_f1_scores(self, df: pd.DataFrame) -> Dict:
        """Calculate F1 scores for each field and model"""
        fields = ['has_store', 'has_abn', 'has_date', 'has_total']
        f1_results = {}

        # Ground truth: assume all documents should have these fields (except ABN)
        # For synthetic data, we'll use a simplified approach
        ground_truth = {
            'has_store': [1] * len(df),  # All should have store
            'has_abn': [1 if img in ['image39.png', 'image76.png', 'image71.png'] else 0
                       for img in df['image']],  # Only specific images have ABN
            'has_date': [1] * len(df),   # All should have date
            'has_total': [1] * len(df)   # All should have total
        }

        for model in df['model'].unique():
            model_df = df[df['model'] == model]
            f1_results[model] = {}

            for field in fields:
                if len(model_df) > 0:
                    # Get predictions for this model
                    predictions = model_df[field].values.astype(int)
                    # Get corresponding ground truth
                    gt_indices = model_df.index
                    gt = [ground_truth[field][i] for i in range(len(predictions))]

                    # Calculate metrics
                    f1 = f1_score(gt, predictions, zero_division=0)
                    precision = precision_score(gt, predictions, zero_division=0)
                    recall = recall_score(gt, predictions, zero_division=0)

                    f1_results[model][field] = {
                        'f1': f1,
                        'precision': precision,
                        'recall': recall
                    }

        return f1_results

    def create_performance_visualizations(self, df: pd.DataFrame, f1_results: Dict):
        """Create comprehensive performance visualizations"""

        # Set up the plotting environment
        fig = plt.figure(figsize=(20, 15))

        # 1. Field Detection Rates Comparison
        plt.subplot(2, 3, 1)
        fields = ['has_store', 'has_abn', 'has_date', 'has_total']
        field_names = ['STORE', 'ABN', 'DATE', 'TOTAL']

        detection_rates = []
        models = df['model'].unique()

        for model in models:
            model_df = df[df['model'] == model]
            rates = [model_df[field].mean() * 100 for field in fields]
            detection_rates.append(rates)

        x = np.arange(len(field_names))
        width = 0.35

        for i, (model, rates) in enumerate(zip(models, detection_rates, strict=False)):
            plt.bar(x + i*width, rates, width, label=model, alpha=0.8)

        plt.xlabel('Fields')
        plt.ylabel('Detection Rate (%)')
        plt.title('Field Detection Rates by Model')
        plt.xticks(x + width/2, field_names)
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 2. F1 Scores Heatmap
        plt.subplot(2, 3, 2)
        f1_matrix = []
        for model in models:
            if model in f1_results:
                f1_scores = [f1_results[model][field]['f1'] for field in fields]
                f1_matrix.append(f1_scores)

        if f1_matrix:
            sns.heatmap(f1_matrix, annot=True, fmt='.3f',
                       xticklabels=field_names, yticklabels=models,
                       cmap='RdYlGn', vmin=0, vmax=1, cbar_kws={'label': 'F1 Score'})
            plt.title('F1 Scores by Model and Field')

        # 3. Inference Time Distribution
        plt.subplot(2, 3, 3)
        for model in models:
            model_df = df[df['model'] == model]
            if len(model_df) > 0:
                plt.hist(model_df['inference_time'], alpha=0.7, label=model, bins=10, density=True)

        plt.xlabel('Inference Time (seconds)')
        plt.ylabel('Density')
        plt.title('Inference Time Distribution')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 4. Success Rate by Document Type
        plt.subplot(2, 3, 4)
        success_by_type = df.groupby(['model', 'doc_type'])['successful'].mean().unstack(fill_value=0)
        success_by_type.plot(kind='bar', ax=plt.gca(), width=0.8)
        plt.xlabel('Model')
        plt.ylabel('Success Rate')
        plt.title('Success Rate by Document Type')
        plt.xticks(rotation=0)
        plt.legend(title='Document Type', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3)

        # 5. Core Score Distribution
        plt.subplot(2, 3, 5)
        for model in models:
            model_df = df[df['model'] == model]
            if len(model_df) > 0:
                scores = model_df['core_score'].value_counts().sort_index()
                plt.plot(scores.index, scores.values, marker='o', label=model, linewidth=2)

        plt.xlabel('Core Score (0-3)')
        plt.ylabel('Number of Documents')
        plt.title('Core Score Distribution')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 6. Structured vs Unstructured Output
        plt.subplot(2, 3, 6)
        structured_rates = df.groupby('model')['is_structured'].mean() * 100
        colors = sns.color_palette("husl", len(structured_rates))
        bars = plt.bar(structured_rates.index, structured_rates.values, color=colors, alpha=0.8)
        plt.xlabel('Model')
        plt.ylabel('Structured Output Rate (%)')
        plt.title('Structured Output Rate by Model')
        plt.grid(True, alpha=0.3)

        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 1,
                    f'{height:.1f}%', ha='center', va='bottom')

        plt.tight_layout()
        plt.show()

    def print_comprehensive_summary(self, df: pd.DataFrame, f1_results: Dict,
                                  extraction_results: Dict, llama_memory: Dict = None,
                                  internvl_memory: Dict = None):
        """Print detailed statistical summary"""

        print(f"\n{'='*80}")
        print("📊 COMPREHENSIVE PERFORMANCE ANALYSIS")
        print(f"{'='*80}")

        # Overall Statistics
        print("\n🎯 OVERALL PERFORMANCE METRICS:")
        print("-" * 50)

        for model in df['model'].unique():
            model_df = df[df['model'] == model]
            if len(model_df) > 0:
                success_rate = model_df['successful'].mean() * 100
                avg_time = model_df['inference_time'].mean()
                std_time = model_df['inference_time'].std()
                structured_rate = model_df['is_structured'].mean() * 100

                print(f"{model}:")
                print(f"  Success Rate: {success_rate:.1f}%")
                print(f"  Avg Inference Time: {avg_time:.2f}s ± {std_time:.2f}s")
                print(f"  Structured Output: {structured_rate:.1f}%")
                print()

        # Field-specific F1 Scores
        print("🔍 FIELD-SPECIFIC F1 SCORES:")
        print("-" * 50)

        f1_df_data = []
        for model, fields in f1_results.items():
            for field, metrics in fields.items():
                f1_df_data.append({
                    'Model': model,
                    'Field': field.replace('has_', '').upper(),
                    'F1': metrics['f1'],
                    'Precision': metrics['precision'],
                    'Recall': metrics['recall']
                })

        if f1_df_data:
            f1_df = pd.DataFrame(f1_df_data)
            print(f1_df.pivot(index='Field', columns='Model', values='F1').round(3))
            print()

        # Memory and Performance Summary
        print("💾 MEMORY AND PERFORMANCE SUMMARY:")
        print("-" * 50)

        if llama_memory:
            print(f"LLAMA Memory Footprint: {llama_memory.get('allocated', 0):.1f}GB allocated")
        if internvl_memory:
            print(f"INTERNVL Memory Footprint: {internvl_memory.get('allocated', 0):.1f}GB allocated")

        for model, results in extraction_results.items():
            if results.get("documents"):
                total_time = results.get("total_time", 0)
                avg_time = results.get("avg_time", 0)
                num_docs = len(results["documents"])

                print(f"{model.upper()}:")
                print(f"  Total Processing Time: {total_time:.1f}s")
                print(f"  Throughput: {num_docs/total_time:.2f} docs/sec")
                print(f"  Avg Time per Document: {avg_time:.2f}s")

        # Recommendations
        print("\n🥇 RECOMMENDATIONS:")
        print("-" * 50)

        if len(df['model'].unique()) >= 2:
            # Compare models
            model_summary = df.groupby('model').agg({
                'successful': 'mean',
                'inference_time': 'mean',
                'is_structured': 'mean',
                'has_abn': 'mean'
            }).round(3)

            best_accuracy = model_summary['successful'].idxmax()
            best_speed = model_summary['inference_time'].idxmin()
            best_abn = model_summary['has_abn'].idxmax()

            print(f"Best Overall Accuracy: {best_accuracy}")
            print(f"Fastest Inference: {best_speed}")
            print(f"Best ABN Detection: {best_abn}")

            # Production recommendation
            if best_accuracy == best_speed:
                print(f"\n🎯 PRODUCTION RECOMMENDATION: {best_accuracy}")
                print("   Reason: Best accuracy AND fastest inference")
            else:
                acc_score = model_summary.loc[best_accuracy, 'successful']
                speed_score = 1 / model_summary.loc[best_speed, 'inference_time']

                print(f"\n🎯 PRODUCTION RECOMMENDATION: {best_accuracy}")
                print(f"   Reason: Higher accuracy ({acc_score:.1%}) is more important than speed")

# Create analyzer and generate comprehensive results
analyzer = ComprehensiveResultsAnalyzer()

# Create detailed DataFrame
results_df = analyzer.create_detailed_dataframe(extraction_results, verified_extraction_images)

if not results_df.empty:
    # Calculate F1 scores
    f1_scores = analyzer.calculate_field_f1_scores(results_df)

    # Get memory information if available
    llama_mem = None
    internvl_mem = None

    # Create visualizations
    analyzer.create_performance_visualizations(results_df, f1_scores)

    # Print comprehensive summary
    analyzer.print_comprehensive_summary(results_df, f1_scores, extraction_results,
                                       llama_mem, internvl_mem)

    # Export detailed results
    print("\n📊 DETAILED RESULTS DATAFRAME:")
    print("-" * 50)
    print(results_df.head(10))

else:
    print("❌ No results available for analysis")

print("\n✅ COMPREHENSIVE ANALYSIS COMPLETE!")
print("📊 Generated: Performance metrics, F1 scores, visualizations, and recommendations")