# Llama 3.2 Vision Multi-Turn Flat Debit Extractor

- This notebook demonstrates multi-turn conversational extraction using Llama 3.2 Vision.
- Uses the `chat_with_mllm` pattern to handle multiple turns of conversation with Flat Bank Statement images.
- Each turn focuses on a specific extraction task, such as initial data extraction, selecting specific columns, extracting debit amounts, and filtering by date range.

**Reference**: [Chat with Your Images Using Multimodal LLMs](https://medium.com/data-science/chat-with-your-images-using-multimodal-llms-60af003e8bfa)

## Imports

In [None]:
# Add project root to path for common/ imports
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import random
import numpy as np
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from transformers.image_utils import load_image
from tqdm.notebook import tqdm
from IPython.display import display, Markdown

## Pre-emptive Memory Cleanup

Optional GPU memory cleanup to prevent OOM errors when switching between models.

In [None]:
# Optional: Pre-emptive memory cleanup (useful when switching models)
try:
    from common.gpu_optimization import emergency_cleanup
    print("üßπ Clearing GPU memory...")
    emergency_cleanup(verbose=False)
    print("‚úÖ Memory cleanup complete")
except ImportError:
    print("‚ö†Ô∏è GPU optimization module not available - skipping cleanup")

## Set Random Seed for Reproducibility

In [None]:
from common.reproducibility import set_seed
set_seed(42)

## Load Model

In [None]:
model_id = "/home/jovyan/shared_PTM/Llama-3.2-11B-Vision-Instruct"

print("üîß Loading Llama-3.2-Vision model...")

from common.llama_model_loader_robust import load_llama_model_robust

model, processor = load_llama_model_robust(
    model_path=model_id,
    use_quantization=False,
    device_map='auto',
    max_new_tokens=2000,
    torch_dtype='bfloat16',
    low_cpu_mem_usage=True,
    verbose=True
)

# Add tie_weights() call
try:
    model.tie_weights()
    print("‚úÖ Model weights tied successfully")
except Exception as e:
    print(f"‚ö†Ô∏è tie_weights() warning: {e}")

print("‚úÖ Model loaded successfully!")

## Optional: Manual Memory Cleanup

Run this cell if you experience memory issues during the conversation. Not needed for normal operation on H200.

In [None]:
# Optional: Run this cell if you experience memory issues during conversation
# import gc

# print("üßπ Manual memory cleanup...")
# gc.collect()
# torch.cuda.empty_cache()

# # Show memory status
# if torch.cuda.is_available():
#     for i in range(torch.cuda.device_count()):
#         allocated = torch.cuda.memory_allocated(i) / 1e9
#         reserved = torch.cuda.memory_reserved(i) / 1e9
#         print(f"   GPU {i}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
# print("‚úÖ Cleanup complete")

## Define chat_with_mllm Function

This function encapsulates the multi-turn conversation pattern.

In [None]:
from common.llama_multiturn_chat import chat_with_mllm

print("‚úÖ chat_with_mllm function imported from common.llama_multiturn_chat")

## Define Conversation Prompts

All prompts are defined upfront for easy modification and review.

In [None]:
# Conversation prompts dictionary
CONVERSATION_PROMPTS = {
    "turn_0_identify_headers": """Look at the bank statement image.

Find the transaction table where transactions are listed.

At the top of this table, there is a row of column headers.

Your task: List ALL the column headers from this table, starting from the left edge and moving to the right edge.

IMPORTANT: The leftmost column is usually "Date" or "Date of Transaction". Start there and read every header moving right.

Include EVERY header you see. Do not skip any.

Output Format:
- TRANSACTION TABLE HEADERS (reading left to right):
- [Write ONLY the headers you actually see in the image, separated by commas]
- If no table headers visible, write: NO_HEADERS""",
    
    "turn_1_initial_extraction": """Step 1: Extract ONLY the table of transaction data from this Australian bank statement in markdown format.""",
    
    "turn_2_select_columns": """STEP 2: From the table you extracted in STEP 1, extract only the "Date | Description | Withdrawal" columns""",
    
    "turn_3_extract_debits": """STEP 3: From the table you extracted in STEP 2, remove any row NOT showing an amount (i.e. having an empty cell) in the "Withdrawal" column.""",
    
    "turn_4_date_range": """Extract earliest date and the latest date from the "Date" colum. Express your answer in the format "STATEMENT_DATE_RANGE: dd/mm/yyyy - dd/mm/yyyy" """
}

def display_prompts(prompts_dict):
    """Display all conversation prompts in a readable format."""
    print("=" * 70)
    print("CONVERSATION PROMPTS")
    print("=" * 70)
    for i, (key, prompt) in enumerate(prompts_dict.items(), 1):
        # Format the key for display
        turn_name = key.replace("_", " ").title()
        print(f"\n{i}. {turn_name}")
        print("-" * 70)
        # Truncate long prompts for preview
        preview = prompt if len(prompt) <= 200 else prompt[:197] + "..."
        print(f"{preview}")
    print("\n" + "=" * 70)
    print(f"Total prompts defined: {len(prompts_dict)}")

# Display the prompts
display_prompts(CONVERSATION_PROMPTS)
print("\n‚úÖ Conversation prompts defined and ready to use")

# Select Image

In [None]:
# Image path
imageName = "/home/jovyan/_LMM_POC/evaluation_data/image_003.png"


In [None]:
# OPTIONAL: Uncomment to enable preprocessing
# from common.image_preprocessing import enhance_for_llama, preprocess_statement_for_llama, enhance_statement_quality
# from PIL import Image
# import tempfile

# # Choose ONE preprocessing approach:

# # Option 1: Light enhancement (recommended for high-quality scans)
# # preprocessed_img = enhance_statement_quality(imageName)

# # Option 2: Moderate enhancement (upscaling + sharpness + contrast)
# # preprocessed_img = enhance_for_llama(imageName)

# # Option 3: Aggressive preprocessing (denoise + binarize + remove lines)
# # preprocessed_img = preprocess_statement_for_llama(imageName)

# # Save preprocessed image to temporary file and update imageName
# # with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
# #     preprocessed_img.save(tmp.name)
# #     imageName = tmp.name
# #     print(f"‚úÖ Using preprocessed image: {imageName}")
# #     display(preprocessed_img)

print("‚è≠Ô∏è  Skipping preprocessing - using original image")

## OPTIONAL: Image Preprocessing

**Experimental:** Test whether preprocessing improves extraction accuracy.

Available preprocessing functions:
- `enhance_for_llama()` - Upscale, sharpen, increase contrast
- `preprocess_statement_for_llama()` - Denoise, binarize, remove table lines
- `enhance_statement_quality()` - Moderate enhancement for bank statements

**Note:** Modern VLMs are trained on natural images. Preprocessing may help with low-quality scans but could hurt performance on high-quality images. Test both approaches to see what works best for your data.

## Turn 0: Identify Table Headers

First, identify all column headers in the transaction table.

In [None]:
# Initialize conversation
messages = []
images = []

print("üì∏ Processing bank statement image...")
print(f"üìù Using prompt: turn_0_identify_headers")

response0, messages, images = chat_with_mllm(
    model, 
    processor, 
    CONVERSATION_PROMPTS["turn_0_identify_headers"],
    images_path=[imageName],
    do_sample=False,
    max_new_tokens=500,
    show_image=True,
    messages=messages,
    images=images
)

print("\n" + "=" * 60)
print("TURN 0 - IDENTIFY TABLE HEADERS:")
print("=" * 60)
print(response0)
print("=" * 60)

## Turn 1: Extract Transaction Table

Extract the complete transaction table in markdown format.

In [None]:
print(f"üìù Using prompt: turn_1_initial_extraction")

response1, messages, images = chat_with_mllm(
    model, 
    processor, 
    CONVERSATION_PROMPTS["turn_1_initial_extraction"],
    messages=messages,
    images=images,
    do_sample=False,
    max_new_tokens=3000
)

display(Markdown(response1))

# Save initial extraction
Path("llama_debit_extractor_initial.txt").write_text(response1)
print("\n‚úÖ Initial extraction saved to llama_debit_extractor_initial.txt")

## Turn 2: Select Columns

In [None]:
print(f"üìù Using prompt: turn_2_select_columns")

response2, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_2_select_columns"],
    messages=messages, 
    images=images,
    max_new_tokens=2000
)

display(Markdown(response2))

# print("\n" + "=" * 60)
# print("TURN 2 - SELECT COLUMNS:")
# print("=" * 60)
# print(response2)
# print("=" * 60)

## Turn 3: Extract Debit/Withdrawal Amounts Only

In [None]:
print(f"üìù Using prompt: turn_3_extract_debits")

response3, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_3_extract_debits"],
    messages=messages,
    images=images,
    max_new_tokens=2000
)

display(Markdown(response3))

# print("\n" + "=" * 60)
# print("TURN 3 - DEBIT AMOUNTS ONLY:")
# print("=" * 60)
# print(response3)
# print("=" * 60)

# Save debit amounts
Path("llama_debit_amounts.txt").write_text(response3)
print("\n‚úÖ Debit amounts saved to llama_debit_amounts.txt")

## Turn 4: Date Range

In [None]:
print(f"üìù Using prompt: turn_4_date_range")

response4, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_4_date_range"],
    messages=messages,
    images=images,
    max_new_tokens=50
)

print("\n" + "=" * 60)
print("TURN 4 - DATE RANGE:")
print("=" * 60)
print(response4)
print("=" * 60)

## Debug: View Conversation Structure

In [None]:
print("üîç Current conversation structure:")
print("=" * 60)
for i, msg in enumerate(messages, 1):
    print(f"\nMessage {i} ({msg['role']}):")
    for content in msg['content']:
        if content['type'] == 'text':
            preview = content['text'][:100] + "..." if len(content['text']) > 100 else content['text']
            print(f"  [text]: {preview}")
        else:
            print(f"  [{content['type']}]")
print("=" * 60)
print(f"\nüìä Total messages: {len(messages)}")
print(f"üìä User messages: {sum(1 for m in messages if m['role'] == 'user')}")
print(f"üìä Assistant messages: {sum(1 for m in messages if m['role'] == 'assistant')}")

## Save Full Conversation

In [None]:
# Save the entire conversation to a file
output_path = Path("llama_multiturn_debit_conversation.txt")

with output_path.open("w", encoding="utf-8") as text_file:
    text_file.write("=" * 60 + "\n")
    text_file.write("MULTI-TURN DEBIT EXTRACTION CONVERSATION\n")
    text_file.write("Llama-3.2-Vision-11B\n")
    text_file.write("=" * 60 + "\n\n")
    
    for i, msg in enumerate(messages, 1):
        role = msg["role"].upper()
        text_file.write(f"\n{'-' * 60}\n")
        text_file.write(f"MESSAGE {i} - {role}\n")
        text_file.write(f"{'-' * 60}\n\n")
        
        for content in msg["content"]:
            if content["type"] == "text":
                text_file.write(content["text"] + "\n")
            elif content["type"] == "image":
                text_file.write("[IMAGE]\n")
    
    text_file.write("\n" + "=" * 60 + "\n")
    text_file.write(f"Total messages: {len(messages)}\n")
    text_file.write("=" * 60 + "\n")

print(f"‚úÖ Full conversation saved to: {output_path}")
print(f"üìä File size: {output_path.stat().st_size} bytes")
print(f"üí¨ Total messages in conversation: {len(messages)}")