# Llama Document-Type-Aware Adaptive Extraction

Processes all images using multi-stage adaptive extraction:
1. **Stage 0**: Classify document type (INVOICE/RECEIPT/BANK_STATEMENT)
2. **Stage 1**: Classify structure (if BANK_STATEMENT: FLAT/GROUPED)
3. **Stage 2**: Apply document-type and structure-specific extraction prompt

Outputs compatible with model_comparison.ipynb

In [None]:
import gc
import json
import random
import time
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import yaml
from PIL import Image
from rich import print as rprint
from rich.console import Console
from rich.progress import track
from transformers import AutoProcessor, MllamaForConditionalGeneration

# Initialize console for rich output
console = Console()

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
rprint("[green]✅ Imports loaded[/green]")

## Configuration

In [None]:
# Environment-specific base paths
ENVIRONMENT_BASES = {
    'sandbox': '/home/jovyan/nfs_share/tod',
    'efs': '/efs/shared/PoC_data'
}
base_data_path = ENVIRONMENT_BASES['sandbox']

CONFIG = {
    # Model settings
    # 'MODEL_PATH': "/efs/shared/PTM/Llama-3.2-11B-Vision-Instruct",
    'MODEL_PATH': "/home/jovyan/nfs_share/models/Llama-3.2-11B-Vision-Instruct",
    
    # Data paths - Using base path for consistency
    'DATA_DIR': f'{base_data_path}/LMM_POC/evaluation_data',
    'GROUND_TRUTH': f'{base_data_path}/LMM_POC/evaluation_data/ground_truth.csv',
    
    # Prompt files - Using base path for consistency
    'PROMPT_FILE_DOCTYPE': f'{base_data_path}/LMM_POC/prompts/document_type_detection.yaml',
    'PROMPT_FILE_INVOICE': f'{base_data_path}/LMM_POC/prompts/generated/llama_invoice_prompt.yaml',
    'PROMPT_FILE_RECEIPT': f'{base_data_path}/LMM_POC/prompts/generated/llama_receipt_prompt.yaml',
    'PROMPT_FILE_BANK': f'{base_data_path}/LMM_POC/prompts/generated/llama_bank_statement_prompt.yaml',
    
    # Output directory - Using base path for consistency
    'OUTPUT_DIR': f'{base_data_path}/LMM_POC/output',
    
    # Token limits
    'MAX_NEW_TOKENS_DOCTYPE': 50,
    'MAX_NEW_TOKENS_STRUCTURE': 50,
    'MAX_NEW_TOKENS_EXTRACT': 2000,
    
    # Verbosity control
    'VERBOSE': True,  # Show stage-by-stage progress
    'SHOW_PROMPTS': False,  # Show actual prompts being used
}

# Create output directory
output_dir = Path(CONFIG['OUTPUT_DIR'])
output_dir.mkdir(exist_ok=True)

# Timestamp for output files
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

rprint("[green]✅ Configuration loaded[/green]")
rprint(f"[cyan]  Environment: {[k for k, v in ENVIRONMENT_BASES.items() if v == base_data_path][0]}[/cyan]")
rprint(f"[cyan]  Base path: {base_data_path}[/cyan]")
rprint(f"[cyan]  Output directory: {output_dir}[/cyan]")
rprint(f"[cyan]  Timestamp: {TIMESTAMP}[/cyan]")
rprint(f"[cyan]  Verbosity: {'ON' if CONFIG['VERBOSE'] else 'OFF'}[/cyan]")

## Load Model

In [3]:
# Load model
rprint("[bold green]🔧 Loading Llama model...[/bold green]")

model = MllamaForConditionalGeneration.from_pretrained(
    CONFIG['MODEL_PATH'],
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(CONFIG['MODEL_PATH'])

rprint("[green]✅ Model loaded[/green]")

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

## Load All Prompts

Loading prompts for:
- Document type detection
- Invoice extraction
- Receipt extraction  
- Bank statement extraction (flat and grouped variants)

In [4]:
# Load all prompts

# Document type detection prompt
with open(CONFIG['PROMPT_FILE_DOCTYPE'], 'r') as f:
    doctype_data = yaml.safe_load(f)
    DOCTYPE_PROMPT = doctype_data['prompts']['detection']['prompt']

# Invoice extraction prompt
with open(CONFIG['PROMPT_FILE_INVOICE'], 'r') as f:
    invoice_data = yaml.safe_load(f)
    INVOICE_PROMPT = invoice_data['prompts']['invoice']['prompt']

# Receipt extraction prompt
with open(CONFIG['PROMPT_FILE_RECEIPT'], 'r') as f:
    receipt_data = yaml.safe_load(f)
    RECEIPT_PROMPT = receipt_data['prompts']['receipt']['prompt']

# Bank statement extraction prompts
with open(CONFIG['PROMPT_FILE_BANK'], 'r') as f:
    bank_data = yaml.safe_load(f)
    BANK_PROMPTS = {
        'flat': bank_data['prompts']['bank_statement_flat']['prompt'],
        'date_grouped': bank_data['prompts']['bank_statement_date_grouped']['prompt']
    }

# Bank statement structure classification prompt
STRUCTURE_CLASSIFICATION_PROMPT = """Look at how dates are displayed in this bank statement's transaction list.

Answer with ONLY one word:
- FLAT (if dates appear as the FIRST COLUMN in a table row, like: "05/05/2025 | Purchase | $22.50")
- GROUPED (if dates appear as SECTION HEADERS above transactions, like: "Thu 05 Sep 2025" followed by indented transaction details below)

The key difference: FLAT has dates IN the table columns, GROUPED has dates AS headers ABOVE the rows.

Answer (one word only):"""

rprint("[green]✅ All prompts loaded[/green]")
rprint(f"[cyan]  Document type detection: {len(DOCTYPE_PROMPT)} chars[/cyan]")
rprint(f"[cyan]  Invoice extraction: {len(INVOICE_PROMPT)} chars[/cyan]")
rprint(f"[cyan]  Receipt extraction: {len(RECEIPT_PROMPT)} chars[/cyan]")
rprint(f"[cyan]  Bank flat extraction: {len(BANK_PROMPTS['flat'])} chars[/cyan]")
rprint(f"[cyan]  Bank grouped extraction: {len(BANK_PROMPTS['date_grouped'])} chars[/cyan]")

## Multi-Turn Chat Function

In [5]:
def chat_with_mllm(model, processor, prompt, images, messages=None, max_new_tokens=2000, do_sample=False):
    """Multi-turn chat using working pattern from Medium article."""
    if messages is None:
        messages = []
    
    if len(messages) == 0:
        messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
    else:
        messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})
    
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(images=images, text=text, return_tensors="pt").to(model.device)
    
    # Deterministic generation: explicitly disable sampling parameters
    generation_args = {
        "max_new_tokens": max_new_tokens,
        "do_sample": do_sample,
        "temperature": None if not do_sample else 0.6,
        "top_p": None if not do_sample else 0.9
    }
    generate_ids = model.generate(**inputs, **generation_args)
    
    # Trim input tokens from output
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:-1]
    generated_texts = processor.decode(generate_ids[0], clean_up_tokenization_spaces=False)
    
    messages.append({"role": "assistant", "content": [{"type": "text", "text": generated_texts}]})
    
    return generated_texts, messages

rprint("[green]✅ Chat function defined[/green]")

## Parser Functions

Functions to parse VLM responses:
- Document type classification
- Bank statement structure classification
- Field extraction parsing

In [6]:
def parse_document_type(response):
    """Parse document type from VLM response."""
    response = response.strip().upper()
    if "INVOICE" in response:
        return "INVOICE"
    elif "RECEIPT" in response:
        return "RECEIPT"
    elif "BANK" in response or "STATEMENT" in response:
        return "BANK_STATEMENT"
    else:
        return "INVOICE"  # Default fallback

def parse_structure_type(response):
    """Parse bank statement structure type from VLM response."""
    response = response.strip().upper()
    if "FLAT" in response:
        return "flat"
    elif "GROUPED" in response or "DATE" in response:
        return "date_grouped"
    else:
        return "flat"  # Default fallback

def parse_extraction(extraction_text):
    """Parse extraction text into field dictionary."""
    extracted_fields = {}
    
    for line in extraction_text.split('\n'):
        line = line.strip()
        if ':' in line and not line.startswith('#'):
            parts = line.split(':', 1)
            if len(parts) == 2:
                field_name = parts[0].strip()
                field_value = parts[1].strip()
                extracted_fields[field_name] = field_value if field_value else 'NOT_FOUND'
    
    return extracted_fields

rprint("[green]✅ Parser functions defined[/green]")

## Discover Images

In [7]:
# Discover all images (no filtering by document type)
data_dir = Path(CONFIG['DATA_DIR'])
image_files = sorted(data_dir.glob("*.png"))

rprint(f"[green]✅ Found {len(image_files)} images to process[/green]")

rprint("\n[bold blue]Images to process:[/bold blue]")
for img in image_files:
    rprint(f"[cyan]  - {img.name}[/cyan]")

## Multi-Stage Batch Processing

 - **Stage 0**: Document Type Classification (INVOICE/RECEIPT/BANK_STATEMENT)
 - **Stage 1**: Structure Classification (for BANK_STATEMENT only: FLAT/GROUPED)
 - **Stage 2**: Document-Type-Aware Extraction (using appropriate prompt)

In [None]:
# Multi-stage adaptive extraction with explicit stages (InternVL3-style)
results = []
processing_times = []
doctype_counts = {'INVOICE': 0, 'RECEIPT': 0, 'BANK_STATEMENT': 0}
structure_counts = {'flat': 0, 'date_grouped': 0}

rprint("\n[bold green]🚀 Starting multi-stage adaptive extraction...[/bold green]\n")

for idx, image_path in enumerate(track(image_files, description="Processing images"), 1):
    image_name = image_path.name
    
    try:
        start_time = time.time()
        
        # Load image
        image = Image.open(image_path)
        images = [image]
        messages = []
        
        # ===================================================================
        # STAGE 0: Document Type Classification
        # ===================================================================
        if CONFIG['VERBOSE']:
            rprint(f"\n[bold blue]Processing [{idx}/{len(image_files)}]: {image_name}[/bold blue]")
            rprint("[dim]Stage 0: Document type detection...[/dim]")
        
        doctype_answer, messages = chat_with_mllm(
            model, processor, DOCTYPE_PROMPT, images, messages,
            max_new_tokens=CONFIG['MAX_NEW_TOKENS_DOCTYPE']
        )
        
        # Parse document type
        document_type = parse_document_type(doctype_answer)
        doctype_counts[document_type] += 1
        
        # ===================================================================
        # STAGE 1: Structure Classification (only for BANK_STATEMENT)
        # ===================================================================
        structure_type = "N/A"
        structure_answer = "N/A"
        
        if document_type == "BANK_STATEMENT":
            if CONFIG['VERBOSE']:
                rprint("[dim]Stage 1: Bank statement structure classification...[/dim]")
            
            structure_answer, messages = chat_with_mllm(
                model, processor, STRUCTURE_CLASSIFICATION_PROMPT, images, messages,
                max_new_tokens=CONFIG['MAX_NEW_TOKENS_STRUCTURE']
            )
            structure_type = parse_structure_type(structure_answer)
            structure_counts[structure_type] += 1
            extraction_prompt = BANK_PROMPTS[structure_type]
            prompt_key = f"bank_statement_{structure_type}"
            
        elif document_type == "INVOICE":
            extraction_prompt = INVOICE_PROMPT
            prompt_key = "invoice"
            
        elif document_type == "RECEIPT":
            extraction_prompt = RECEIPT_PROMPT
            prompt_key = "receipt"
        
        # ===================================================================
        # STAGE 2: Document-Type-Aware Extraction
        # ===================================================================
        if CONFIG['VERBOSE']:
            rprint(f"[dim]Stage 2: Extraction using {prompt_key}...[/dim]")
        
        extraction_result, messages = chat_with_mllm(
            model, processor, extraction_prompt, images, messages,
            max_new_tokens=CONFIG['MAX_NEW_TOKENS_EXTRACT']
        )
        
        # Parse extraction
        extracted_fields = parse_extraction(extraction_result)
        
        # Store results
        result = {
            'image_file': image_name,
            'document_type': document_type,
            'structure_type': structure_type,
            'prompt_used': prompt_key,
            'doctype_classification': doctype_answer.strip(),
            'structure_classification': structure_answer.strip() if isinstance(structure_answer, str) else structure_answer,
            'extraction_raw': extraction_result,
            **extracted_fields
        }
        results.append(result)
        
        processing_time = time.time() - start_time
        processing_times.append(processing_time)
        
        structure_display = structure_type if structure_type != 'N/A' else 'direct'
        rprint(f"[green]✅ {image_name}: {document_type} ({structure_display}) - {processing_time:.2f}s[/green]")
        
    except Exception as e:
        rprint(f"[red]❌ {image_name}: Error - {e}[/red]")
        results.append({
            'image_file': image_name,
            'document_type': 'ERROR',
            'structure_type': 'ERROR',
            'error': str(e)
        })
        processing_times.append(0)
    
    finally:
        # Memory cleanup after each image
        if 'image' in locals():
            image.close()
        
        # Clear GPU cache to prevent OOM on large batches
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Periodic garbage collection every 3 images
        if idx % 3 == 0:
            gc.collect()

console.rule("[bold green]Batch Processing Complete[/bold green]")

# Display summary statistics
rprint(f"\n[bold blue]📊 Document Type Classification Summary:[/bold blue]")
rprint(f"[cyan]  Invoices: {doctype_counts['INVOICE']}[/cyan]")
rprint(f"[cyan]  Receipts: {doctype_counts['RECEIPT']}[/cyan]")
rprint(f"[cyan]  Bank Statements: {doctype_counts['BANK_STATEMENT']}[/cyan]")

if doctype_counts['BANK_STATEMENT'] > 0:
    rprint(f"\n[bold blue]📊 Bank Statement Structure Summary:[/bold blue]")
    rprint(f"[cyan]  Flat table: {structure_counts['flat']}[/cyan]")
    rprint(f"[cyan]  Date-grouped: {structure_counts['date_grouped']}[/cyan]")

## Save Results

In [9]:
# Convert results to DataFrame
df = pd.DataFrame(results)

# Save to CSV (compatible with model_comparison.ipynb)
csv_output = output_dir / f"llama_adaptive_batch_results_{TIMESTAMP}.csv"
df.to_csv(csv_output, index=False)

rprint(f"[green]✅ CSV saved to: {csv_output}[/green]")
rprint(f"[cyan]  Rows: {len(df)}[/cyan]")
rprint(f"[cyan]  Columns: {len(df.columns)}[/cyan]")

# Save detailed JSON results
json_output = output_dir / f"llama_adaptive_batch_results_{TIMESTAMP}.json"
with open(json_output, 'w') as f:
    json.dump(results, f, indent=2)

rprint(f"[green]✅ JSON saved to: {json_output}[/green]")

## Display Sample Results

In [10]:
# Display sample results
console.rule("[bold blue]Sample Results[/bold blue]")

display_cols = ['image_file', 'document_type', 'structure_type', 'prompt_used']

rprint(df[display_cols].to_string(index=False))

## Summary statistics

In [None]:
print("\n📊 DOCUMENT-TYPE-AWARE ADAPTIVE EXTRACTION SUMMARY")
print("="*80)
print(f"Total images processed: {len(results)}")
print(f"Successful extractions: {len([r for r in results if 'error' not in r])}")
print(f"Errors: {len([r for r in results if 'error' in r])}")

print("\nDocument Type Classification:")
print(f"  Invoices: {doctype_counts['INVOICE']}")
print(f"  Receipts: {doctype_counts['RECEIPT']}")
print(f"  Bank Statements: {doctype_counts['BANK_STATEMENT']}")

if doctype_counts['BANK_STATEMENT'] > 0:
    print("\nBank Statement Structure Classification:")
    print(f"  Flat table format: {structure_counts['flat']}")
    print(f"  Date-grouped format: {structure_counts['date_grouped']}")

print("\nPrompts Used:")
prompt_usage = {}
for result in results:
    if 'prompt_used' in result:
        prompt = result['prompt_used']
        prompt_usage[prompt] = prompt_usage.get(prompt, 0) + 1

for prompt, count in sorted(prompt_usage.items()):
    print(f"  {prompt}: {count}")

print("="*80)

# Field extraction statistics
if len(df) > 0:
    field_cols = [col for col in df.columns if col not in [
        'image_file', 'document_type', 'structure_type', 'prompt_used', 
        'doctype_classification', 'structure_classification', 'extraction_raw', 'error'
    ]]
    
    if field_cols:
        print("\n📈 Field Extraction Coverage:")
        for field in field_cols:
            if field in df.columns:
                found_count = df[field].notna().sum()
                coverage = (found_count / len(df)) * 100
                print(f"  {field}: {found_count}/{len(df)} ({coverage:.1f}%)")


📊 DOCUMENT-TYPE-AWARE ADAPTIVE EXTRACTION SUMMARY
Total images processed: 9
Successful extractions: 9
Errors: 0

Document Type Classification:
  Invoices: 3
  Receipts: 3
  Bank Statements: 3

Bank Statement Structure Classification:
  Flat table format: 1
  Date-grouped format: 2

Prompts Used:
  bank_statement_date_grouped: 2
  bank_statement_flat: 1
  invoice: 3
  receipt: 3

📈 Field Extraction Coverage:
  DOCUMENT_TYPE: 9/9 (100.0%)
  BUSINESS_ABN: 6/9 (66.7%)
  SUPPLIER_NAME: 6/9 (66.7%)
  BUSINESS_ADDRESS: 6/9 (66.7%)
  PAYER_NAME: 6/9 (66.7%)
  PAYER_ADDRESS: 6/9 (66.7%)
  INVOICE_DATE: 6/9 (66.7%)
  LINE_ITEM_DESCRIPTIONS: 9/9 (100.0%)
  LINE_ITEM_QUANTITIES: 6/9 (66.7%)
  LINE_ITEM_PRICES: 6/9 (66.7%)
  LINE_ITEM_TOTAL_PRICES: 6/9 (66.7%)
  IS_GST_INCLUDED: 6/9 (66.7%)
  GST_AMOUNT: 6/9 (66.7%)
  TOTAL_AMOUNT: 6/9 (66.7%)
  STATEMENT_DATE_RANGE: 3/9 (33.3%)
  TRANSACTION_DATES: 3/9 (33.3%)
  TRANSACTION_AMOUNTS_PAID: 3/9 (33.3%)


## View Individual Extraction

Uncomment and run to view detailed extraction for a specific image:

In [13]:
# View detailed extraction for specific image
image_to_view = "image_003.png"  # Change this

result = next((r for r in results if r['image_file'] == image_to_view), None)

if result:
    print(f"\n🔍 Detailed Extraction: {image_to_view}")
    print("="*80)
    print(f"Document Type: {result['document_type']}")
    print(f"Structure Type: {result['structure_type']}")
    print(f"Prompt Used: {result['prompt_used']}")
    print(f"\nDocument Type Classification Response:")
    print(result.get('doctype_classification', 'N/A'))
    print(f"\nStructure Classification Response:")
    print(result.get('structure_classification', 'N/A'))
    print(f"\nExtraction Result:")
    print(result.get('extraction_raw', 'N/A'))
    print("="*80)
else:
    print(f"Image {image_to_view} not found in results")


🔍 Detailed Extraction: image_003.png
Document Type: BANK_STATEMENT
Structure Type: flat
Prompt Used: bank_statement_flat

Document Type Classification Response:
This is a bank statement.

Structure Classification Response:
FLAT

Extraction Result:
DOCUMENT_TYPE: BANK_STATEMENT
STATEMENT_DATE_RANGE: 03/05/2025 to 10/05/2025
TRANSACTION_DATES: 03/05/2025 | 04/05/2025 | 05/05/2025 | 06/05/2025 | 07/05/2025 | 08/05/2025 | 09/05/2025 | 10/05/2025
LINE_ITEM_DESCRIPTIONS: ONLINE PURCHASE AMAZON AU | EFTPOS PURCHASE COLES EXP | EFTPOS PURCHASE COLES EXP | DIRECT CREDIT SALARY | ATM WITHDRAWAL ANZ ATM | EFTPOS PURCHASE COLES EXP | INTEREST PAYMENT | ATM WITHDRAWAL ANZ ATM
TRANSACTION_AMOUNTS_PAID: $288.03 | $22.50 | $114.66 | $187.59 | $112.50 | $5.16 | $146.72
