# Minimal Vision Model Test

Direct model loading and testing without using the unified_vision_processor package.

All configuration is embedded in the notebook for easy modification.

In [1]:
# Configuration - Modify as needed
CONFIG = {
    # Model selection: "llama" or "internvl"
    "model_type": "llama",  # BACK TO LLAMA with working code patterns
    
    # Model paths
    "model_paths": {
        "llama": "/home/jovyan/nfs_share/models/Llama-3.2-11B-Vision",
        "internvl": "/home/jovyan/nfs_share/models/InternVL3-8B"
    },
    
    # Test image path
    "test_image": "datasets/image14.png",
    
    # WORKING prompt pattern from vision_processor (KEY-VALUE format)
    "prompt": "<|image|>Extract data from this receipt in KEY-VALUE format.\n\nOutput format:\nDATE: [date from receipt]\nSTORE: [store name]\nTOTAL: [total amount]\n\nExtract all visible text and format as KEY: VALUE pairs only.",
    
    # EXACT working generation parameters from LlamaVisionModel
    "max_new_tokens": 1024,
    "enable_quantization": True
}

print(f"Configuration loaded:")
print(f"Model: {CONFIG['model_type']} (using WORKING vision_processor patterns)")
print(f"Image: {CONFIG['test_image']}")
print(f"Prompt: {CONFIG['prompt'][:100]}...")
print("\n✅ Using PROVEN working patterns from vision_processor/models/llama_model.py")

Configuration loaded:
Model: llama (using WORKING vision_processor patterns)
Image: datasets/image14.png
Prompt: <|image|>Extract data from this receipt in KEY-VALUE format.

Output format:
DATE: [date from receip...

✅ Using PROVEN working patterns from vision_processor/models/llama_model.py


In [2]:
# Imports - Direct model loading
import time
import torch
from pathlib import Path
from PIL import Image

# Model-specific imports based on selection
if CONFIG["model_type"] == "llama":
    from transformers import AutoProcessor, MllamaForConditionalGeneration
elif CONFIG["model_type"] == "internvl":
    from transformers import AutoModel, AutoTokenizer
    import torchvision.transforms as T
    from torchvision.transforms.functional import InterpolationMode

print(f"Imports successful for {CONFIG['model_type']} ✓")

Imports successful for llama ✓


In [3]:
# Load model directly - USING WORKING VISION_PROCESSOR PATTERNS
model_path = CONFIG["model_paths"][CONFIG["model_type"]]
print(f"Loading {CONFIG['model_type']} model from {model_path}...")
start_time = time.time()

try:
    if CONFIG["model_type"] == "llama":
        # EXACT pattern from vision_processor/models/llama_model.py
        processor = AutoProcessor.from_pretrained(
            model_path,
            trust_remote_code=True,
            local_files_only=True
        )
        
        # Working quantization config from LlamaVisionModel
        quantization_config = None
        if CONFIG["enable_quantization"] and torch.cuda.is_available():
            try:
                from transformers import BitsAndBytesConfig
                quantization_config = BitsAndBytesConfig(
                    load_in_8bit=True,
                    llm_int8_enable_fp32_cpu_offload=True,
                    llm_int8_skip_modules=["vision_tower", "multi_modal_projector"],
                    llm_int8_threshold=6.0,
                )
                print("✅ Using WORKING quantization config (skipping vision modules)")
            except ImportError:
                print("Quantization not available, using FP16")
                CONFIG["enable_quantization"] = False
        
        # Working model loading args from LlamaVisionModel
        model_loading_args = {
            "low_cpu_mem_usage": True,
            "torch_dtype": torch.float16,
            "device_map": "cuda:0" if torch.cuda.is_available() else "cpu",
            "local_files_only": True
        }
        
        if quantization_config:
            model_loading_args["quantization_config"] = quantization_config
        
        model = MllamaForConditionalGeneration.from_pretrained(
            model_path,
            **model_loading_args
        ).eval()
        
        # CRITICAL: Set working generation config exactly like LlamaVisionModel
        model.generation_config.max_new_tokens = CONFIG["max_new_tokens"]
        model.generation_config.do_sample = False
        model.generation_config.temperature = None  # Disable temperature
        model.generation_config.top_p = None        # Disable top_p  
        model.generation_config.top_k = None        # Disable top_k
        model.config.use_cache = True               # Enable KV cache
        
        print("✅ Applied WORKING generation config (no sampling parameters)")
        
    elif CONFIG["model_type"] == "internvl":
        # Load InternVL3
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
            local_files_only=True
        )
        
        model_kwargs = {
            "low_cpu_mem_usage": True,
            "trust_remote_code": True,
            "torch_dtype": torch.bfloat16,
            "local_files_only": True
        }
        
        if CONFIG["enable_quantization"] and torch.cuda.is_available():
            try:
                model_kwargs["load_in_8bit"] = True
                print("8-bit quantization enabled")
            except Exception:
                print("Quantization not available, using bfloat16")
                CONFIG["enable_quantization"] = False
        
        model = AutoModel.from_pretrained(
            model_path,
            **model_kwargs
        ).eval()
        
        if torch.cuda.is_available() and not CONFIG["enable_quantization"]:
            model = model.cuda()
    
    load_time = time.time() - start_time
    print(f"✅ Model loaded successfully in {load_time:.2f}s")
    print(f"Model device: {next(model.parameters()).device}")
    print(f"Quantization active: {CONFIG['enable_quantization']}")
    
except Exception as e:
    print(f"✗ Model loading failed: {e}")
    import traceback
    traceback.print_exc()
    raise e

Loading llama model from /home/jovyan/nfs_share/models/Llama-3.2-11B-Vision...
✅ Using WORKING quantization config (skipping vision modules)


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

✅ Applied WORKING generation config (no sampling parameters)
✅ Model loaded successfully in 5.92s
Model device: cuda:0
Quantization active: True


In [4]:
# Load and preprocess image
test_image_path = Path(CONFIG["test_image"])

if not test_image_path.exists():
    print(f"✗ Test image not found: {test_image_path}")
    available = list(Path("datasets").glob("*.png"))[:5]
    print(f"Available images: {[img.name for img in available]}")
    raise FileNotFoundError(f"Test image not found: {test_image_path}")

# Load image
image = Image.open(test_image_path)
if image.mode != "RGB":
    image = image.convert("RGB")

print(f"✓ Image loaded: {image.size}")
print(f"  File size: {test_image_path.stat().st_size / 1024:.1f} KB")

✓ Image loaded: (2048, 2048)
  File size: 211.1 KB


In [5]:
# Run inference - ULTRA-AGGRESSIVE REPETITION CONTROL
prompt = CONFIG["prompt"]
print(f"Running inference with {CONFIG['model_type']}...")
print(f"Prompt: {prompt[:100]}...")
print("-" * 50)

start_time = time.time()

class UltraAggressiveRepetitionController:
    """Ultra-aggressive repetition detection and control specifically for Llama-3.2-Vision."""
    
    def __init__(self, word_threshold: float = 0.15, phrase_threshold: int = 2):
        """
        Initialize ultra-aggressive repetition controller.
        
        Args:
            word_threshold: If any word appears more than this % of total words, it's repetitive (15% vs 30%)
            phrase_threshold: Minimum repetitions to trigger cleaning (2 vs 3)
        """
        self.word_threshold = word_threshold
        self.phrase_threshold = phrase_threshold
        
        # Known problematic patterns from Llama-3.2-Vision
        self.toxic_patterns = [
            r"THANK YOU FOR SHOPPING WITH US[^.]*",
            r"All prices include GST where applicable[^.]*",
            r"\\+[a-zA-Z]*\{[^}]*\}",  # LaTeX artifacts
            r"\(\s*\)",  # Empty parentheses
            r"[.-]\s*THANK YOU",  # Dash/period before thank you
        ]
    
    def detect_repetitive_generation(self, text: str, min_words: int = 3) -> bool:
        """Ultra-sensitive repetition detection."""
        words = text.split()
        
        # Much stricter minimum content requirement
        if len(words) < min_words:
            return True
        
        # Check for known toxic patterns first
        if self._has_toxic_patterns(text):
            return True
            
        # Ultra-aggressive word repetition check (15% threshold vs 30%)
        word_counts = {}
        for word in words:
            word_lower = word.lower().strip('.,!?()[]{}')
            if len(word_lower) > 2:  # Ignore very short words
                word_counts[word_lower] = word_counts.get(word_lower, 0) + 1
        
        total_words = len([w for w in words if len(w.strip('.,!?()[]{}')) > 2])
        if total_words > 0:
            for word, count in word_counts.items():
                if count > total_words * self.word_threshold:  # 15% threshold
                    return True
        
        # Ultra-aggressive phrase repetition
        if self._detect_aggressive_phrase_repetition(text):
            return True
            
        return False
    
    def _has_toxic_patterns(self, text: str) -> bool:
        """Check for known problematic patterns."""
        import re
        
        for pattern in self.toxic_patterns:
            matches = re.findall(pattern, text, flags=re.IGNORECASE)
            if len(matches) >= 2:  # Even 2 occurrences is too many
                return True
        
        return False
    
    def _detect_aggressive_phrase_repetition(self, text: str) -> bool:
        """Ultra-aggressive phrase repetition detection."""
        import re
        
        # Check for 3+ word phrases repeated even twice
        words = text.split()
        for i in range(len(words) - 6):  # Need at least 6 words for 3+3
            phrase = ' '.join(words[i:i+3]).lower()
            remainder = ' '.join(words[i+3:]).lower()
            if phrase in remainder:
                return True
        
        # Check sentences/segments
        segments = re.split(r'[.!?]+', text)
        segment_counts = {}
        
        for segment in segments:
            segment_clean = re.sub(r'\s+', ' ', segment.strip().lower())
            # Much shorter minimum segment length
            if len(segment_clean) > 5:  # Was 10, now 5
                segment_counts[segment_clean] = segment_counts.get(segment_clean, 0) + 1
        
        # Any segment appearing twice is problematic
        for count in segment_counts.values():
            if count >= self.phrase_threshold:  # Now 2 instead of 3
                return True
                
        return False
    
    def clean_response(self, response: str) -> str:
        """Ultra-aggressive cleaning with early truncation."""
        import re
        
        if not response or len(response.strip()) == 0:
            return ""
        
        original_length = len(response)
        
        # Step 1: Early truncation at first major repetition
        response = self._early_truncate_at_repetition(response)
        
        # Step 2: Remove toxic patterns aggressively
        response = self._remove_toxic_patterns(response)
        
        # Step 3: Remove safety warnings
        response = self._remove_safety_warnings(response)
        
        # Step 4: Ultra-aggressive repetition removal
        response = self._ultra_aggressive_word_removal(response)
        response = self._ultra_aggressive_phrase_removal(response)
        response = self._ultra_aggressive_sentence_removal(response)
        
        # Step 5: Clean artifacts
        response = self._clean_artifacts(response)
        
        # Step 6: Final validation and truncation
        response = self._final_validation_truncate(response)
        
        final_length = len(response)
        reduction = ((original_length - final_length) / original_length * 100) if original_length > 0 else 0
        
        print(f"🧹 Cleaning: {original_length} → {final_length} chars ({reduction:.1f}% reduction)")
        
        return response.strip()
    
    def _early_truncate_at_repetition(self, text: str) -> str:
        """Truncate immediately when repetition starts."""
        import re
        
        # Find first occurrence of toxic patterns and truncate there
        for pattern in self.toxic_patterns:
            match = re.search(pattern, text, flags=re.IGNORECASE)
            if match:
                # Find the SECOND occurrence and truncate before it
                remaining = text[match.end():]
                second_match = re.search(pattern, remaining, flags=re.IGNORECASE)
                if second_match:
                    truncate_point = match.end() + second_match.start()
                    print(f"🔪 Early truncation at repetition: {len(text)} → {truncate_point} chars")
                    return text[:truncate_point]
        
        return text
    
    def _remove_toxic_patterns(self, text: str) -> str:
        """Aggressively remove known toxic patterns."""
        import re
        
        for pattern in self.toxic_patterns:
            # Remove ALL occurrences, not just duplicates
            text = re.sub(pattern, "", text, flags=re.IGNORECASE)
        
        return text
    
    def _remove_safety_warnings(self, text: str) -> str:
        """Remove safety warnings."""
        import re
        
        safety_patterns = [
            r"I'm not able to provide.*?information\.?",
            r"I cannot provide.*?information\.?", 
            r"I'm unable to.*?\.?",
            r"I can't.*?\.?",
            r"Sorry, I cannot.*?\.?",
            r".*could compromise.*privacy.*",
        ]
        
        for pattern in safety_patterns:
            text = re.sub(pattern, "", text, flags=re.IGNORECASE | re.DOTALL)
        
        return text
    
    def _ultra_aggressive_word_removal(self, text: str) -> str:
        """Ultra-aggressive word repetition removal."""
        import re
        
        # Remove 2+ consecutive identical words (was 3+)
        text = re.sub(r'\b(\w+)(\s+\1){1,}', r'\1', text, flags=re.IGNORECASE)
        
        # Remove any word appearing more than 3 times total
        words = text.split()
        word_counts = {}
        for word in words:
            word_lower = word.lower().strip('.,!?()[]{}')
            word_counts[word_lower] = word_counts.get(word_lower, 0) + 1
        
        # Rebuild text, limiting each word to max 3 occurrences
        result_words = []
        word_usage = {}
        
        for word in words:
            word_lower = word.lower().strip('.,!?()[]{}')
            current_count = word_usage.get(word_lower, 0)
            
            if current_count < 3:  # Allow max 3 occurrences
                result_words.append(word)
                word_usage[word_lower] = current_count + 1
        
        return ' '.join(result_words)
    
    def _ultra_aggressive_phrase_removal(self, text: str) -> str:
        """Ultra-aggressive phrase removal."""
        import re
        
        # Remove repeated 2-6 word phrases (expanded range)
        for phrase_length in range(2, 7):
            pattern = r'\b((?:\w+\s+){' + str(phrase_length-1) + r'}\w+)(\s+\1){1,}'  # 1+ repetitions vs 2+
            text = re.sub(pattern, r'\1', text, flags=re.IGNORECASE)
        
        return text
    
    def _ultra_aggressive_sentence_removal(self, text: str) -> str:
        """Ultra-aggressive sentence removal."""
        import re
        
        sentences = re.split(r'[.!?]+', text)
        
        # Keep only first occurrence of any sentence
        seen = set()
        unique_sentences = []
        
        for sentence in sentences:
            sentence_clean = re.sub(r'\s+', ' ', sentence.strip().lower())
            sentence_clean = re.sub(r'[^\w\s]', '', sentence_clean)  # Remove all punctuation for comparison
            
            if sentence_clean and len(sentence_clean) > 3:  # Very short minimum
                if sentence_clean not in seen:
                    seen.add(sentence_clean)
                    unique_sentences.append(sentence.strip())
        
        return '. '.join(unique_sentences)
    
    def _clean_artifacts(self, text: str) -> str:
        """Aggressive artifact cleaning."""
        import re
        
        # Remove whitespace
        text = re.sub(r'\s+', ' ', text)
        
        # Remove LaTeX/markdown aggressively
        text = re.sub(r'\\+[a-zA-Z]*\{[^}]*\}', '', text)
        text = re.sub(r'\\+[a-zA-Z]+', '', text)
        text = re.sub(r'```+[^`]*```+', '', text)
        text = re.sub(r'[{}]+', '', text)
        
        # Remove excessive punctuation
        text = re.sub(r'[.]{2,}', '.', text)
        text = re.sub(r'[!]{2,}', '!', text)
        text = re.sub(r'[?]{2,}', '?', text)
        text = re.sub(r'[,]{2,}', ',', text)
        
        # Remove empty parentheses and brackets
        text = re.sub(r'\(\s*\)', '', text)
        text = re.sub(r'\[\s*\]', '', text)
        
        # Remove standalone punctuation
        text = re.sub(r'\s+[.,!?;:]\s+', ' ', text)
        
        return text
    
    def _final_validation_truncate(self, text: str, max_length: int = 800) -> str:
        """Final validation with aggressive truncation."""
        # If still repetitive after all cleaning, something is very wrong
        if self.detect_repetitive_generation(text):
            print("⚠️ Still repetitive after ultra-aggressive cleaning - truncating heavily")
            # Find last good sentence in first half
            half_point = len(text) // 2
            truncated = text[:half_point]
            last_period = truncated.rfind('.')
            if last_period > half_point * 0.5:
                return truncated[:last_period + 1]
            else:
                return truncated[:half_point] + "..."
        
        # Aggressive length limit
        if len(text) > max_length:
            truncated = text[:max_length]
            last_period = truncated.rfind('.')
            if last_period > max_length * 0.7:
                return truncated[:last_period + 1]
            else:
                return truncated + "..."
        
        return text

# Initialize ultra-aggressive repetition controller
repetition_controller = UltraAggressiveRepetitionController(
    word_threshold=0.15,  # Much stricter: 15% vs 30%
    phrase_threshold=2    # Much stricter: 2 vs 3 repetitions
)

try:
    if CONFIG["model_type"] == "llama":
        # EXACT input preparation from LlamaVisionModel._prepare_inputs()
        prompt_with_image = prompt if prompt.startswith("<|image|>") else f"<|image|>{prompt}"
        
        inputs = processor(text=prompt_with_image, images=image, return_tensors="pt")
        
        # WORKING device handling from LlamaVisionModel
        device = next(model.parameters()).device
        if device.type != "cpu":
            device_target = str(device).split(":")[0] if ":" in str(device) else str(device)
            inputs = {k: v.to(device_target) if hasattr(v, "to") else v for k, v in inputs.items()}
        
        print(f"Input tensor shapes: {[(k, v.shape) for k, v in inputs.items() if hasattr(v, 'shape')]}")
        print(f"Device target: {device}")
        
        # ULTRA-AGGRESSIVE: Even shorter token limit
        effective_max_tokens = min(CONFIG["max_new_tokens"], 384)  # Further reduced: 384 vs 512
        print(f"Using ultra-short max_new_tokens: {effective_max_tokens} (was {CONFIG['max_new_tokens']})")
        
        # EXACT generation kwargs from LlamaVisionModel.generate()
        generation_kwargs = {
            **inputs,
            "max_new_tokens": effective_max_tokens,
            "do_sample": False,  # Deterministic generation
            "pad_token_id": processor.tokenizer.eos_token_id,
            "eos_token_id": processor.tokenizer.eos_token_id,
            "use_cache": True,
        }
        
        print("✅ Using ULTRA-AGGRESSIVE repetition control + shorter generation")
        
        with torch.no_grad():
            outputs = model.generate(**generation_kwargs)
        
        raw_response = processor.decode(
            outputs[0][inputs["input_ids"].shape[-1]:],
            skip_special_tokens=True
        )
        
        print(f"Raw response (first 200 chars): {raw_response[:200]}...")
        print(f"Raw response length: {len(raw_response)} characters")
        
        # ULTRA-AGGRESSIVE: Enhanced repetition control
        response = repetition_controller.clean_response(raw_response)
        
        # Final check with stricter detection
        if repetition_controller.detect_repetitive_generation(response):
            print("❌ STILL REPETITIVE after ultra-aggressive cleaning!")
            print("   This indicates a fundamental issue with the model's generation pattern")
        else:
            print("✅ Ultra-aggressive cleaning successful - repetition eliminated")
        
    elif CONFIG["model_type"] == "internvl":
        # InternVL inference with ultra-aggressive repetition control
        image_size = 448
        transform = T.Compose([
            T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        
        pixel_values = transform(image).unsqueeze(0)
        
        if torch.cuda.is_available():
            pixel_values = pixel_values.cuda().to(torch.bfloat16).contiguous()
        else:
            pixel_values = pixel_values.contiguous()
        
        generation_config = {
            "max_new_tokens": min(CONFIG["max_new_tokens"], 384),
            "do_sample": False,
            "pad_token_id": tokenizer.eos_token_id
        }
        
        raw_response = model.chat(
            tokenizer=tokenizer,
            pixel_values=pixel_values,
            question=prompt,
            generation_config=generation_config
        )
        
        if isinstance(raw_response, tuple):
            raw_response = raw_response[0]
        
        # Apply ultra-aggressive repetition control
        response = repetition_controller.clean_response(raw_response)
    
    inference_time = time.time() - start_time
    print(f"✅ Inference completed in {inference_time:.2f}s")
    print(f"Final response length: {len(response)} characters")
    
except Exception as e:
    print(f"✗ Inference failed: {e}")
    import traceback
    traceback.print_exc()
    
    response = f"Error: Inference failed - {str(e)}"
    inference_time = time.time() - start_time

print(f"Final response ready for display (length: {len(response) if 'response' in locals() else 0} characters)")

Running inference with llama...
Prompt: <|image|>Extract data from this receipt in KEY-VALUE format.

Output format:
DATE: [date from receip...
--------------------------------------------------
Input tensor shapes: [('input_ids', torch.Size([1, 49])), ('attention_mask', torch.Size([1, 49])), ('pixel_values', torch.Size([1, 1, 4, 3, 448, 448])), ('aspect_ratio_ids', torch.Size([1, 1])), ('aspect_ratio_mask', torch.Size([1, 1, 4])), ('cross_attention_mask', torch.Size([1, 49, 1, 4]))]
Device target: cuda:0
Using ultra-short max_new_tokens: 384 (was 1024)
✅ Using ULTRA-AGGRESSIVE repetition control + shorter generation
Raw response (first 200 chars):  
DATE: 11-07-2022
STORE: SPOTLIGHT
TOTAL: $22.45
ITEM: Apples (kg)
QUANTITY: 1
PRICE: $3.96
TOTAL: $3.96
ITEM: Tea Bags (box)
QUANTITY: 1
PRICE: $4.53
TOTAL: $4.53
ITEM: Free Range Eggs (d)
QUANTITY:...
Raw response length: 1151 characters
⚠️ Still repetitive after ultra-aggressive cleaning - truncating heavily
🧹 Cleaning: 1151 → 164 chars 

In [6]:
# Display results
print("=" * 60)
print("EXTRACTED TEXT:")
print("=" * 60)
print(response)
print("=" * 60)

# Summary
print(f"\nSUMMARY:")
print(f"Model: {CONFIG['model_type']}")
print(f"Response length: {len(response)} characters")
print(f"Processing time: {inference_time:.2f}s")
print(f"Quantization enabled: {CONFIG['enable_quantization']}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

# Enhanced JSON parsing with validation
print(f"\nRESPONSE ANALYSIS:")
if response.strip().startswith('{') and response.strip().endswith('}'):
    try:
        import json
        parsed = json.loads(response.strip())
        print(f"✅ VALID JSON EXTRACTED:")
        for key, value in parsed.items():
            print(f"  {key}: {value}")
        
        # Validate completeness
        expected_fields = ["DATE", "STORE", "TOTAL"]
        missing = [field for field in expected_fields if field not in parsed or not parsed[field]]
        if missing:
            print(f"⚠️ Missing fields: {missing}")
        else:
            print(f"✅ All expected fields present")
            
    except json.JSONDecodeError as e:
        print(f"❌ Invalid JSON: {e}")
        print(f"Raw response: {response}")
        
elif any(keyword in response for keyword in ["DATE:", "STORE:", "TOTAL:"]):
    print(f"✅ KEY-VALUE format detected")
    # Try to extract key-value pairs
    import re
    matches = re.findall(r'([A-Z]+):\s*([^\n]+)', response)
    if matches:
        print(f"Extracted fields:")
        for key, value in matches:
            print(f"  {key}: {value.strip()}")
            
elif any(phrase in response.lower() for phrase in ["not able", "cannot provide", "sorry"]):
    print(f"❌ SAFETY MODE TRIGGERED")
    print(f"This indicates the prompt triggered Llama's safety restrictions")
    print(f"Solution: Use simpler JSON format prompts")
    
else:
    print(f"⚠️ UNSTRUCTURED RESPONSE")
    print(f"Response doesn't match expected patterns")
    print(f"Consider using different prompt format")

# Performance assessment
if inference_time < 30:
    print(f"\n⚡ GOOD performance: {inference_time:.1f}s")
elif inference_time < 60:
    print(f"\n⚠️ ACCEPTABLE performance: {inference_time:.1f}s") 
else:
    print(f"\n❌ SLOW performance: {inference_time:.1f}s")

print(f"\n🎯 For production use:")
print(f"- Llama-3.2-Vision: Use simple JSON prompts only")
print(f"- InternVL3: More flexible, handles complex prompts better")
print(f"- Both models: Shorter max_new_tokens prevents issues")

EXTRACTED TEXT:
DATE: 11-07-2022 STORE: SPOTLIGHT TOTAL: $22. 45 ITEM: Apples (kg) QUANTITY: 1 PRICE: $3. 96 TOTAL: $3. 96 ITEM: Tea Bags (box) QUANTITY: 1 PRICE: $4. 53 TOTAL: $4.

SUMMARY:
Model: llama
Response length: 164 characters
Processing time: 32.43s
Quantization enabled: True
Device: CUDA

RESPONSE ANALYSIS:
✅ KEY-VALUE format detected
Extracted fields:
  DATE: 11-07-2022 STORE: SPOTLIGHT TOTAL: $22. 45 ITEM: Apples (kg) QUANTITY: 1 PRICE: $3. 96 TOTAL: $3. 96 ITEM: Tea Bags (box) QUANTITY: 1 PRICE: $4. 53 TOTAL: $4.

⚠️ ACCEPTABLE performance: 32.4s

🎯 For production use:
- Llama-3.2-Vision: Use simple JSON prompts only
- InternVL3: More flexible, handles complex prompts better
- Both models: Shorter max_new_tokens prevents issues


In [7]:
# Test additional prompts - WITH ULTRA-AGGRESSIVE REPETITION CONTROL
working_test_prompts = [
    "<|image|>Extract store name and total amount in KEY-VALUE format.\n\nOutput format:\nSTORE: [store name]\nTOTAL: [total amount]",
    "<|image|>What type of business document is this? Answer: receipt, invoice, or statement.",
    "<|image|>Extract the date from this document in format DD/MM/YYYY."
]

print("Testing additional prompts with ULTRA-AGGRESSIVE REPETITION CONTROL...\n")

for i, test_prompt in enumerate(working_test_prompts, 1):
    print(f"Test {i}: {test_prompt[:60]}...")
    try:
        start = time.time()
        
        if CONFIG["model_type"] == "llama":
            # Use EXACT same pattern as main inference
            prompt_with_image = test_prompt if test_prompt.startswith("<|image|>") else f"<|image|>{test_prompt}"
            
            inputs = processor(text=prompt_with_image, images=image, return_tensors="pt")
            
            # Same device handling
            device = next(model.parameters()).device
            if device.type != "cpu":
                device_target = str(device).split(":")[0] if ":" in str(device) else str(device)
                inputs = {k: v.to(device_target) if hasattr(v, "to") else v for k, v in inputs.items()}
            
            # ULTRA-AGGRESSIVE: Extremely short tokens for tests
            generation_kwargs = {
                **inputs,
                "max_new_tokens": 96,  # Even shorter: 96 vs 128
                "do_sample": False,
                "pad_token_id": processor.tokenizer.eos_token_id,
                "eos_token_id": processor.tokenizer.eos_token_id,
                "use_cache": True,
            }
            
            with torch.no_grad():
                outputs = model.generate(**generation_kwargs)
            
            raw_result = processor.decode(
                outputs[0][inputs["input_ids"].shape[-1]:],
                skip_special_tokens=True
            )
            
            # Apply ultra-aggressive repetition control
            result = repetition_controller.clean_response(raw_result)
            
        elif CONFIG["model_type"] == "internvl":
            result = model.chat(
                tokenizer=tokenizer,
                pixel_values=pixel_values,
                question=test_prompt,
                generation_config={
                    "max_new_tokens": 96, 
                    "do_sample": False
                }
            )
            if isinstance(result, tuple):
                result = result[0]
            result = repetition_controller.clean_response(result)
        
        elapsed = time.time() - start
        
        # Ultra-strict analysis of results
        if repetition_controller.detect_repetitive_generation(result):
            print(f"❌ STILL REPETITIVE ({elapsed:.1f}s): {result[:60]}...")
            print(f"   Even ultra-aggressive cleaning failed - model has fundamental repetition issue")
        elif any(phrase in result.lower() for phrase in ["not able", "cannot provide", "sorry"]):
            print(f"⚠️ Safety mode triggered ({elapsed:.1f}s): {result[:60]}...")
        elif len(result.strip()) < 3:
            print(f"⚠️ Over-cleaned ({elapsed:.1f}s): '{result}' - may be too aggressive")
        else:
            print(f"✅ SUCCESS ({elapsed:.1f}s): {result[:80]}...")
            print(f"   Length: {len(result)} chars - repetition eliminated")
        
    except Exception as e:
        print(f"❌ Error: {str(e)[:100]}...")
    print("-" * 50)

print("\n🎯 ULTRA-AGGRESSIVE REPETITION CONTROL FEATURES:")
print("🔥 UltraAggressiveRepetitionController - Nuclear option for repetition")
print("🔥 Stricter thresholds:")
print("   - Word repetition: 15% threshold (was 30%)")  
print("   - Phrase repetition: 2 occurrences trigger (was 3)")
print("   - Sentence repetition: Any duplicate removed")
print("🔥 Toxic pattern targeting:")
print("   - 'THANK YOU FOR SHOPPING...' pattern recognition")
print("   - 'All prices include GST...' pattern recognition")
print("   - LaTeX artifact removal")
print("🔥 Early truncation at first repetition detection")
print("🔥 Max 3 occurrences per word across entire text")
print("🔥 Ultra-short token limits (384 main, 96 tests)")
print("🔥 Aggressive artifact cleaning (punctuation, parentheses, etc.)")
print("\n💡 If this still shows repetition, the issue is in the model's generation")
print("   pattern itself, not the post-processing cleaning.")

Testing additional prompts with ULTRA-AGGRESSIVE REPETITION CONTROL...

Test 1: <|image|>Extract store name and total amount in KEY-VALUE fo...
🧹 Cleaning: 184 → 49 chars (73.4% reduction)
✅ SUCCESS (7.9s): <OCR/> SPOTLIGHT TAX INVOICE 888Park 3:53PM QTY 1...
   Length: 49 chars - repetition eliminated
--------------------------------------------------
Test 2: <|image|>What type of business document is this? Answer: rec...
⚠️ Still repetitive after ultra-aggressive cleaning - truncating heavily
🧹 Cleaning: 511 → 3 chars (99.4% reduction)
❌ STILL REPETITIVE (8.1s): ......
   Even ultra-aggressive cleaning failed - model has fundamental repetition issue
--------------------------------------------------
Test 3: <|image|>Extract the date from this document in format DD/MM...
⚠️ Still repetitive after ultra-aggressive cleaning - truncating heavily
🧹 Cleaning: 173 → 20 chars (88.4% reduction)
❌ STILL REPETITIVE (8.2s): 11-07-2022, 11-07......
   Even ultra-aggressive cleaning failed - model

In [8]:
# Memory cleanup
print("Cleaning up memory...")

# Safe cleanup with existence checks
if 'model' in locals() or 'model' in globals():
    try:
        del model
        print("✓ Model deleted")
    except:
        pass

if CONFIG["model_type"] == "llama":
    if 'processor' in locals() or 'processor' in globals():
        try:
            del processor
            print("✓ Processor deleted")
        except:
            pass
elif CONFIG["model_type"] == "internvl":
    if 'tokenizer' in locals() or 'tokenizer' in globals():
        try:
            del tokenizer
            print("✓ Tokenizer deleted")
        except:
            pass

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    print("✓ CUDA cache cleared")

print("✓ Memory cleanup completed")
print("\n🎉 Test completed!")
print("\n📋 SUMMARY OF FIXES APPLIED:")
print("1. ❌ FIXED: Removed repetition_penalty (causes CUDA assert errors)")
print("2. ✅ SAFE: Using minimal generation parameters")
print("3. 🔧 ROBUST: Added proper error handling")
print("4. 🧹 CLEAN: Safe memory cleanup with existence checks")
print("\n🚀 Ready for testing on remote machine!")

Cleaning up memory...
✓ Model deleted
✓ Processor deleted
✓ CUDA cache cleared
✓ Memory cleanup completed

🎉 Test completed!

📋 SUMMARY OF FIXES APPLIED:
1. ❌ FIXED: Removed repetition_penalty (causes CUDA assert errors)
2. ✅ SAFE: Using minimal generation parameters
3. 🔧 ROBUST: Added proper error handling
4. 🧹 CLEAN: Safe memory cleanup with existence checks

🚀 Ready for testing on remote machine!


In [None]:
# Document Type Classification for Taxpayer Work-Related Expense Substantiation
print("🏛️ TAXPAYER WORK-RELATED EXPENSE DOCUMENT CLASSIFICATION")
print("=" * 70)

# Standard document types for taxpayer substantiation
STANDARD_DOCUMENT_TYPES = [
    "FUEL_RECEIPT",           # Fuel and automotive expenses
    "BUSINESS_RECEIPT",       # General business purchases  
    "TAX_INVOICE",           # Business-to-business transactions
    "BANK_STATEMENT",        # Financial transaction records
    "MEAL_RECEIPT",          # Business meal expenses
    "ACCOMMODATION_RECEIPT", # Travel accommodation
    "TRAVEL_DOCUMENT",       # Transport tickets, boarding passes
    "PARKING_TOLL_RECEIPT",  # Parking and toll expenses
    "PROFESSIONAL_SERVICES", # Consultancy, legal, accounting
    "EQUIPMENT_SUPPLIES",    # Office supplies, equipment purchases
    "OTHER_BUSINESS"         # Other legitimate business expenses
]

# Create classification prompt
classification_prompt = f"""<|image|>Classify this business document for taxpayer work-related expense substantiation.

Choose EXACTLY ONE category from this list:
{chr(10).join([f"- {doc_type}" for doc_type in STANDARD_DOCUMENT_TYPES])}

Requirements:
- Analyze document content, layout, and business purpose
- Choose the MOST SPECIFIC applicable category
- If document doesn't clearly fit any category, use OTHER_BUSINESS
- Provide brief justification (max 20 words)

Format:
CLASSIFICATION: [CATEGORY_NAME]
JUSTIFICATION: [brief reason]"""

print(f"Classification Categories ({len(STANDARD_DOCUMENT_TYPES)} types):")
for i, doc_type in enumerate(STANDARD_DOCUMENT_TYPES, 1):
    print(f"  {i:2d}. {doc_type}")

print(f"\nPrompt length: {len(classification_prompt)} characters")
print(f"Test image: {CONFIG['test_image']}")

# Test classification with both models if available
classification_results = {}

for model_name in ["llama", "internvl"]:
    print(f"\n{'-' * 50}")
    print(f"🔍 Testing {model_name.upper()} Classification...")
    
    try:
        # Temporarily switch model for testing
        original_model_type = CONFIG["model_type"]
        CONFIG["model_type"] = model_name
        model_path = CONFIG["model_paths"][model_name]
        
        print(f"Loading {model_name} model...")
        start_load = time.time()
        
        if model_name == "llama":
            # Load Llama model for classification
            if 'model' in locals():
                del model
                torch.cuda.empty_cache()
                
            processor = AutoProcessor.from_pretrained(
                model_path, trust_remote_code=True, local_files_only=True
            )
            
            model_loading_args = {
                "low_cpu_mem_usage": True,
                "torch_dtype": torch.float16,
                "device_map": "cuda:0" if torch.cuda.is_available() else "cpu",
                "local_files_only": True
            }
            
            if CONFIG["enable_quantization"] and torch.cuda.is_available():
                try:
                    from transformers import BitsAndBytesConfig
                    quantization_config = BitsAndBytesConfig(
                        load_in_8bit=True,
                        llm_int8_enable_fp32_cpu_offload=True,
                        llm_int8_skip_modules=["vision_tower", "multi_modal_projector"],
                    )
                    model_loading_args["quantization_config"] = quantization_config
                except ImportError:
                    pass
            
            model = MllamaForConditionalGeneration.from_pretrained(
                model_path, **model_loading_args
            ).eval()
            
            load_time = time.time() - start_load
            print(f"✅ Llama loaded in {load_time:.1f}s")
            
            # Run classification
            start_inference = time.time()
            
            inputs = processor(text=classification_prompt, images=image, return_tensors="pt")
            device = next(model.parameters()).device
            if device.type != "cpu":
                device_target = str(device).split(":")[0] if ":" in str(device) else str(device)
                inputs = {k: v.to(device_target) if hasattr(v, "to") else v for k, v in inputs.items()}
            
            generation_kwargs = {
                **inputs,
                "max_new_tokens": 64,  # Short response for classification
                "do_sample": False,
                "pad_token_id": processor.tokenizer.eos_token_id,
                "eos_token_id": processor.tokenizer.eos_token_id,
                "use_cache": True,
            }
            
            with torch.no_grad():
                outputs = model.generate(**generation_kwargs)
            
            raw_response = processor.decode(
                outputs[0][inputs["input_ids"].shape[-1]:],
                skip_special_tokens=True
            )
            
            # Clean response (minimal cleaning for classification)
            response = raw_response.strip()
            if len(response) > 200:  # If too long, truncate
                response = response[:200] + "..."
                
            inference_time = time.time() - start_inference
            
        elif model_name == "internvl":
            # Load InternVL model for classification  
            if 'model' in locals():
                del model
                torch.cuda.empty_cache()
                
            tokenizer = AutoTokenizer.from_pretrained(
                model_path, trust_remote_code=True, local_files_only=True
            )
            
            model_kwargs = {
                "low_cpu_mem_usage": True,
                "trust_remote_code": True,
                "torch_dtype": torch.bfloat16,
                "local_files_only": True
            }
            
            if CONFIG["enable_quantization"] and torch.cuda.is_available():
                try:
                    model_kwargs["load_in_8bit"] = True
                except Exception:
                    pass
            
            model = AutoModel.from_pretrained(model_path, **model_kwargs).eval()
            
            if torch.cuda.is_available() and not CONFIG["enable_quantization"]:
                model = model.cuda()
                
            load_time = time.time() - start_load
            print(f"✅ InternVL loaded in {load_time:.1f}s")
            
            # Run classification
            start_inference = time.time()
            
            # Prepare image for InternVL
            image_size = 448
            transform = T.Compose([
                T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
                T.ToTensor(),
                T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])
            
            pixel_values = transform(image).unsqueeze(0)
            if torch.cuda.is_available():
                pixel_values = pixel_values.cuda().to(torch.bfloat16).contiguous()
            
            generation_config = {
                "max_new_tokens": 64,
                "do_sample": False,
                "pad_token_id": tokenizer.eos_token_id
            }
            
            response = model.chat(
                tokenizer=tokenizer,
                pixel_values=pixel_values,
                question=classification_prompt,
                generation_config=generation_config
            )
            
            if isinstance(response, tuple):
                response = response[0]
                
            inference_time = time.time() - start_inference
        
        # Analyze classification result
        print(f"✅ Classification completed in {inference_time:.1f}s")
        print(f"Response length: {len(response)} characters")
        
        # Extract classification and justification
        import re
        classification_match = re.search(r'CLASSIFICATION:\s*([A-Z_]+)', response)
        justification_match = re.search(r'JUSTIFICATION:\s*([^\n]+)', response)
        
        extracted_classification = classification_match.group(1) if classification_match else "UNKNOWN"
        extracted_justification = justification_match.group(1) if justification_match else "No justification provided"
        
        # Validate classification
        is_valid = extracted_classification in STANDARD_DOCUMENT_TYPES
        
        classification_results[model_name] = {
            "classification": extracted_classification,
            "justification": extracted_justification,
            "valid": is_valid,
            "inference_time": inference_time,
            "load_time": load_time,
            "raw_response": response
        }
        
        print(f"📋 CLASSIFICATION: {extracted_classification}")
        print(f"📝 JUSTIFICATION: {extracted_justification}")
        print(f"✅ VALID: {'Yes' if is_valid else 'No'}")
        
        if not is_valid:
            print(f"⚠️  '{extracted_classification}' not in standard categories")
            # Find closest match
            from difflib import get_close_matches
            close_matches = get_close_matches(extracted_classification, STANDARD_DOCUMENT_TYPES, n=3, cutoff=0.6)
            if close_matches:
                print(f"🔍 Similar categories: {', '.join(close_matches)}")
        
        # Memory cleanup
        if 'model' in locals():
            del model
        if model_name == "llama" and 'processor' in locals():
            del processor
        elif model_name == "internvl" and 'tokenizer' in locals():
            del tokenizer
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"❌ {model_name.upper()} classification failed: {str(e)[:100]}...")
        classification_results[model_name] = {
            "error": str(e),
            "inference_time": 0,
            "load_time": 0
        }
    
    # Restore original model type
    CONFIG["model_type"] = original_model_type

# Final comparison
print(f"\n{'=' * 70}")
print("🏆 CLASSIFICATION COMPARISON RESULTS")
print(f"{'=' * 70}")

comparison_table = []
comparison_table.append(["Model", "Classification", "Valid", "Time (s)", "Justification"])
comparison_table.append(["-" * 10, "-" * 15, "-" * 5, "-" * 8, "-" * 30])

for model_name, result in classification_results.items():
    if "error" not in result:
        comparison_table.append([
            model_name.upper(),
            result["classification"],
            "✅" if result["valid"] else "❌",
            f"{result['inference_time']:.1f}",
            result["justification"][:30] + "..." if len(result["justification"]) > 30 else result["justification"]
        ])
    else:
        comparison_table.append([
            model_name.upper(),
            "ERROR",
            "❌",
            "0.0",
            result["error"][:30] + "..."
        ])

# Print table
for row in comparison_table:
    print(f"{row[0]:<10} {row[1]:<15} {row[2]:<5} {row[3]:<8} {row[4]}")

# Analysis
successful_results = [result for result in classification_results.values() if "error" not in result]
if len(successful_results) >= 2:
    print(f"\n🔍 ANALYSIS:")
    
    # Check agreement
    classifications = [result["classification"] for result in successful_results]
    if len(set(classifications)) == 1:
        print(f"✅ MODELS AGREE: Both classified as {classifications[0]}")
    else:
        print(f"⚠️  MODELS DISAGREE: {', '.join(classifications)}")
    
    # Performance comparison
    times = [result["inference_time"] for result in successful_results]
    fastest_idx = times.index(min(times))
    model_names = [name for name in classification_results.keys() if "error" not in classification_results[name]]
    
    print(f"⚡ FASTEST: {model_names[fastest_idx].upper()} ({min(times):.1f}s)")
    
    # Validity check
    valid_results = [result for result in successful_results if result["valid"]]
    if valid_results:
        print(f"✅ VALID CLASSIFICATIONS: {len(valid_results)}/{len(successful_results)}")
    else:
        print(f"⚠️  NO VALID CLASSIFICATIONS PRODUCED")

print(f"\n📚 DOCUMENT TYPE STANDARD:")
print(f"This test validates compliance with taxpayer work-related expense")
print(f"substantiation requirements using {len(STANDARD_DOCUMENT_TYPES)} standard categories.")
print(f"\n🎯 For production use:")
print(f"- Use the model that produces valid classifications consistently")
print(f"- Consider ensemble approach if models disagree frequently")
print(f"- Monitor classification accuracy against manual validation")

print(f"\n✅ Document classification test completed!")