In [None]:
#Cell 0: Imports
import sys
from pathlib import Path

import torch
import yaml
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration

# Add project root to path for imports
sys.path.insert(0, '/home/jovyan/nfs_share/tod/LMM_POC')
from common.header_mapping import (
    generate_extraction_instruction,
    map_headers_to_fields,
    validate_mapping,
)


In [None]:
#Cell 1: Load both classifier prompts
doc_type_prompt_path = Path("/home/jovyan/nfs_share/tod/LMM_POC/prompts/document_type_classifier.yaml")
header_prompt_path = Path("/home/jovyan/nfs_share/tod/LMM_POC/prompts/bank_statement_structure_classifier_simple.yaml")

print("📄 Loading document type classifier...")
with doc_type_prompt_path.open("r", encoding="utf-8") as f:
    doc_type_config = yaml.safe_load(f)
print(f"✅ {doc_type_config['name']} v{doc_type_config['version']}")

print("\n📄 Loading header classifier...")
with header_prompt_path.open("r", encoding="utf-8") as f:
    header_config = yaml.safe_load(f)
print(f"✅ {header_config['name']} v{header_config['version']}")


In [None]:
#Cell 2: Load Llama-3.2-Vision model
model_id = "/home/jovyan/nfs_share/models/Llama-3.2-11B-Vision-Instruct"

print("🔧 Loading Llama-3.2-Vision model...")
model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

print("✅ Model loaded successfully!")

In [None]:
#Cell 3: Load test bank statement image
image_path = "/home/jovyan/nfs_share/tod/LMM_POC/evaluation_data/image_003.png"

print("📷 Loading bank statement image...")
image = Image.open(image_path)
print(f"✅ Image loaded: {image.size}")
print(f"📁 Image path: {image_path}")

In [None]:
#Cell 4: STAGE 1 - Document Type Classification
print("="*60)
print("STAGE 1: DOCUMENT TYPE CLASSIFICATION")
print("="*60)

# Build document type classification prompt
doc_type_prompt = f"{doc_type_config['instruction']}\n\n{doc_type_config['output_format']}"

# Generate classification
doc_type_message = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": doc_type_prompt},
        ],
    }
]

print("🤖 Classifying document type...")
doc_type_text_input = processor.apply_chat_template(
    doc_type_message, add_generation_prompt=True
)
doc_type_inputs = processor(image, doc_type_text_input, return_tensors="pt").to(model.device)
doc_type_output = model.generate(**doc_type_inputs, max_new_tokens=100)
doc_type_result = processor.decode(doc_type_output[0])

# Extract clean result
if "<|start_header_id|>assistant<|end_header_id|>" in doc_type_result:
    doc_type_clean = doc_type_result.split("<|start_header_id|>assistant<|end_header_id|>")[1]
    doc_type_clean = doc_type_clean.replace("<|eot_id|>", "").strip()
else:
    doc_type_clean = doc_type_result

# Determine classification
if "Mobile_APP" in doc_type_clean:
    document_type = "Mobile_APP"
elif "BANK_STATEMENT" in doc_type_clean:
    document_type = "BANK_STATEMENT"
else:
    document_type = "UNKNOWN"

print(f"✅ Document Type: {document_type}")
print(f"📝 Raw response: {doc_type_clean}")


In [None]:
#Cell 5: STAGE 2 - Conditional Header Extraction
print("\n" + "="*60)
print("STAGE 2: HEADER EXTRACTION")
print("="*60)

headers_pipe_separated = None

if document_type == "BANK_STATEMENT":
    print("✅ Bank statement detected - extracting headers...")
    
    # Build header extraction prompt
    header_prompt = f"{header_config['instruction']}\n\n{header_config['output_format']}"
    
    # Generate header extraction
    header_message = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": header_prompt},
            ],
        }
    ]
    
    print("🤖 Extracting transaction table headers...")
    header_text_input = processor.apply_chat_template(
        header_message, add_generation_prompt=True
    )
    header_inputs = processor(image, header_text_input, return_tensors="pt").to(model.device)
    header_output = model.generate(**header_inputs, max_new_tokens=2000)
    header_result = processor.decode(header_output[0])
    
    # Extract clean response
    if "<|start_header_id|>assistant<|end_header_id|>" in header_result:
        header_clean = header_result.split("<|start_header_id|>assistant<|end_header_id|>")[1]
        header_clean = header_clean.replace("<|eot_id|>", "").strip()
    else:
        header_clean = header_result
    
    # Parse headers from response
    if "NO_HEADERS" in header_clean:
        headers_pipe_separated = "NO_HEADERS"
    else:
        # Look for header keywords
        header_keywords = ['Date', 'Transaction', 'Description', 'Amount', 'Debit', 'Credit', 'Balance', 'Particulars']
        
        header_text = None
        for line in header_clean.split('\n'):
            line = line.strip()
            if not line:
                continue
            if ',' in line and any(keyword in line for keyword in header_keywords):
                header_text = line
                break
        
        if not header_text:
            lines = [l.strip() for l in header_clean.split('\n') if l.strip()]
            if lines:
                header_text = lines[-1]
            else:
                header_text = ""
        
        # Clean up and convert to pipe-separated
        header_text = header_text.rstrip('.')
        headers_list = [h.strip() for h in header_text.split(',') if h.strip()]
        headers_pipe_separated = " | ".join(headers_list)
    
    print(f"✅ Headers extracted: {headers_pipe_separated}")
    
elif document_type == "Mobile_APP":
    print("ℹ️  Mobile app detected - skipping header extraction")
    headers_pipe_separated = "N/A"
else:
    print("⚠️  Unknown document type - skipping header extraction")
    headers_pipe_separated = "UNKNOWN"


In [None]:
#Cell 6: Display Final Results
print("\n" + "="*60)
print("FINAL CLASSIFICATION RESULTS")
print("="*60)
print(f"📄 Image: {image_path}")
print(f"🏷️  Document Type: {document_type}")
print(f"📋 Headers: {headers_pipe_separated}")
print("="*60)

# Store results for potential saving
classification_result = {
    "image_path": str(image_path),
    "document_type": document_type,
    "headers": headers_pipe_separated
}


In [None]:
#Cell 7: Save Classification Results (Optional)
import json

output_dir = Path("classification_results")
output_dir.mkdir(exist_ok=True)

# Save as JSON
json_path = output_dir / "classification_result.json"
with json_path.open("w", encoding="utf-8") as f:
    json.dump(classification_result, f, indent=2)

print(f"✅ Results saved to: {json_path}")

# Also save as simple text format
text_path = output_dir / "classification_result.txt"
with text_path.open("w", encoding="utf-8") as f:
    f.write(f"Document Type: {document_type}\n")
    f.write(f"Headers: {headers_pipe_separated}\n")

print(f"✅ Text results saved to: {text_path}")


In [None]:
#Cell 8: STAGE 3 - Smart Header Mapping (Only if BANK_STATEMENT)
print("\n" + "="*60)
print("STAGE 3: SMART HEADER MAPPING")
print("="*60)

if document_type == "BANK_STATEMENT" and headers_pipe_separated not in ["NO_HEADERS", "N/A", "UNKNOWN"]:
    print("🧠 Mapping headers to semantic fields...")
    
    # Map headers to fields
    field_mapping = map_headers_to_fields(headers_pipe_separated)
    
    print("\n📋 Header Mapping Results:")
    for field, column_name in field_mapping.items():
        status = "✅" if column_name else "❌"
        print(f"  {status} {field}: {column_name or 'NOT FOUND'}")
    
    # Validate that we have the required fields for extraction
    is_valid, missing_fields = validate_mapping(field_mapping, required_fields=['DATE', 'DESCRIPTION', 'DEBIT'])
    
    if is_valid:
        print("\n✅ All required fields mapped successfully!")
        can_extract = True
    else:
        print(f"\n⚠️  WARNING: Missing required fields: {missing_fields}")
        print("   Extraction will proceed but may be incomplete.")
        can_extract = True  # Still try extraction with available fields
else:
    print("ℹ️  Skipping header mapping (not a bank statement or no headers detected)")
    field_mapping = None
    can_extract = False


In [None]:
#Cell 9: STAGE 4 - Transaction Extraction (Only if mapping successful)
print("\n" + "="*60)
print("STAGE 4: TRANSACTION EXTRACTION")
print("="*60)

extracted_transactions = None

if can_extract and field_mapping:
    print("💰 Extracting transactions (Date, Description, Debit)...")
    
    # Load extraction template
    extraction_template_path = Path("/home/jovyan/nfs_share/tod/LMM_POC/prompts/transaction_extraction_template.yaml")
    with extraction_template_path.open("r", encoding="utf-8") as f:
        extraction_config = yaml.safe_load(f)
    
    print(f"✅ Loaded: {extraction_config['name']}")
    
    # Generate dynamic instruction using mapped headers
    extraction_instruction = generate_extraction_instruction(field_mapping, headers_pipe_separated)
    
    # Note: Anti-hallucination rules are now embedded in generate_extraction_instruction()
    # but we can also include them separately if defined in the YAML
    extraction_output_format = extraction_config['output_format']
    
    extraction_prompt = f"{extraction_instruction}\n\n{extraction_output_format}"
    
    print(f"📏 Prompt length: {len(extraction_prompt)} characters")
    
    # Generate extraction
    extraction_message = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": extraction_prompt},
            ],
        }
    ]
    
    print("🤖 Generating extraction with Llama-3.2-Vision...")
    extraction_text_input = processor.apply_chat_template(
        extraction_message, add_generation_prompt=True
    )
    extraction_inputs = processor(image, extraction_text_input, return_tensors="pt").to(model.device)
    extraction_output = model.generate(**extraction_inputs, max_new_tokens=3000)
    extraction_result = processor.decode(extraction_output[0])
    
    # Extract clean response
    if "<|start_header_id|>assistant<|end_header_id|>" in extraction_result:
        extracted_transactions = extraction_result.split("<|start_header_id|>assistant<|end_header_id|>")[1]
        extracted_transactions = extracted_transactions.replace("<|eot_id|>", "").strip()
    else:
        extracted_transactions = extraction_result
    
    print("✅ Extraction complete!")
    
else:
    print("ℹ️  Skipping extraction (not applicable for this document type)")


In [None]:
#Cell 10: Display and Save Extracted Transactions
import json

if extracted_transactions:
    print("\n" + "="*60)
    print("EXTRACTED TRANSACTIONS (PIPE-SEPARATED FORMAT)")
    print("="*60)
    print(extracted_transactions)
    print("="*60)
    
    # Count transactions (exclude header line)
    transaction_lines = [line for line in extracted_transactions.split('\n') if line.strip() and not line.startswith('Date | Description | Debit')]
    transaction_count = len(transaction_lines)
    print(f"\n📊 Total transactions extracted: {transaction_count}")
    
    # Save to file
    output_dir = Path("classification_results")
    output_dir.mkdir(exist_ok=True)
    
    # Save as pipe-separated file
    psv_path = output_dir / "extracted_transactions.psv"
    with psv_path.open("w", encoding="utf-8") as f:
        f.write(extracted_transactions)
    print(f"✅ Pipe-separated file saved to: {psv_path}")
    
    # Save complete results as JSON
    complete_results = {
        "image_path": str(image_path),
        "document_type": document_type,
        "headers_detected": headers_pipe_separated,
        "field_mapping": field_mapping,
        "transaction_count": transaction_count,
        "transactions_psv": extracted_transactions
    }
    
    json_path = output_dir / "complete_extraction_results.json"
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(complete_results, f, indent=2)
    print(f"✅ Complete results saved to: {json_path}")
    
else:
    print("\n" + "="*60)
    print("NO TRANSACTIONS EXTRACTED")
    print("="*60)
    print("Reason: Document is not a bank statement or extraction was skipped")
