# Imports

In [None]:
from pathlib import Path
import random

import numpy as np
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration

# Set Random Seed for Reproducibility

In [None]:
from common.reproducibility import set_seed
set_seed(42)
print("✅ Random seed set to 42 for reproducibility")

# Load the model

In [None]:
# Update this path to your local Llama model
model_id = "/home/jovyan/shared_PTM/Llama-3.2-11B-Vision-Instruct"

print("🔧 Loading Llama-3.2-Vision model...")
# model = MllamaForConditionalGeneration.from_pretrained(
#     model_id,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
# )
# processor = AutoProcessor.from_pretrained(model_id)

from common.llama_model_loader_robust import load_llama_model_robust

model, processor = load_llama_model_robust(
    model_path=model_id,
    use_quantization=False,
    device_map='auto',
    max_new_tokens=2000,
    torch_dtype='bfloat16',
    low_cpu_mem_usage=True,
    verbose=True
)

# Add tie_weights() call
try:
    model.tie_weights()
    print("✅ Model weights tied successfully")
except Exception as e:
    print(f"⚠️ tie_weights() warning: {e}")

# processor

# Load the image

In [None]:
# Update this path to your test image
# imageName = "/home/jovyan/shared_PoC_data/evaluation_data/image_009.png"
imageName = "/home/jovyan/nfs_share/tod/LMM_POC/evaluation_data/image_008.png"
print("📁 Loading image...")
image = Image.open(imageName)

# CRITICAL: Store as list for multi-turn compatibility
images = [image]

print(f"✅ Image loaded: {image.size}")
print(f"✅ Images list created with {len(images)} image(s)")

# Define the prompt

In [None]:
# Visual Question Answering - ask a simple question about the image
prompt = """
You are an expert document analyser specializing in Date Grouped Australian Bank Statement extraction.
Date Grouped Bank Statements are date ordered, with one or more transactions for each date header.
Every transaction for a given date heading has a description, a debit/credit amount and finally a balance amount with a ' CR' suffix.
Extract all balance amounts along with their ' CR' suffix, the transaction dates (from the date heading) and transaction descriptions,
maintaining the same date ordering as the image, with every transaction appearing on its own row and remembering that some date headings have more than one balance.
"""

# Create message structure for Llama
messageDataStructure = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {
                "type": "text",
                "text": prompt,
            },
        ],
    }
]

print(f"💬 Prompt: {prompt}")
print("🤖 Generating response with Llama-3.2-Vision...")

# Clean Response Function

In [None]:
def clean_llama_response(response: str) -> str:
    """Remove chat template artifacts and extract only the assistant's response.
    
    Note: This function is kept for backwards compatibility, but when using
    the proper multi-turn pattern (trimming generate_ids), it's not needed.
    """
    start_marker = "<|start_header_id|>assistant<|end_header_id|>"
    end_marker = "<|eot_id|>"
    
    start_idx = response.find(start_marker)
    if start_idx != -1:
        start_idx += len(start_marker)
        end_idx = response.find(end_marker, start_idx)
        if end_idx != -1:
            return response[start_idx:end_idx].strip()
    
    return response.strip()

# Process the prompt

In [None]:
# Process the input using the CORRECT multi-turn pattern
# Based on: https://medium.com/data-science/chat-with-your-images-using-multimodal-llms-60af003e8bfa

textInput = processor.apply_chat_template(
    messageDataStructure, add_generation_prompt=True
)

# CRITICAL: Use named parameter 'images=' with list
inputs = processor(images=images, text=textInput, return_tensors="pt").to(model.device)

# Generate response with deterministic parameters
output = model.generate(
    **inputs,
    max_new_tokens=2000,
    do_sample=False,
    temperature=None,
    top_p=None,
)

# CRITICAL: Trim input tokens from output (this is the key to clean responses!)
generate_ids = output[:, inputs['input_ids'].shape[1]:-1]
cleanedOutput = processor.decode(generate_ids[0], clean_up_tokenization_spaces=False)

print("✅ Response generated successfully!")
print("\n" + "=" * 60)
print("CLEANED EXTRACTION:")
print("=" * 60)
print(cleanedOutput)
print("=" * 60)

# Save the cleaned response to a file
output_path = Path("llama_grouped_bank_statement_output.txt")

with output_path.open("w", encoding="utf-8") as text_file:
    text_file.write(cleanedOutput)

print(f"✅ Response saved to: {output_path}")
print(f"📊 File size: {output_path.stat().st_size} bytes")

## Multi-Turn Conversation Support

Llama supports multi-turn conversations by maintaining a conversation history list:

## 🔑 Key Multi-Turn Pattern for Llama 3.2 Vision

This notebook uses the **correct multi-turn conversation pattern** discovered from the Medium article:
[Chat with Your Images Using Llama 3.2-Vision Multimodal LLMs](https://medium.com/data-science/chat-with-your-images-using-multimodal-llms-60af003e8bfa)

### Critical Requirements:

1. **Images as List**: `images = [image]` (not just `image`)
2. **Named Parameter**: `processor(images=images, text=text, ...)` (not positional args)
3. **Trim Generated Tokens**: `generate_ids[:, inputs['input_ids'].shape[1]:-1]`
4. **Same Images Every Turn**: Pass the same `images` list for all turns

### Message Structure:

- **Turn 1**: `{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "..."}]}`
- **Turn 2+**: `{"role": "user", "content": [{"type": "text", "text": "..."}]}` (no image in content)
- **Assistant**: `{"role": "assistant", "content": [{"type": "text", "text": "..."}]}`

The model attends to the image only in the first turn, but the processor needs the images list for all turns because the chat template contains the `<|image|>` token.

In [None]:
# Store conversation history for multi-turn support
# Initialize with first exchange
conversation_history = messageDataStructure.copy()

# Add assistant's response to history
conversation_history.append({
    "role": "assistant",
    "content": [{"type": "text", "text": cleanedOutput}]
})

print("✅ Conversation history initialized")
print(f"📊 Current conversation has {len(conversation_history)} messages (1 user + 1 assistant)")
print(f"💡 Pattern: Using working multi-turn approach from Medium article")

### Debug: View Conversation Context

This cell helps you see what's being sent to the model:

In [None]:
# Optional: Debug conversation structure
print("🔍 Current conversation structure:")
print("=" * 60)
for i, msg in enumerate(conversation_history, 1):
    print(f"\nMessage {i} ({msg['role']}):")
    for content in msg['content']:
        if content['type'] == 'text':
            preview = content['text'][:100] + "..." if len(content['text']) > 100 else content['text']
            print(f"  [text]: {preview}")
        else:
            print(f"  [{content['type']}]")
print("=" * 60)

### Follow-up Question (Turn 2)

In [None]:
# Follow-up question (Turn 2)
# Using the WORKING pattern from: https://medium.com/data-science/chat-with-your-images-using-multimodal-llms-60af003e8bfa

follow_up_prompt = "How many transactions are shown in this bank statement?"

# Append user's follow-up to conversation history (text only - NO image in content)
conversation_history.append({
    "role": "user",
    "content": [{"type": "text", "text": follow_up_prompt}]
})

print(f"💬 Follow-up question: {follow_up_prompt}")
print("🤖 Generating follow-up response with Llama-3.2-Vision...")

# Process with updated conversation history
textInput = processor.apply_chat_template(
    conversation_history, add_generation_prompt=True
)

# CRITICAL: Use named parameter 'images=' and pass the SAME images list
inputs = processor(images=images, text=textInput, return_tensors="pt").to(model.device)

# Generate response
output = model.generate(
    **inputs,
    max_new_tokens=2000,
    do_sample=False,
    temperature=None,
    top_p=None,
)

# CRITICAL: Trim input tokens from output
generate_ids = output[:, inputs['input_ids'].shape[1]:-1]
cleanedOutput2 = processor.decode(generate_ids[0], clean_up_tokenization_spaces=False)

print("\n✅ Follow-up response generated successfully!")
print("\n" + "=" * 60)
print("FOLLOW-UP RESPONSE:")
print("=" * 60)
print(cleanedOutput2)
print("=" * 60)

# Update conversation history with assistant's response
conversation_history.append({
    "role": "assistant",
    "content": [{"type": "text", "text": cleanedOutput2}]
})

print(f"\n📊 Conversation now has {len(conversation_history)} messages")

### Additional Follow-up (Turn 3 - Optional)

You can continue the conversation by running this cell with different questions:

In [None]:
# Third turn - another follow-up (uncomment to use)
follow_up_prompt_3 = "What is the date range covered by this bank statement?"

# Append user's follow-up to conversation history
conversation_history.append({
    "role": "user",
    "content": [{"type": "text", "text": follow_up_prompt_3}]
})

print(f"💬 Follow-up question: {follow_up_prompt_3}")
print("🤖 Generating follow-up response with Llama-3.2-Vision...")

# Process with updated conversation history
textInput = processor.apply_chat_template(
    conversation_history, add_generation_prompt=True
)

# Use named parameter 'images=' and pass the SAME images list
inputs = processor(images=images, text=textInput, return_tensors="pt").to(model.device)

# Generate response
output = model.generate(
    **inputs,
    max_new_tokens=2000,
    do_sample=False,
    temperature=None,
    top_p=None,
)

# Trim input tokens from output
generate_ids = output[:, inputs['input_ids'].shape[1]:-1]
cleanedOutput3 = processor.decode(generate_ids[0], clean_up_tokenization_spaces=False)

print("\n✅ Follow-up response generated successfully!")
print("\n" + "=" * 60)
print("FOLLOW-UP RESPONSE:")
print("=" * 60)
print(cleanedOutput3)
print("=" * 60)

# Update conversation history with assistant's response
conversation_history.append({
    "role": "assistant",
    "content": [{"type": "text", "text": cleanedOutput3}]
})

print(f"\n📊 Conversation now has {len(conversation_history)} messages")
print("\n💡 To ask more questions, copy this cell and modify the 'follow_up_prompt_3' variable")

### Save Multi-Turn Conversation

In [None]:
# Save the entire conversation to a file
output_path = Path("llama_multiturn_conversation.txt")

with output_path.open("w", encoding="utf-8") as text_file:
    text_file.write("=" * 60 + "\n")
    text_file.write("MULTI-TURN CONVERSATION WITH LLAMA-3.2-VISION\n")
    text_file.write("=" * 60 + "\n\n")
    
    for i, msg in enumerate(conversation_history, 1):
        role = msg["role"].upper()
        text_file.write(f"\n{'-' * 60}\n")
        text_file.write(f"MESSAGE {i} - {role}\n")
        text_file.write(f"{'-' * 60}\n\n")
        
        for content in msg["content"]:
            if content["type"] == "text":
                text_file.write(content["text"] + "\n")
            elif content["type"] == "image":
                text_file.write("[IMAGE]\n")
    
    text_file.write("\n" + "=" * 60 + "\n")
    text_file.write(f"Total messages: {len(conversation_history)}\n")
    text_file.write("=" * 60 + "\n")

print(f"✅ Full conversation saved to: {output_path}")
print(f"📊 File size: {output_path.stat().st_size} bytes")
print(f"💬 Total messages in conversation: {len(conversation_history)}")