# Llama Pipeline-Oriented Batch Processing

**Transparent pandas-based pipeline for processing images**

Clear separation of extraction → parsing → cleaning → evaluation stages:
1. **Stage 0**: Classify document type (INVOICE/RECEIPT/BANK_STATEMENT)
2. **Stage 1**: Classify structure (if BANK_STATEMENT: FLAT/GROUPED)
3. **Stage 2**: Extract field values using document-type-aware prompts
4. **Stage 3**: Parse responses (text → structured fields) using `hybrid_parse_response`
5. **Stage 4**: Clean and normalize field values using `ExtractionCleaner`
6. **Stage 5**: Evaluate against ground truth (optional)

**Key Features:**
- Inspectable at every stage
- Checkpointing support
- Scalable to 10,000+ images
- Compatible with model_comparison.ipynb

In [1]:
#Cell 1
import gc
import json
import random
import time
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import yaml
from PIL import Image
from rich import print as rprint
from rich.console import Console
from transformers import AutoProcessor, MllamaForConditionalGeneration

# Pipeline library imports - NEW lightweight library (replacing common/)
from pipeline_lib import (
    hybrid_parse_response,
    ExtractionCleaner,
    load_ground_truth,
    calculate_field_accuracy,
    stage_3_parsing,
    stage_4_cleaning,
    stage_5_evaluation,
    show_pipeline_memory,
)

# Initialize console for rich output
console = Console()

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
rprint("[green]✅ Imports loaded (using pipeline_lib)[/green]")

## Configuration

In [2]:
#Cell 2
# Environment-specific base paths
ENVIRONMENT_BASES = {
    'sandbox': '/home/jovyan/nfs_share/tod',
    'efs': '/efs/shared/PoC_data'
}
base_data_path = ENVIRONMENT_BASES['sandbox']

CONFIG = {
    # Model settings
    # 'MODEL_PATH': "/efs/shared/PTM/Llama-3.2-11B-Vision-Instruct",
    'MODEL_PATH': "/home/jovyan/nfs_share/models/Llama-3.2-11B-Vision-Instruct",
    
    # Data paths - Using base path for consistency
    'DATA_DIR': f'{base_data_path}/LMM_POC/evaluation_data',
    'GROUND_TRUTH': f'{base_data_path}/LMM_POC/evaluation_data/ground_truth.csv',
    
    # Prompt files - Using base path for consistency
    'PROMPT_FILE_DOCTYPE': f'{base_data_path}/LMM_POC/prompts/document_type_detection.yaml',
    'PROMPT_FILE_INVOICE': f'{base_data_path}/LMM_POC/prompts/generated/llama_invoice_prompt.yaml',
    'PROMPT_FILE_RECEIPT': f'{base_data_path}/LMM_POC/prompts/generated/llama_receipt_prompt.yaml',
    'PROMPT_FILE_BANK': f'{base_data_path}/LMM_POC/prompts/generated/llama_bank_statement_prompt.yaml',
    
    # Output directory - Using base path for consistency
    'OUTPUT_DIR': f'{base_data_path}/LMM_POC/output',
    
    # Token limits
    'MAX_NEW_TOKENS_DOCTYPE': 50,
    'MAX_NEW_TOKENS_STRUCTURE': 50,
    'MAX_NEW_TOKENS_EXTRACT': 2000,
    
    # Pipeline control
    'INFERENCE_ONLY': False,  # Set to True to skip ground truth evaluation
    
    # Verbosity control
    'VERBOSE': True,  # Show stage-by-stage progress
    'SHOW_PROMPTS': False,  # Show actual prompts being used
}

# Make GROUND_TRUTH conditional based on INFERENCE_ONLY mode
if CONFIG['INFERENCE_ONLY']:
    CONFIG['GROUND_TRUTH'] = None

# Define expected fields (matching ground truth)
FIELD_COLUMNS = [
    'DOCUMENT_TYPE', 'BUSINESS_ABN', 'SUPPLIER_NAME', 'BUSINESS_ADDRESS',
    'PAYER_NAME', 'PAYER_ADDRESS', 'INVOICE_DATE', 'LINE_ITEM_DESCRIPTIONS',
    'LINE_ITEM_QUANTITIES', 'LINE_ITEM_PRICES', 'LINE_ITEM_TOTAL_PRICES',
    'IS_GST_INCLUDED', 'GST_AMOUNT', 'TOTAL_AMOUNT', 'STATEMENT_DATE_RANGE',
    'TRANSACTION_DATES', 'TRANSACTION_AMOUNTS_PAID'
]

# Create output directory
output_dir = Path(CONFIG['OUTPUT_DIR'])
output_dir.mkdir(exist_ok=True)

# Create checkpoint directory
checkpoint_dir = output_dir / 'checkpoints'
checkpoint_dir.mkdir(exist_ok=True)

# Timestamp for output files
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# Initialize extraction cleaner
cleaner = ExtractionCleaner(debug=CONFIG['VERBOSE'])

rprint("[green]✅ Configuration loaded[/green]")
rprint(f"[cyan]  Environment: {[k for k, v in ENVIRONMENT_BASES.items() if v == base_data_path][0]}[/cyan]")
rprint(f"[cyan]  Base path: {base_data_path}[/cyan]")
rprint(f"[cyan]  Output directory: {output_dir}[/cyan]")
rprint(f"[cyan]  Checkpoint directory: {checkpoint_dir}[/cyan]")
rprint(f"[cyan]  Timestamp: {TIMESTAMP}[/cyan]")
rprint(f"[cyan]  Mode: {'Inference-only' if CONFIG['INFERENCE_ONLY'] else 'Evaluation mode'}[/cyan]")
rprint(f"[cyan]  Expected fields: {len(FIELD_COLUMNS)}[/cyan]")

## Load Model

In [3]:
#Cell 3
# Load model with diagnostics
from pipeline_lib.model_diagnostics import show_model_diagnostics

rprint("[bold green]🔧 Loading Llama model...[/bold green]")

model = MllamaForConditionalGeneration.from_pretrained(
    CONFIG['MODEL_PATH'],
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(CONFIG['MODEL_PATH'])

# Show diagnostics (only if VERBOSE is True)
show_model_diagnostics(
    model, 
    processor, 
    CONFIG['MODEL_PATH'], 
    CONFIG['MAX_NEW_TOKENS_EXTRACT'],
    verbose=CONFIG['VERBOSE']
)

rprint("[green]✅ Model loaded[/green]")

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

## Load All Prompts

Loading prompts for:
- Document type detection
- Invoice extraction
- Receipt extraction  
- Bank statement extraction (flat and grouped variants)

In [4]:
#Cell 4
# Load all prompts

# Document type detection prompt
with open(CONFIG['PROMPT_FILE_DOCTYPE'], 'r') as f:
    doctype_data = yaml.safe_load(f)
    DOCTYPE_PROMPT = doctype_data['prompts']['detection']['prompt']

# Invoice extraction prompt
with open(CONFIG['PROMPT_FILE_INVOICE'], 'r') as f:
    invoice_data = yaml.safe_load(f)
    INVOICE_PROMPT = invoice_data['prompts']['invoice']['prompt']

# Receipt extraction prompt
with open(CONFIG['PROMPT_FILE_RECEIPT'], 'r') as f:
    receipt_data = yaml.safe_load(f)
    RECEIPT_PROMPT = receipt_data['prompts']['receipt']['prompt']

# Bank statement extraction prompts
with open(CONFIG['PROMPT_FILE_BANK'], 'r') as f:
    bank_data = yaml.safe_load(f)
    BANK_PROMPTS = {
        'flat': bank_data['prompts']['bank_statement_flat']['prompt'],
        'date_grouped': bank_data['prompts']['bank_statement_date_grouped']['prompt']
    }

# Bank statement structure classification prompt
STRUCTURE_CLASSIFICATION_PROMPT = """Look at how dates are displayed in this bank statement's transaction list.

Answer with ONLY one word:
- FLAT (if dates appear as the FIRST COLUMN in a table row, like: "05/05/2025 | Purchase | $22.50")
- GROUPED (if dates appear as SECTION HEADERS above transactions, like: "Thu 05 Sep 2025" followed by indented transaction details below)

The key difference: FLAT has dates IN the table columns, GROUPED has dates AS headers ABOVE the rows.

Answer (one word only):"""

rprint("[green]✅ All prompts loaded[/green]")
rprint(f"[cyan]  Document type detection: {len(DOCTYPE_PROMPT)} chars[/cyan]")
rprint(f"[cyan]  Invoice extraction: {len(INVOICE_PROMPT)} chars[/cyan]")
rprint(f"[cyan]  Receipt extraction: {len(RECEIPT_PROMPT)} chars[/cyan]")
rprint(f"[cyan]  Bank flat extraction: {len(BANK_PROMPTS['flat'])} chars[/cyan]")
rprint(f"[cyan]  Bank grouped extraction: {len(BANK_PROMPTS['date_grouped'])} chars[/cyan]")

## Multi-Turn Chat Function

In [5]:
#Cell 5
def chat_with_mllm(model, processor, prompt, images, messages=None, max_new_tokens=2000, do_sample=False):
    """Multi-turn chat using working pattern from Medium article."""
    if messages is None:
        messages = []
    
    if len(messages) == 0:
        messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
    else:
        messages.append({"role": "user", "content": [{"type": "text", "text": prompt}]})
    
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(images=images, text=text, return_tensors="pt").to(model.device)
    
    # Deterministic generation: explicitly disable sampling parameters
    generation_args = {
        "max_new_tokens": max_new_tokens,
        "do_sample": do_sample,
        "temperature": None if not do_sample else 0.6,
        "top_p": None if not do_sample else 0.9
    }
    generate_ids = model.generate(**inputs, **generation_args)
    
    # Trim input tokens from output
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:-1]
    generated_texts = processor.decode(generate_ids[0], clean_up_tokenization_spaces=False)
    
    messages.append({"role": "assistant", "content": [{"type": "text", "text": generated_texts}]})
    
    return generated_texts, messages

rprint("[green]✅ Chat function defined[/green]")

## Parser Functions

Functions to parse VLM responses:
- Document type classification
- Bank statement structure classification
- Field extraction parsing

In [6]:
#Cell 6
def parse_document_type(response):
    """Parse document type from VLM response."""
    response = response.strip().upper()
    if "INVOICE" in response:
        return "INVOICE"
    elif "RECEIPT" in response:
        return "RECEIPT"
    elif "BANK" in response or "STATEMENT" in response:
        return "BANK_STATEMENT"
    else:
        return "INVOICE"  # Default fallback

def parse_structure_type(response):
    """Parse bank statement structure type from VLM response."""
    response = response.strip().upper()
    if "FLAT" in response:
        return "flat"
    elif "GROUPED" in response or "DATE" in response:
        return "date_grouped"
    else:
        return "flat"  # Default fallback

def parse_extraction(extraction_text):
    """Parse extraction text into field dictionary."""
    extracted_fields = {}
    
    for line in extraction_text.split('\n'):
        line = line.strip()
        if ':' in line and not line.startswith('#'):
            parts = line.split(':', 1)
            if len(parts) == 2:
                field_name = parts[0].strip()
                field_value = parts[1].strip()
                extracted_fields[field_name] = field_value if field_value else 'NOT_FOUND'
    
    return extracted_fields

rprint("[green]✅ Parser functions defined[/green]")

In [7]:
#Cell 7
# Discover all images (no filtering by document type)
data_dir = Path(CONFIG["DATA_DIR"])
image_files = sorted(data_dir.glob("*.png"))

rprint(f"[green]✅ Found {len(image_files)} images to process[/green]")

rprint("[bold blue]Images to process:[/bold blue]")
for img in image_files:
    rprint(f"[cyan]  - {img.name}[/cyan]")

## Pipeline Stage 0-2 Functions

Define pipeline stages as functions that can be applied to DataFrame rows using pandas `.apply()`

In [8]:
#Cell 8
def stage_0_doctype_detection(row):
    """
    Stage 0: Document type detection.
    
    Args:
        row: DataFrame row with 'image_path' column
    
    Returns:
        dict: {'raw_response': str, 'processing_time': float, 'messages': list}
    """
    image_path = row['image_path']
    image = Image.open(image_path)
    images = [image]
    messages = []
    
    start_time = time.time()
    
    doctype_answer, messages = chat_with_mllm(
        model, processor, DOCTYPE_PROMPT, images, messages,
        max_new_tokens=CONFIG['MAX_NEW_TOKENS_DOCTYPE']
    )
    
    processing_time = time.time() - start_time
    image.close()
    
    # GPU cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return {
        'raw_response': doctype_answer,
        'processing_time': processing_time,
        'messages': messages
    }


def stage_1_structure_classification(row):
    """
    Stage 1: Structure classification for bank statements.
    
    Args:
        row: DataFrame row with 'image_path', 'document_type', 'messages_after_stage0'
    
    Returns:
        dict or None: {'raw_response': str, 'processing_time': float, 'messages': list}
                     or None if not a bank statement
    """
    if row['document_type'] != 'BANK_STATEMENT':
        return None
    
    image_path = row['image_path']
    image = Image.open(image_path)
    images = [image]
    messages = row['messages_after_stage0'].copy()
    
    start_time = time.time()
    
    structure_answer, messages = chat_with_mllm(
        model, processor, STRUCTURE_CLASSIFICATION_PROMPT, images, messages,
        max_new_tokens=CONFIG['MAX_NEW_TOKENS_STRUCTURE']
    )
    
    processing_time = time.time() - start_time
    image.close()
    
    # GPU cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return {
        'raw_response': structure_answer,
        'processing_time': processing_time,
        'messages': messages
    }


def stage_2_extraction(row):
    """
    Stage 2: Field extraction with document-type-aware prompts.
    
    Args:
        row: DataFrame row with all previous stage data
    
    Returns:
        dict: {'raw_response': str, 'processing_time': float, 'prompt_used': str}
    """
    image_path = row['image_path']
    document_type = row['document_type']
    structure_type = row['structure_type']
    
    # Determine which prompt to use
    if document_type == 'BANK_STATEMENT':
        extraction_prompt = BANK_PROMPTS[structure_type]
        prompt_key = f"bank_statement_{structure_type}"
        messages = row['messages_after_stage1'].copy()
    elif document_type == 'INVOICE':
        extraction_prompt = INVOICE_PROMPT
        prompt_key = "invoice"
        messages = row['messages_after_stage0'].copy()
    elif document_type == 'RECEIPT':
        extraction_prompt = RECEIPT_PROMPT
        prompt_key = "receipt"
        messages = row['messages_after_stage0'].copy()
    else:
        # Fallback
        extraction_prompt = INVOICE_PROMPT
        prompt_key = "invoice_fallback"
        messages = row['messages_after_stage0'].copy()
    
    image = Image.open(image_path)
    images = [image]
    
    start_time = time.time()
    
    extraction_result, messages = chat_with_mllm(
        model, processor, extraction_prompt, images, messages,
        max_new_tokens=CONFIG['MAX_NEW_TOKENS_EXTRACT']
    )
    
    processing_time = time.time() - start_time
    image.close()
    
    # GPU cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return {
        'raw_response': extraction_result,
        'processing_time': processing_time,
        'prompt_used': prompt_key
    }


# rprint("[green]✅ Pipeline stage functions 0-2 defined[/green]")

## Pipeline Stage 3-5 Functions
- now imported from pipeline_lib

In [9]:
#Cell 9
# Pipeline Stage Functions
#
# Stage 3-5 functions are now imported from pipeline_lib.stages
#
# Usage:
#   df["parsed_fields"] = df["extraction_response"].apply(
#       lambda x: stage_3_parsing(x, expected_fields=FIELD_COLUMNS)
#   )
#
#   df["cleaned_fields"] = df["parsed_fields"].apply(
#       lambda x: stage_4_cleaning(x, cleaner=cleaner)
#   )
#
#   df["evaluation"] = df.apply(
#       lambda row: stage_5_evaluation(row, ground_truth, expected_fields=FIELD_COLUMNS),
#       axis=1
#   )

rprint("[green]✅ Pipeline stage functions imported from pipeline_lib[/green]")

## Full Pipeline Function 
- run all stages sequentially on a DataFrame of images

In [10]:
#Cell 10
# ============================================================================
# BATCH PROCESSING CONFIGURATION
# ============================================================================
BATCH_SIZE = 1000  # Process images in batches (configurable)

# ============================================================================
# INITIALIZE PROCESSING
# ============================================================================
from tqdm.auto import tqdm

all_image_files = image_files  # All images to process
total_images = len(all_image_files)
num_batches = (total_images + BATCH_SIZE - 1) // BATCH_SIZE

rprint(f"\n[bold green]🚀 Starting batch pipeline processing...[/bold green]")
rprint(f"[cyan]  Total images: {total_images}[/cyan]")
rprint(f"[cyan]  Batch size: {BATCH_SIZE}[/cyan]")
rprint(f"[cyan]  Number of batches: {num_batches}[/cyan]\n")

# Store all batch results
all_results = []

# ============================================================================
# PROCESS EACH BATCH
# ============================================================================
for batch_num in range(num_batches):
    batch_start = batch_num * BATCH_SIZE
    batch_end = min(batch_start + BATCH_SIZE, total_images)
    batch_files = all_image_files[batch_start:batch_end]
    batch_size = len(batch_files)
    
    console.rule(f"[bold magenta]Batch {batch_num + 1}/{num_batches} ({batch_size} images)[/bold magenta]")
    
    # Initialize DataFrame for this batch
    df = pd.DataFrame({'image_path': [str(p) for p in batch_files]})
    df['image_name'] = df['image_path'].apply(lambda x: Path(x).name)
    
    # ------------------------------------------------------------------------
    # STAGE 0: Document Type Detection (GPU - use tqdm for OOM control)
    # ------------------------------------------------------------------------
    console.rule("[bold cyan]Stage 0: Document Type Detection[/bold cyan]")
    
    doctype_results = []
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Stage 0: Doc Type"):
        doctype_results.append(stage_0_doctype_detection(row))
    df['doctype_raw'] = doctype_results
    
    # Extract components
    df['doctype_response'] = df['doctype_raw'].apply(lambda x: x['raw_response'])
    df['doctype_time'] = df['doctype_raw'].apply(lambda x: x['processing_time'])
    df['messages_after_stage0'] = df['doctype_raw'].apply(lambda x: x['messages'])
    
    # Parse document type
    df['document_type'] = df['doctype_response'].apply(parse_document_type)
    
    # Summary
    doctype_counts = df['document_type'].value_counts().to_dict()
    rprint(f"[green]✅ Stage 0 complete: {len(df)} documents classified[/green]")
    rprint(f"[cyan]   Invoices: {doctype_counts.get('INVOICE', 0)}[/cyan]")
    rprint(f"[cyan]   Receipts: {doctype_counts.get('RECEIPT', 0)}[/cyan]")
    rprint(f"[cyan]   Bank Statements: {doctype_counts.get('BANK_STATEMENT', 0)}[/cyan]")
    show_pipeline_memory(df, "Stage 0")
    
    # ------------------------------------------------------------------------
    # STAGE 1: Structure Classification (GPU - use tqdm for OOM control)
    # ------------------------------------------------------------------------
    console.rule("[bold cyan]Stage 1: Structure Classification[/bold cyan]")
    
    structure_results = []
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Stage 1: Structure"):
        structure_results.append(stage_1_structure_classification(row))
    df['structure_raw'] = structure_results
    
    # Extract components (handle None for non-bank-statements)
    df['structure_response'] = df['structure_raw'].apply(
        lambda x: x['raw_response'] if x else 'N/A'
    )
    df['structure_time'] = df['structure_raw'].apply(
        lambda x: x['processing_time'] if x else 0
    )
    df['messages_after_stage1'] = df['structure_raw'].apply(
        lambda x: x['messages'] if x else None
    )
    
    # Parse structure type
    df['structure_type'] = df['structure_response'].apply(parse_structure_type)
    
    # Summary
    bank_count = (df['document_type'] == 'BANK_STATEMENT').sum()
    structure_counts = df[df['document_type'] == 'BANK_STATEMENT']['structure_type'].value_counts().to_dict()
    rprint(f"[green]✅ Stage 1 complete: {bank_count} bank statements classified[/green]")
    if bank_count > 0:
        rprint(f"[cyan]   Flat: {structure_counts.get('flat', 0)}[/cyan]")
        rprint(f"[cyan]   Date-grouped: {structure_counts.get('date_grouped', 0)}[/cyan]")
    show_pipeline_memory(df, "Stage 1")
    
    # ------------------------------------------------------------------------
    # STAGE 2: Extraction (GPU - CRITICAL: Manual loop for OOM control)
    # ------------------------------------------------------------------------
    console.rule("[bold cyan]Stage 2: Document-Type-Aware Extraction[/bold cyan]")
    
    extraction_results = []
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Stage 2: Extraction"):
        extraction_results.append(stage_2_extraction(row))
        
        # CRITICAL: Periodic garbage collection for GPU OOM control
        if (idx + 1) % 3 == 0:
            gc.collect()
    
    df['extraction_raw'] = extraction_results
    
    # Extract components
    df['extraction_response'] = df['extraction_raw'].apply(lambda x: x['raw_response'])
    df['extraction_time'] = df['extraction_raw'].apply(lambda x: x['processing_time'])
    df['prompt_used'] = df['extraction_raw'].apply(lambda x: x['prompt_used'])
    
    # Calculate total processing time
    df['total_time'] = df['doctype_time'] + df['structure_time'] + df['extraction_time']
    
    rprint(f"[green]✅ Stage 2 complete: {len(df)} extractions done[/green]")
    rprint(f"[cyan]   Average extraction time: {df['extraction_time'].mean():.2f}s[/cyan]")
    rprint(f"[cyan]   Total time (avg): {df['total_time'].mean():.2f}s[/cyan]")
    show_pipeline_memory(df, "Stage 2")
    
    # ------------------------------------------------------------------------
    # STAGE 3: Parsing (CPU - safe to use progress_apply)
    # ------------------------------------------------------------------------
    console.rule("[bold cyan]Stage 3: Parsing (Text → Fields)[/bold cyan]")
    
    tqdm.pandas(desc="Stage 3: Parsing")
    df['parsed_fields'] = df['extraction_response'].progress_apply(
        lambda x: stage_3_parsing(x, expected_fields=FIELD_COLUMNS)
    )
    
    # Count fields found
    df['fields_found'] = df['parsed_fields'].apply(
        lambda x: sum(1 for v in x.values() if v != 'NOT_FOUND')
    )
    
    rprint(f"[green]✅ Stage 3 complete: {len(df)} responses parsed[/green]")
    rprint(f"[cyan]   Average fields found: {df['fields_found'].mean():.1f}/{len(FIELD_COLUMNS)}[/cyan]")
    show_pipeline_memory(df, "Stage 3")
    
    # ------------------------------------------------------------------------
    # STAGE 4: Cleaning (CPU - safe to use progress_apply)
    # ------------------------------------------------------------------------
    console.rule("[bold cyan]Stage 4: Cleaning & Normalization[/bold cyan]")
    
    tqdm.pandas(desc="Stage 4: Cleaning")
    df['cleaned_fields'] = df['parsed_fields'].progress_apply(
        lambda x: stage_4_cleaning(x, cleaner=cleaner)
    )
    
    # Count fields after cleaning
    df['fields_cleaned'] = df['cleaned_fields'].apply(
        lambda x: sum(1 for v in x.values() if v != 'NOT_FOUND')
    )
    
    rprint(f"[green]✅ Stage 4 complete: {len(df)} field sets cleaned[/green]")
    rprint(f"[cyan]   Average fields cleaned: {df['fields_cleaned'].mean():.1f}/{len(FIELD_COLUMNS)}[/cyan]")
    show_pipeline_memory(df, "Stage 4")
    
    # ------------------------------------------------------------------------
    # STAGE 5: Evaluation (CPU - safe to use progress_apply)
    # ------------------------------------------------------------------------
    console.rule("[bold cyan]Stage 5: Evaluation (Optional)[/bold cyan]")
    
    # Load ground truth if available
    if not CONFIG['INFERENCE_ONLY'] and CONFIG.get('GROUND_TRUTH'):
        if batch_num == 0:  # Load ground truth only once
            rprint("[cyan]Loading ground truth for evaluation...[/cyan]")
            ground_truth = load_ground_truth(CONFIG['GROUND_TRUTH'], verbose=False)
            rprint(f"[green]✅ Ground truth loaded for {len(ground_truth)} images[/green]")
        
        # Apply evaluation with progress bar
        tqdm.pandas(desc="Stage 5: Evaluation")
        df['evaluation'] = df.progress_apply(
            lambda row: stage_5_evaluation(row, ground_truth, expected_fields=FIELD_COLUMNS),
            axis=1
        )
        
        # Extract accuracy metrics
        df['overall_accuracy'] = df['evaluation'].apply(
            lambda x: x.get('overall_accuracy', 0) * 100 if x and 'error' not in x else None
        )
        df['fields_matched'] = df['evaluation'].apply(
            lambda x: x.get('fields_matched', 0) if x and 'error' not in x else None
        )
        df['fields_extracted'] = df['evaluation'].apply(
            lambda x: x.get('fields_extracted', 0) if x and 'error' not in x else None
        )
        
        rprint(f"[green]✅ Stage 5 complete: {len(df)} extractions evaluated[/green]")
        rprint(f"[cyan]   Average accuracy: {df['overall_accuracy'].mean():.2f}%[/cyan]")
        rprint(f"[cyan]   Median accuracy: {df['overall_accuracy'].median():.2f}%[/cyan]")
    else:
        df['evaluation'] = None
        df['overall_accuracy'] = None
        df['fields_matched'] = None
        df['fields_extracted'] = None
        if batch_num == 0:
            rprint("[yellow]⚠️  Inference-only mode - skipping evaluation[/yellow]")
    
    show_pipeline_memory(df, "Stage 5")
    
    # ------------------------------------------------------------------------
    # SAVE BATCH CHECKPOINT
    # ------------------------------------------------------------------------
    batch_checkpoint = checkpoint_dir / f'batch_{batch_num + 1:04d}_{TIMESTAMP}.pkl'
    df.to_pickle(batch_checkpoint)
    rprint(f"[green]✅ Batch {batch_num + 1} checkpoint saved: {batch_checkpoint.name}[/green]")
    
    # Store batch result
    all_results.append(df)
    
    # ------------------------------------------------------------------------
    # MEMORY CLEANUP BETWEEN BATCHES
    # ------------------------------------------------------------------------
    if batch_num < num_batches - 1:  # Not the last batch
        rprint("[yellow]🧹 Cleaning up memory before next batch...[/yellow]")
        del df
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        rprint("[green]✅ Memory cleanup complete[/green]\n")

# ============================================================================
# COMBINE ALL BATCHES
# ============================================================================
console.rule("[bold green]Combining All Batches[/bold green]")

df = pd.concat(all_results, ignore_index=True)

rprint(f"[bold green]✅ All {len(df)} images processed through pipeline[/bold green]")
rprint(f"[cyan]   Total processing time: {df['total_time'].sum():.2f}s[/cyan]")
rprint(f"[cyan]   Average per image: {df['total_time'].mean():.2f}s[/cyan]")

# Final checkpoint with combined data
final_checkpoint = checkpoint_dir / f'final_combined_{TIMESTAMP}.pkl'
df.to_pickle(final_checkpoint)
rprint(f"[green]✅ Final combined checkpoint: {final_checkpoint.name}[/green]")

console.rule("[bold green]Pipeline Complete[/bold green]")

Stage 0: Doc Type:   0%|          | 0/9 [00:00<?, ?it/s]

Stage 1: Structure:   0%|          | 0/9 [00:00<?, ?it/s]

Stage 2: Extraction:   0%|          | 0/9 [00:00<?, ?it/s]

Stage 3: Parsing:   0%|          | 0/9 [00:00<?, ?it/s]

Stage 4: Cleaning:   0%|          | 0/9 [00:00<?, ?it/s]

🧹 CLEANER CALLED: DOCUMENT_TYPE: 'RECEIPT' -> 'RECEIPT'
🧹 CLEANER CALLED: BUSINESS_ABN: '08 082 698 025' -> '08 082 698 025'
🧹 CLEANER CALLED: SUPPLIER_NAME: 'Liberty Oil' -> 'Liberty Oil'
🧹 CLEANER CALLED: BUSINESS_ADDRESS: '481 Bourke Street  Perth WA 6000' -> 🏠 Address cleaned: '481 Bourke Street  Perth WA 6000' -> '481 Bourke Street Perth WA 6000'
'481 Bourke Street Perth WA 6000'
🧹 CLEANER CALLED: PAYER_NAME: 'Robert Taylor' -> 'Robert Taylor'
🧹 CLEANER CALLED: PAYER_ADDRESS: '243 Adelaide Street  Perth WA 6000' -> 🏠 Address cleaned: '243 Adelaide Street  Perth WA 6000' -> '243 Adelaide Street Perth WA 6000'
'243 Adelaide Street Perth WA 6000'
🧹 CLEANER CALLED: INVOICE_DATE: '05/08/2025' -> '05/08/2025'
🧹 CLEANER CALLED: LINE_ITEM_DESCRIPTIONS: 'Car Wash | Coffee Large | Unleaded Petrol | Car Wash | Diesel' -> 'Car Wash | Coffee Large | Unleaded Petrol | Car Wash | Diesel'
🧹 CLEANER CALLED: LINE_ITEM_QUANTITIES: '3 | 1 | 1 | 2 | 3' -> '3 | 1 | 1 | 2 | 3'
🧹 CLEANER CALLED: LINE_ITE

Stage 5: Evaluation:   0%|          | 0/9 [00:00<?, ?it/s]

## Export results

In [11]:
#Cell 11
# ============================================================================
# EXPORT RESULTS WITH EXPANDED FIELDS
# ============================================================================
console.rule("[bold blue]Exporting Results[/bold blue]")

# Create export data with individual field columns
export_data = []

for idx, row in df.iterrows():
    record = {
        'image_file': row['image_name'],
        'document_type': row['document_type'],
        'structure_type': row['structure_type'],
        'prompt_used': row['prompt_used'],
        'total_time': row['total_time'],
        'fields_found': row['fields_found'],
        'fields_cleaned': row['fields_cleaned'],
    }
    
    # Add all cleaned field values as individual columns
    cleaned = row['cleaned_fields']
    for field in FIELD_COLUMNS:
        record[field] = cleaned.get(field, 'NOT_FOUND')
    
    # Add evaluation metrics if available
    if row['evaluation'] is not None and row['overall_accuracy'] is not None:
        record['overall_accuracy'] = row['overall_accuracy']
        record['fields_matched'] = row['fields_matched']
        record['fields_extracted'] = row['fields_extracted']
    
    export_data.append(record)

export_df = pd.DataFrame(export_data)

# Save CSV (compatible with model_comparison.ipynb)
csv_output = output_dir / f"llama_pipeline_results_{TIMESTAMP}.csv"
export_df.to_csv(csv_output, index=False)

rprint(f"[green]✅ CSV exported: {csv_output}[/green]")
rprint(f"[cyan]   Rows: {len(export_df)}[/cyan]")
rprint(f"[cyan]   Columns: {len(export_df.columns)}[/cyan]")
rprint(f"[cyan]   Pattern: *llama*pipeline*results*.csv (compatible with model_comparison.ipynb)[/cyan]")

# Save detailed JSON with full pipeline data
json_data = []

for idx, row in df.iterrows():
    record = {
        'image_file': row['image_name'],
        'pipeline_stages': {
            'stage_0_doctype': {
                'raw_response': row['doctype_response'],
                'document_type': row['document_type'],
                'processing_time': row['doctype_time']
            },
            'stage_1_structure': {
                'raw_response': row['structure_response'],
                'structure_type': row['structure_type'],
                'processing_time': row['structure_time']
            },
            'stage_2_extraction': {
                'raw_response': row['extraction_response'],
                'prompt_used': row['prompt_used'],
                'processing_time': row['extraction_time']
            },
            'stage_3_parsing': {
                'parsed_fields': row['parsed_fields']
            },
            'stage_4_cleaning': {
                'cleaned_fields': row['cleaned_fields']
            }
        },
        'total_processing_time': row['total_time']
    }
    
    if row['evaluation'] is not None:
        record['stage_5_evaluation'] = row['evaluation']
    
    json_data.append(record)

json_output = output_dir / f"llama_pipeline_full_{TIMESTAMP}.json"
with open(json_output, 'w') as f:
    json.dump(json_data, f, indent=2)

rprint(f"[green]✅ JSON exported: {json_output}[/green]")
rprint(f"[cyan]   Full pipeline data saved for all {len(json_data)} images[/cyan]")

# Display sample of exported data
rprint("\n[bold blue]📋 Sample Exported Data:[/bold blue]")
if not CONFIG['INFERENCE_ONLY']:
    sample_cols = ['image_file', 'document_type', 'overall_accuracy', 'total_time', 'fields_cleaned']
else:
    sample_cols = ['image_file', 'document_type', 'total_time', 'fields_cleaned']

rprint(export_df[sample_cols].head(3).to_string(index=False))

## Pipeline Debugging Utilities

Functions to inspect and debug the pipeline at any stage

In [12]:
#Cell 12
def inspect_pipeline(df, image_name, stages=['all']):
    """
    Inspect all pipeline stages for a specific image.
    
    Args:
        df: DataFrame with pipeline results
        image_name: Name of image to inspect (e.g., 'image_003.png')
        stages: List of stages to show or 'all'
    """
    row = df[df['image_name'] == image_name].iloc[0]
    
    print(f"\n{'='*80}")
    print(f"🔍 Pipeline Inspection: {image_name}")
    print(f"{'='*80}")
    
    if 'all' in stages or 'doctype' in stages:
        print(f"\n[STAGE 0: Document Type Detection]")
        print(f"Response: {row['doctype_response']}")
        print(f"Parsed: {row['document_type']}")
        print(f"Time: {row['doctype_time']:.2f}s")
    
    if 'all' in stages or 'structure' in stages:
        print(f"\n[STAGE 1: Structure Classification]")
        print(f"Response: {row['structure_response']}")
        print(f"Parsed: {row['structure_type']}")
        print(f"Time: {row['structure_time']:.2f}s")
    
    if 'all' in stages or 'extraction' in stages:
        print(f"\n[STAGE 2: Extraction]")
        print(f"Prompt Used: {row['prompt_used']}")
        response = row['extraction_response']
        print(f"Response ({len(response)} chars):")
        print(response[:500] + "..." if len(response) > 500 else response)
        print(f"Time: {row['extraction_time']:.2f}s")
    
    if 'all' in stages or 'parsing' in stages:
        print(f"\n[STAGE 3: Parsing]")
        parsed = row['parsed_fields']
        found_fields = [k for k, v in parsed.items() if v != 'NOT_FOUND']
        print(f"Found {len(found_fields)}/{len(parsed)} fields:")
        for field in found_fields[:5]:
            value = parsed[field]
            print(f"  {field}: {value[:50]}..." if len(value) > 50 else f"  {field}: {value}")
    
    if 'all' in stages or 'cleaning' in stages:
        print(f"\n[STAGE 4: Cleaning]")
        cleaned = row['cleaned_fields']
        found_fields = [k for k, v in cleaned.items() if v != 'NOT_FOUND']
        print(f"Cleaned {len(found_fields)}/{len(cleaned)} fields:")
        
        # Show before/after for changed fields
        changes_shown = 0
        for field in found_fields:
            parsed_val = row['parsed_fields'][field]
            cleaned_val = cleaned[field]
            if parsed_val != cleaned_val:
                print(f"  {field}:")
                print(f"    Before: {parsed_val[:50]}..." if len(parsed_val) > 50 else f"    Before: {parsed_val}")
                print(f"    After:  {cleaned_val[:50]}..." if len(cleaned_val) > 50 else f"    After:  {cleaned_val}")
                changes_shown += 1
                if changes_shown >= 5:
                    break
    
    if 'all' in stages or 'evaluation' in stages:
        if row['evaluation'] is not None:
            print(f"\n[STAGE 5: Evaluation]")
            print(f"Overall Accuracy: {row['overall_accuracy']:.2f}%")
            print(f"Fields Matched: {row['fields_matched']}/{row['fields_extracted']}")
    
    print(f"\n{'='*80}\n")


def compare_parsing_cleaning(df, image_name):
    """Show side-by-side comparison of parsed vs cleaned fields."""
    row = df[df['image_name'] == image_name].iloc[0]
    
    parsed = row['parsed_fields']
    cleaned = row['cleaned_fields']
    
    print(f"\n📊 Parsing vs Cleaning Comparison: {image_name}")
    print(f"{'Field':<30} {'Parsed':<40} {'Cleaned':<40}")
    print("="*110)
    
    for field in FIELD_COLUMNS:
        parsed_val = parsed.get(field, 'NOT_FOUND')
        cleaned_val = cleaned.get(field, 'NOT_FOUND')
        
        if parsed_val != cleaned_val:
            p_display = parsed_val[:37] + "..." if len(parsed_val) > 40 else parsed_val
            c_display = cleaned_val[:37] + "..." if len(cleaned_val) > 40 else cleaned_val
            
            print(f"{field:<30} {p_display:<40} {c_display:<40}")


def field_coverage_report(df):
    """Generate field coverage statistics across all images."""
    print("\n" + "="*80)
    print("📈 FIELD COVERAGE REPORT")
    print("="*80)
    
    coverage_data = []
    
    for field in FIELD_COLUMNS:
        parsed_count = sum(
            1 for idx, row in df.iterrows()
            if row['parsed_fields'].get(field, 'NOT_FOUND') != 'NOT_FOUND'
        )
        cleaned_count = sum(
            1 for idx, row in df.iterrows()
            if row['cleaned_fields'].get(field, 'NOT_FOUND') != 'NOT_FOUND'
        )
        
        parsed_pct = (parsed_count / len(df)) * 100
        cleaned_pct = (cleaned_count / len(df)) * 100
        
        coverage_data.append({
            'Field': field,
            'Parsed': f"{parsed_count}/{len(df)} ({parsed_pct:.1f}%)",
            'Cleaned': f"{cleaned_count}/{len(df)} ({cleaned_pct:.1f}%)",
            'Change': cleaned_count - parsed_count
        })
    
    coverage_df = pd.DataFrame(coverage_data)
    print(coverage_df.to_string(index=False))
    print("="*80 + "\n")


rprint("[green]✅ Debugging utilities defined[/green]")

## Pipeline Summary and Next Steps

Final summary of pipeline execution and tips for working with large batches

In [13]:
#Cell 13
# ============================================================================
# PIPELINE SUMMARY
# ============================================================================
console.rule("[bold green]Pipeline Execution Summary[/bold green]")

print("\n📊 LLAMA PIPELINE PROCESSING SUMMARY")
print("="*80)
print(f"Total images processed: {len(df)}")
print(f"Mode: {'Inference-only' if CONFIG['INFERENCE_ONLY'] else 'Evaluation mode'}")
print()

# Document type distribution
print("Document Type Distribution:")
for doc_type, count in df['document_type'].value_counts().items():
    print(f"  {doc_type}: {count}")
print()

# Bank statement structure distribution
bank_count = (df['document_type'] == 'BANK_STATEMENT').sum()
if bank_count > 0:
    print("Bank Statement Structure Distribution:")
    for struct_type, count in df[df['document_type'] == 'BANK_STATEMENT']['structure_type'].value_counts().items():
        print(f"  {struct_type}: {count}")
    print()

# Processing time statistics
print("Processing Time Statistics:")
print(f"  Total time: {df['total_time'].sum():.2f}s")
print(f"  Average per image: {df['total_time'].mean():.2f}s")
print(f"  Min: {df['total_time'].min():.2f}s")
print(f"  Max: {df['total_time'].max():.2f}s")
print()

# Field extraction statistics
print("Field Extraction Statistics:")
print(f"  Average fields parsed: {df['fields_found'].mean():.1f}/{len(FIELD_COLUMNS)}")
print(f"  Average fields cleaned: {df['fields_cleaned'].mean():.1f}/{len(FIELD_COLUMNS)}")
print()

# Accuracy statistics (if available)
if not CONFIG['INFERENCE_ONLY']:
    print("Accuracy Statistics:")
    print(f"  Average accuracy: {df['overall_accuracy'].mean():.2f}%")
    print(f"  Median accuracy: {df['overall_accuracy'].median():.2f}%")
    print(f"  Min accuracy: {df['overall_accuracy'].min():.2f}%")
    print(f"  Max accuracy: {df['overall_accuracy'].max():.2f}%")
    print()

print("="*80)

# ============================================================================
# TIPS FOR LARGE BATCHES (10,000+ images)
# ============================================================================
print("\n💡 TIPS FOR PROCESSING LARGE BATCHES (10,000+ images)")
print("="*80)
print("""
1. CHECKPOINTING:
   - Checkpoints are automatically saved after stages 2, 3, and 4
   - Resume from checkpoint:
     df = pd.read_pickle('checkpoints/stage3_parsed_TIMESTAMP.pkl')
     df['cleaned_fields'] = df['parsed_fields'].apply(
    lambda x: stage_4_cleaning(x, cleaner=cleaner)
)

2. BATCH PROCESSING:
   - Process in batches of 1000 images to manage memory
   - Use df.iloc[start:end] to process subsets
   
3. PARALLEL PROCESSING:
   - Install pandarallel: pip install pandarallel
   - Use parallel_apply for stages 3 and 4 (CPU-bound)
   - Stages 0-2 (GPU-bound) must remain sequential

4. MEMORY MANAGEMENT:
   - Periodic garbage collection already enabled (every 3 images)
   - GPU cache clearing after each image
   - Monitor with: nvidia-smi

5. INSPECTION AND DEBUGGING:
   - Use inspect_pipeline(df, 'image_name') to debug specific images
   - Use compare_parsing_cleaning(df, 'image_name') to see cleaning effects
   - Use field_coverage_report(df) for overall statistics

6. OUTPUT FILES:
   - CSV: {csv_output}
   - JSON: {json_output}
   - Checkpoints: {checkpoint_dir}/
""")
print("="*80)

# ============================================================================
# DATAFRAME COLUMN REFERENCE
# ============================================================================
print("\n📋 DATAFRAME COLUMN REFERENCE")
print("="*80)
print("""
PIPELINE STAGES (dict objects):
  - doctype_raw: {'raw_response', 'processing_time', 'messages'}
  - structure_raw: {'raw_response', 'processing_time', 'messages'} or None
  - extraction_raw: {'raw_response', 'processing_time', 'prompt_used'}
  - parsed_fields: {field_name: value} - 17 fields
  - cleaned_fields: {field_name: cleaned_value} - 17 fields
  - evaluation: {metrics} or None

EXTRACTED COMPONENTS (primitives):
  - image_path, image_name: str
  - document_type: 'INVOICE' | 'RECEIPT' | 'BANK_STATEMENT'
  - structure_type: 'flat' | 'date_grouped' | 'N/A'
  - doctype_response, extraction_response: str (raw VLM output)
  - doctype_time, structure_time, extraction_time, total_time: float (seconds)
  - fields_found, fields_cleaned: int (count of non-NOT_FOUND fields)
  - overall_accuracy, fields_matched, fields_extracted: float/int (if evaluation)

ACCESSING DATA:
  - Full extraction: df.loc[0, 'cleaned_fields']
  - Single field: df.loc[0, 'cleaned_fields']['SUPPLIER_NAME']
  - List field as array: df.loc[0, 'cleaned_fields']['LINE_ITEM_DESCRIPTIONS'].split(' | ')
""")
print("="*80)

rprint("\n[bold green]🎉 Pipeline processing complete! Use the debugging utilities above to inspect results.[/bold green]")


📊 LLAMA PIPELINE PROCESSING SUMMARY
Total images processed: 9
Mode: Evaluation mode

Document Type Distribution:
  RECEIPT: 3
  BANK_STATEMENT: 3
  INVOICE: 3

Bank Statement Structure Distribution:
  date_grouped: 2
  flat: 1

Processing Time Statistics:
  Total time: 90.87s
  Average per image: 10.10s
  Min: 6.32s
  Max: 28.35s

Field Extraction Statistics:
  Average fields parsed: 11.0/17
  Average fields cleaned: 11.0/17

Accuracy Statistics:
  Average accuracy: 72.22%
  Median accuracy: 85.71%
  Min accuracy: 28.57%
  Max accuracy: 100.00%


💡 TIPS FOR PROCESSING LARGE BATCHES (10,000+ images)

1. CHECKPOINTING:
   - Checkpoints are automatically saved after stages 2, 3, and 4
   - Resume from checkpoint:
     df = pd.read_pickle('checkpoints/stage3_parsed_TIMESTAMP.pkl')
     df['cleaned_fields'] = df['parsed_fields'].apply(
    lambda x: stage_4_cleaning(x, cleaner=cleaner)
)

2. BATCH PROCESSING:
   - Process in batches of 1000 images to manage memory
   - Use df.iloc[start:en

## View Individual Extraction

Uncomment and run to view detailed extraction for a specific image:

In [14]:
#Cell 14
# View detailed extraction for specific image (using pandas DataFrame)
image_to_view = "image_003.png"  # Change this

if 'df' in dir() and len(df) > 0:
    row = df[df['image_name'] == image_to_view]

    if len(row) > 0:
        row = row.iloc[0]
        print(f"\n🔍 Detailed Extraction: {image_to_view}")
        print("="*80)
        print(f"Document Type: {row['document_type']}")
        print(f"Structure Type: {row['structure_type']}")
        print(f"Prompt Used: {row['prompt_used']}")
        print(f"\nDocument Type Classification Response:")
        print(row['doctype_response'])
        print(f"\nStructure Classification Response:")
        print(row['structure_response'])
        print(f"\nExtraction Result:")
        print(row['extraction_response'])
        print("="*80)
    else:
        print(f"Image {image_to_view} not found in DataFrame")
else:
    print("⚠️ DataFrame 'df' not found - run Cell 10 (main pipeline) first")


🔍 Detailed Extraction: image_003.png
Document Type: BANK_STATEMENT
Structure Type: flat
Prompt Used: bank_statement_flat

Document Type Classification Response:
This is a bank statement.

Structure Classification Response:
FLAT

Extraction Result:
DOCUMENT_TYPE: BANK_STATEMENT
STATEMENT_DATE_RANGE: 03/05/2025 to 10/05/2025
TRANSACTION_DATES: 03/05/2025 | 04/05/2025 | 05/05/2025 | 06/05/2025 | 07/05/2025 | 08/05/2025 | 09/05/2025 | 10/05/2025
LINE_ITEM_DESCRIPTIONS: ONLINE PURCHASE AMAZON AU | EFTPOS PURCHASE COLES EXP | EFTPOS PURCHASE COLES EXP | DIRECT CREDIT SALARY | ATM WITHDRAWAL ANZ ATM | EFTPOS PURCHASE COLES EXP | INTEREST PAYMENT | ATM WITHDRAWAL ANZ ATM
TRANSACTION_AMOUNTS_PAID: $288.03 | $22.50 | $114.66 | $187.59 | $112.50 | $5.16 | $146.72
