# Llama 3.2 Vision Single-Column Multi-Turn Extraction

This notebook implements a **single-column extraction strategy** to avoid column merging issues observed with standard table extraction.

**Strategy:**
- 5 separate conversation turns, each extracting ONE column only
- Turn 1: Date column
- Turn 2: Transaction/Description column
- Turn 3: Debit column
- Turn 4: Credit column
- Turn 5: Balance column

**Reference:** Based on OCRBench v2 findings that VLMs struggle with spatial column reasoning [arXiv:2412.20662v2]

## Imports

In [None]:
# Add project root to path for common/ imports
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import random
import numpy as np
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from transformers.image_utils import load_image
from tqdm.notebook import tqdm
from IPython.display import display, Markdown

## Pre-emptive Memory Cleanup

Optional GPU memory cleanup to prevent OOM errors when switching between models.

In [None]:
# Optional: Pre-emptive memory cleanup (useful when switching models)
try:
    from common.gpu_optimization import emergency_cleanup
    print("🧹 Clearing GPU memory...")
    emergency_cleanup(verbose=False)
    print("✅ Memory cleanup complete")
except ImportError:
    print("⚠️ GPU optimization module not available - skipping cleanup")

## Set Random Seed for Reproducibility

In [None]:
from common.reproducibility import set_seed
set_seed(42)

## Load Model

In [None]:
model_id = "/home/jovyan/shared_PTM/Llama-3.2-11B-Vision-Instruct"

print("🔧 Loading Llama-3.2-Vision model...")

from common.llama_model_loader_robust import load_llama_model_robust

model, processor = load_llama_model_robust(
    model_path=model_id,
    use_quantization=False,
    device_map='auto',
    max_new_tokens=2000,
    torch_dtype='bfloat16',
    low_cpu_mem_usage=True,
    verbose=True
)

# Add tie_weights() call
try:
    model.tie_weights()
    print("✅ Model weights tied successfully")
except Exception as e:
    print(f"⚠️ tie_weights() warning: {e}")

print("✅ Model loaded successfully!")

## Define chat_with_mllm Function

This function encapsulates the multi-turn conversation pattern.

In [None]:
def chat_with_mllm(model, processor, prompt, images_path=[], do_sample=False, 
                   temperature=0.1, show_image=False, max_new_tokens=2000, 
                   messages=[], images=[]):
    """Chat with Llama vision model in multi-turn conversation mode.
    
    Args:
        model: Loaded Llama vision model
        processor: AutoProcessor for the model
        prompt: User's text prompt
        images_path: Path(s) to image files (string or list)
        do_sample: Enable sampling (if True, uses temperature)
        temperature: Sampling temperature (default 0.1)
        show_image: Display image in notebook (default False)
        max_new_tokens: Maximum tokens to generate (default 2000)
        messages: Conversation history (empty list for new conversation)
        images: Loaded image objects (empty list to load from paths)
    
    Returns:
        tuple: (generated_text, updated_messages, images)
    """
    # Ensure list
    if not isinstance(images_path, list):
        images_path = [images_path]

    # Load images
    if len(images) == 0 and len(images_path) > 0:
        for image_path in tqdm(images_path, desc="Loading images"):
            image = load_image(image_path)
            images.append(image)
            if show_image:
                display(image)

    # If starting a new conversation about an image
    if len(messages) == 0:
        messages = [
            {
                "role": "user", 
                "content": [
                    {"type": "image"}, 
                    {"type": "text", "text": prompt}
                ]
            }
        ]

    # If continuing conversation on the image
    else:
        messages.append({
            "role": "user", 
            "content": [{"type": "text", "text": prompt}]
        })

    # Process input data
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(images=images, text=text, return_tensors="pt").to(model.device)

    # Generate response
    generation_args = {"max_new_tokens": max_new_tokens, "do_sample": do_sample}
    if do_sample:
        generation_args["temperature"] = temperature
    else:
        generation_args["temperature"] = None
        generation_args["top_p"] = None
    
    generate_ids = model.generate(**inputs, **generation_args)
    
    # Trim input tokens from output
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:-1]
    generated_texts = processor.decode(generate_ids[0], clean_up_tokenization_spaces=False)

    # Append the model's response to the conversation history
    messages.append({
        "role": "assistant", 
        "content": [{"type": "text", "text": generated_texts}]
    })

    return generated_texts, messages, images

print("✅ chat_with_mllm function defined")

## Define Conversation Prompts

Each prompt extracts a single column to avoid spatial reasoning failures.

In [None]:
# Single-column extraction prompts
CONVERSATION_PROMPTS = {
    "turn_1_date_column": """Look at the transaction table in this bank statement.

Find the leftmost column with the header "Date" or "Date of Transaction".

Extract ONLY the date values from this column, ignoring all other columns.

Important:
- Extract ONLY dates (e.g., "15 Mar 2024" or "15 Mar")
- Do NOT include transaction descriptions
- Do NOT include amounts
- List one date per line

Output format:
[First date]
[Second date]
[Third date]
...""",
    
    "turn_2_transaction_column": """Look at the transaction table in this bank statement.

Find the column with the header "Transaction" or "Description".

This column is located IMMEDIATELY TO THE RIGHT of the Date column.

Extract ONLY the transaction descriptions from this column, ignoring all other columns.

Important:
- Extract ONLY transaction descriptions
- If a description spans multiple lines, combine them with spaces
- Do NOT include dates
- Do NOT include amounts
- List one description per line

Output format:
[First description]
[Second description]
[Third description]
...""",
    
    "turn_3_debit_column": """Look at the transaction table in this bank statement.

Find the column with the header "Debit" or "Withdrawal".

Extract ONLY the debit amounts from this column, ignoring all other columns.

Important:
- Extract ONLY debit amounts with currency symbols
- If a cell is empty, write "EMPTY"
- Do NOT extract credit amounts or balance amounts
- List one amount per line

Output format:
[First amount or EMPTY]
[Second amount or EMPTY]
[Third amount or EMPTY]
...""",
    
    "turn_4_credit_column": """Look at the transaction table in this bank statement.

Find the column with the header "Credit" or "Deposit".

Extract ONLY the credit amounts from this column, ignoring all other columns.

Important:
- Extract ONLY credit amounts with currency symbols
- If a cell is empty, write "EMPTY"
- NEVER add "CR" suffix to these amounts
- Do NOT extract debit amounts or balance amounts
- List one amount per line

Output format:
[First amount or EMPTY]
[Second amount or EMPTY]
[Third amount or EMPTY]
...""",
    
    "turn_5_balance_column": """Look at the transaction table in this bank statement.

Find the rightmost column with the header "Balance".

Extract ONLY the balance amounts from this column, ignoring all other columns.

Important:
- Extract ONLY balance amounts
- Preserve "CR" notation exactly as shown
- Do NOT extract debit or credit amounts
- List one balance per line

Output format:
[First balance]
[Second balance]
[Third balance]
..."""
}

def display_prompts(prompts_dict):
    """Display all conversation prompts in a readable format."""
    print("=" * 70)
    print("SINGLE-COLUMN EXTRACTION PROMPTS")
    print("=" * 70)
    for i, (key, prompt) in enumerate(prompts_dict.items(), 1):
        turn_name = key.replace("_", " ").title()
        print(f"\n{i}. {turn_name}")
        print("-" * 70)
        preview = prompt if len(prompt) <= 200 else prompt[:197] + "..."
        print(f"{preview}")
    print("\n" + "=" * 70)
    print(f"Total prompts defined: {len(prompts_dict)}")

display_prompts(CONVERSATION_PROMPTS)
print("\n✅ Conversation prompts defined and ready to use")

## Load Image and Initialize Conversation

In [None]:
# Image path
imageName = "/home/jovyan/_LMM_POC/evaluation_data/image_003.png"

# Initialize conversation
messages = []
images = []

print("📸 Processing bank statement image...")
print(f"📁 Image: {imageName}")

## Turn 1: Extract Date Column Only

In [None]:
print(f"📝 Using prompt: turn_1_date_column")

response1, messages, images = chat_with_mllm(
    model, 
    processor, 
    CONVERSATION_PROMPTS["turn_1_date_column"],
    images_path=[imageName],
    do_sample=False,
    max_new_tokens=1000,
    show_image=True,
    messages=messages,
    images=images
)

print("\n" + "=" * 60)
print("TURN 1 - DATE COLUMN:")
print("=" * 60)
print(response1)
print("=" * 60)

## Turn 2: Extract Transaction Column Only

In [None]:
print(f"📝 Using prompt: turn_2_transaction_column")

response2, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_2_transaction_column"],
    messages=messages, 
    images=images,
    max_new_tokens=2000
)

print("\n" + "=" * 60)
print("TURN 2 - TRANSACTION COLUMN:")
print("=" * 60)
print(response2)
print("=" * 60)

## Turn 3: Extract Debit Column Only

In [None]:
print(f"📝 Using prompt: turn_3_debit_column")

response3, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_3_debit_column"],
    messages=messages,
    images=images,
    max_new_tokens=1000
)

print("\n" + "=" * 60)
print("TURN 3 - DEBIT COLUMN:")
print("=" * 60)
print(response3)
print("=" * 60)

## Turn 4: Extract Credit Column Only

In [None]:
print(f"📝 Using prompt: turn_4_credit_column")

response4, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_4_credit_column"],
    messages=messages,
    images=images,
    max_new_tokens=1000
)

print("\n" + "=" * 60)
print("TURN 4 - CREDIT COLUMN:")
print("=" * 60)
print(response4)
print("=" * 60)

## Turn 5: Extract Balance Column Only

In [None]:
print(f"📝 Using prompt: turn_5_balance_column")

response5, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_5_balance_column"],
    messages=messages,
    images=images,
    max_new_tokens=1000
)

print("\n" + "=" * 60)
print("TURN 5 - BALANCE COLUMN:")
print("=" * 60)
print(response5)
print("=" * 60)

## Validation: Check Row Counts

In [None]:
# Parse responses to count rows
def count_lines(response_text):
    """Count non-empty lines in response."""
    lines = [line.strip() for line in response_text.strip().split('\n') if line.strip()]
    return len(lines)

date_count = count_lines(response1)
transaction_count = count_lines(response2)
debit_count = count_lines(response3)
credit_count = count_lines(response4)
balance_count = count_lines(response5)

print("=" * 60)
print("VALIDATION - ROW COUNTS:")
print("=" * 60)
print(f"Date column:        {date_count} rows")
print(f"Transaction column: {transaction_count} rows")
print(f"Debit column:       {debit_count} rows")
print(f"Credit column:      {credit_count} rows")
print(f"Balance column:     {balance_count} rows")
print("=" * 60)

if len(set([date_count, transaction_count, debit_count, credit_count, balance_count])) == 1:
    print("✅ All columns have matching row counts")
else:
    print("⚠️ WARNING: Row count mismatch detected!")
    print("   Manual review required")

## Save Extraction Results

In [None]:
# Save all column extractions to a file
output_path = Path("llama_single_column_extraction.txt")

with output_path.open("w", encoding="utf-8") as f:
    f.write("=" * 60 + "\n")
    f.write("SINGLE-COLUMN EXTRACTION RESULTS\n")
    f.write("Llama-3.2-Vision-11B\n")
    f.write("=" * 60 + "\n\n")
    
    f.write("DATE COLUMN:\n")
    f.write("-" * 60 + "\n")
    f.write(response1 + "\n\n")
    
    f.write("TRANSACTION COLUMN:\n")
    f.write("-" * 60 + "\n")
    f.write(response2 + "\n\n")
    
    f.write("DEBIT COLUMN:\n")
    f.write("-" * 60 + "\n")
    f.write(response3 + "\n\n")
    
    f.write("CREDIT COLUMN:\n")
    f.write("-" * 60 + "\n")
    f.write(response4 + "\n\n")
    
    f.write("BALANCE COLUMN:\n")
    f.write("-" * 60 + "\n")
    f.write(response5 + "\n\n")
    
    f.write("=" * 60 + "\n")
    f.write(f"Total conversations: {len(messages)}\n")
    f.write("=" * 60 + "\n")

print(f"✅ Extraction results saved to: {output_path}")
print(f"📊 File size: {output_path.stat().st_size} bytes")