# Llama 3.2 Vision Multi-Turn Debit Extractor

This notebook demonstrates multi-turn conversational extraction using Llama 3.2 Vision.
Uses the `chat_with_mllm` pattern for clean, maintainable multi-turn conversations.

**Reference**: [Chat with Your Images Using Multimodal LLMs](https://medium.com/data-science/chat-with-your-images-using-multimodal-llms-60af003e8bfa)

## Imports

In [None]:
# Add project root to path for common/ imports
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import random
import numpy as np
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from transformers.image_utils import load_image
from tqdm.notebook import tqdm
from IPython.display import display, Markdown

## Pre-emptive Memory Cleanup

Optional GPU memory cleanup to prevent OOM errors when switching between models.

In [None]:
# Optional: Pre-emptive memory cleanup (useful when switching models)
try:
    from common.gpu_optimization import emergency_cleanup
    print("🧹 Clearing GPU memory...")
    emergency_cleanup(verbose=False)
    print("✅ Memory cleanup complete")
except ImportError:
    print("⚠️ GPU optimization module not available - skipping cleanup")

## Set Random Seed for Reproducibility

In [None]:
from common.reproducibility import set_seed
set_seed(42)

## Load Model

In [None]:
model_id = "/home/jovyan/shared_PTM/Llama-3.2-11B-Vision-Instruct"

print("🔧 Loading Llama-3.2-Vision model...")

from common.llama_model_loader_robust import load_llama_model_robust

model, processor = load_llama_model_robust(
    model_path=model_id,
    use_quantization=False,
    device_map='auto',
    max_new_tokens=2000,
    torch_dtype='bfloat16',
    low_cpu_mem_usage=True,
    verbose=True
)

# Add tie_weights() call
try:
    model.tie_weights()
    print("✅ Model weights tied successfully")
except Exception as e:
    print(f"⚠️ tie_weights() warning: {e}")

print("✅ Model loaded successfully!")

## Optional: Manual Memory Cleanup

Run this cell if you experience memory issues during the conversation. Not needed for normal operation on H200.

In [None]:
# Optional: Run this cell if you experience memory issues during conversation
import gc

print("🧹 Manual memory cleanup...")
gc.collect()
torch.cuda.empty_cache()

# Show memory status
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        allocated = torch.cuda.memory_allocated(i) / 1e9
        reserved = torch.cuda.memory_reserved(i) / 1e9
        print(f"   GPU {i}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
print("✅ Cleanup complete")

## Define chat_with_mllm Function

This function encapsulates the multi-turn conversation pattern.

In [None]:
def chat_with_mllm(model, processor, prompt, images_path=[], do_sample=False, 
                   temperature=0.1, show_image=False, max_new_tokens=2000, 
                   messages=[], images=[]):
    """Chat with Llama vision model in multi-turn conversation mode.
    
    Args:
        model: Loaded Llama vision model
        processor: AutoProcessor for the model
        prompt: User's text prompt
        images_path: Path(s) to image files (string or list)
        do_sample: Enable sampling (if True, uses temperature)
        temperature: Sampling temperature (default 0.1)
        show_image: Display image in notebook (default False)
        max_new_tokens: Maximum tokens to generate (default 2000)
        messages: Conversation history (empty list for new conversation)
        images: Loaded image objects (empty list to load from paths)
    
    Returns:
        tuple: (generated_text, updated_messages, images)
    """
    # Ensure list
    if not isinstance(images_path, list):
        images_path = [images_path]

    # Load images
    if len(images) == 0 and len(images_path) > 0:
        for image_path in tqdm(images_path, desc="Loading images"):
            image = load_image(image_path)
            images.append(image)
            if show_image:
                display(image)

    # If starting a new conversation about an image
    if len(messages) == 0:
        messages = [
            {
                "role": "user", 
                "content": [
                    {"type": "image"}, 
                    {"type": "text", "text": prompt}
                ]
            }
        ]

    # If continuing conversation on the image
    else:
        messages.append({
            "role": "user", 
            "content": [{"type": "text", "text": prompt}]
        })

    # Process input data
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(images=images, text=text, return_tensors="pt").to(model.device)

    # Generate response
    generation_args = {"max_new_tokens": max_new_tokens, "do_sample": do_sample}
    if do_sample:
        generation_args["temperature"] = temperature
    else:
        generation_args["temperature"] = None
        generation_args["top_p"] = None
    
    generate_ids = model.generate(**inputs, **generation_args)
    
    # Trim input tokens from output
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:-1]
    generated_texts = processor.decode(generate_ids[0], clean_up_tokenization_spaces=False)

    # Append the model's response to the conversation history
    messages.append({
        "role": "assistant", 
        "content": [{"type": "text", "text": generated_texts}]
    })

    return generated_texts, messages, images

print("✅ chat_with_mllm function defined")

## Define Conversation Prompts

All prompts are defined upfront for easy modification and review.

In [None]:
# Conversation prompts dictionary
CONVERSATION_PROMPTS = {
    "turn_1_initial_extraction": """You are an expert document analyser specializing in Date Grouped Australian Bank Statement extraction.
Date Grouped Bank Statements are date ordered, with one or more transactions for each date header.
Every transaction for a given date heading has a description, a debit/credit amount and finally a balance amount with a ' CR' suffix.
Extract all balance amounts along with their ' CR' suffix, the transaction dates (from the date heading) and transaction descriptions,
maintaining the same date ordering as the image, with every transaction appearing on its own row and remembering that some date headings have more than one balance.""",
    
    "turn_2_count_transactions": "How many transactions are shown in this bank statement?",
    
    "turn_3_extract_debits": "From your first response, extract ONLY the debit/withdrawal amounts (amounts paid out). List them in order, one per line.",
    
    "turn_4_verify_count": "How many debit/withdrawal transactions did you extract in your previous response?",
    
    "turn_5_total_debits": "What is the total sum of all debit/withdrawal amounts in this statement?",
    
    "turn_6_date_range": "What is the date range covered by this bank statement?",
    
    "turn_7_verify_consistency": "In your very first response, you extracted all transactions. Can you verify that the debit amounts you listed in turn 3 match the debit amounts from your first extraction?"
}

def display_prompts(prompts_dict):
    """Display all conversation prompts in a readable format."""
    print("=" * 70)
    print("CONVERSATION PROMPTS")
    print("=" * 70)
    for i, (key, prompt) in enumerate(prompts_dict.items(), 1):
        # Format the key for display
        turn_name = key.replace("_", " ").title()
        print(f"\n{i}. {turn_name}")
        print("-" * 70)
        # Truncate long prompts for preview
        preview = prompt if len(prompt) <= 200 else prompt[:197] + "..."
        print(f"{preview}")
    print("\n" + "=" * 70)
    print(f"Total prompts defined: {len(prompts_dict)}")

# Display the prompts
display_prompts(CONVERSATION_PROMPTS)
print("\n✅ Conversation prompts defined and ready to use")

## Initial Extraction (Turn 1)

Extract all transaction data from the bank statement image.

In [None]:
# Image path
imageName = "/home/jovyan/_LMM_POC/evaluation_data/image_009.png"

# Initialize conversation
messages = []
images = []

print("📸 Processing bank statement image...")
print(f"📝 Using prompt: turn_1_initial_extraction")

response1, messages, images = chat_with_mllm(
    model, processor, 
    CONVERSATION_PROMPTS["turn_1_initial_extraction"],
    images_path=[imageName],
    do_sample=False,
    max_new_tokens=2000,
    show_image=True,
    messages=messages,
    images=images
)

print("\n" + "=" * 60)
print("TURN 1 - INITIAL EXTRACTION:")
print("=" * 60)
print(response1)
print("=" * 60)

# Save initial extraction
Path("llama_debit_extractor_initial.txt").write_text(response1)
print("\n✅ Initial extraction saved to llama_debit_extractor_initial.txt")

## Turn 2: Count Transactions

In [None]:
print(f"📝 Using prompt: turn_2_count_transactions")

response2, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_2_count_transactions"],
    messages=messages, 
    images=images,
    max_new_tokens=500
)

print("\n" + "=" * 60)
print("TURN 2 - TRANSACTION COUNT:")
print("=" * 60)
print(response2)
print("=" * 60)

## Turn 3: Extract Debit/Withdrawal Amounts Only

In [None]:
print(f"📝 Using prompt: turn_3_extract_debits")

response3, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_3_extract_debits"],
    messages=messages,
    images=images,
    max_new_tokens=1000
)

print("\n" + "=" * 60)
print("TURN 3 - DEBIT AMOUNTS ONLY:")
print("=" * 60)
print(response3)
print("=" * 60)

# Save debit amounts
Path("llama_debit_amounts.txt").write_text(response3)
print("\n✅ Debit amounts saved to llama_debit_amounts.txt")

## Turn 4: Verify Debit Count

In [None]:
print(f"📝 Using prompt: turn_4_verify_count")

response4, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_4_verify_count"],
    messages=messages,
    images=images,
    max_new_tokens=500
)

print("\n" + "=" * 60)
print("TURN 4 - DEBIT COUNT:")
print("=" * 60)
print(response4)
print("=" * 60)

## Turn 5: Total Debit Amount

In [None]:
print(f"📝 Using prompt: turn_5_total_debits")

response5, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_5_total_debits"],
    messages=messages,
    images=images,
    max_new_tokens=500
)

print("\n" + "=" * 60)
print("TURN 5 - TOTAL DEBITS:")
print("=" * 60)
print(response5)
print("=" * 60)

## Turn 6: Date Range

In [None]:
print(f"📝 Using prompt: turn_6_date_range")

response6, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_6_date_range"],
    messages=messages,
    images=images,
    max_new_tokens=500
)

print("\n" + "=" * 60)
print("TURN 6 - DATE RANGE:")
print("=" * 60)
print(response6)
print("=" * 60)

## Turn 7: Verification - Cross-check First Response

In [None]:
print(f"📝 Using prompt: turn_7_verify_consistency")

response7, messages, images = chat_with_mllm(
    model, processor,
    CONVERSATION_PROMPTS["turn_7_verify_consistency"],
    messages=messages,
    images=images,
    max_new_tokens=1000
)

print("\n" + "=" * 60)
print("TURN 7 - VERIFICATION:")
print("=" * 60)
print(response7)
print("=" * 60)

## Debug: View Conversation Structure

In [None]:
print("🔍 Current conversation structure:")
print("=" * 60)
for i, msg in enumerate(messages, 1):
    print(f"\nMessage {i} ({msg['role']}):")
    for content in msg['content']:
        if content['type'] == 'text':
            preview = content['text'][:100] + "..." if len(content['text']) > 100 else content['text']
            print(f"  [text]: {preview}")
        else:
            print(f"  [{content['type']}]")
print("=" * 60)
print(f"\n📊 Total messages: {len(messages)}")
print(f"📊 User messages: {sum(1 for m in messages if m['role'] == 'user')}")
print(f"📊 Assistant messages: {sum(1 for m in messages if m['role'] == 'assistant')}")

## Save Full Conversation

In [None]:
# Save the entire conversation to a file
output_path = Path("llama_multiturn_debit_conversation.txt")

with output_path.open("w", encoding="utf-8") as text_file:
    text_file.write("=" * 60 + "\n")
    text_file.write("MULTI-TURN DEBIT EXTRACTION CONVERSATION\n")
    text_file.write("Llama-3.2-Vision-11B\n")
    text_file.write("=" * 60 + "\n\n")
    
    for i, msg in enumerate(messages, 1):
        role = msg["role"].upper()
        text_file.write(f"\n{'-' * 60}\n")
        text_file.write(f"MESSAGE {i} - {role}\n")
        text_file.write(f"{'-' * 60}\n\n")
        
        for content in msg["content"]:
            if content["type"] == "text":
                text_file.write(content["text"] + "\n")
            elif content["type"] == "image":
                text_file.write("[IMAGE]\n")
    
    text_file.write("\n" + "=" * 60 + "\n")
    text_file.write(f"Total messages: {len(messages)}\n")
    text_file.write("=" * 60 + "\n")

print(f"✅ Full conversation saved to: {output_path}")
print(f"📊 File size: {output_path.stat().st_size} bytes")
print(f"💬 Total messages in conversation: {len(messages)}")