In [None]:
from pathlib import Path

import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration

In [None]:
model_id = "/home/jovyan/nfs_share/models/Llama-3.2-11B-Vision-Instruct"
# here, specify the name of the image
imageName = "/home/jovyan/nfs_share/tod/datasets/synthetic_invoice_014.png"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

# open the image
image = Image.open(imageName)

In [None]:
# Key-value extraction prompt from model_comparison.yaml
extraction_prompt = """Extract data from this business document. 
Output ALL fields below with their exact keys. 
DO NOT USE "*".
Use "N/A" if field is not visible or not present.

REQUIRED OUTPUT FORMAT (output ALL lines exactly as shown):
DOCUMENT_TYPE: [value or N/A]
SUPPLIER: [value or N/A]
ABN: [11-digit Australian Business Number or N/A]
PAYER_NAME: [value or N/A]
PAYER_ADDRESS: [value or N/A]
PAYER_PHONE: [value or N/A]
PAYER_EMAIL: [value or N/A]
INVOICE_DATE: [value or N/A]
DUE_DATE: [value or N/A]
GST: [GST amount in dollars or N/A]
TOTAL: [total amount in dollars or N/A]
SUBTOTAL: [subtotal amount in dollars or N/A]
SUPPLIER_WEBSITE: [value or N/A]
QUANTITIES: [list of quantities or N/A]
PRICES: [individual prices in dollars or N/A]
BUSINESS_ADDRESS: [value or N/A]
BUSINESS_PHONE: [value or N/A]
BANK_NAME: [bank name from bank statements only or N/A]
BSB_NUMBER: [6-digit BSB from bank statements only or N/A]
BANK_ACCOUNT_NUMBER: [account number from bank statements only or N/A]
ACCOUNT_HOLDER: [value or N/A]
STATEMENT_PERIOD: [value or N/A]
OPENING_BALANCE: [opening balance amount in dollars or N/A]
CLOSING_BALANCE: [closing balance amount in dollars or N/A]
DESCRIPTIONS: [list of transaction descriptions or N/A]

CRITICAL: Output in PLAIN TEXT format only. Do NOT use markdown formatting.

CORRECT format: DOCUMENT_TYPE: TAX INVOICE
WRONG format: **DOCUMENT_TYPE:** TAX INVOICE
WRONG format: **DOCUMENT_TYPE: TAX INVOICE**
WRONG format: DOCUMENT_TYPE: **TAX INVOICE**

Use exactly: KEY: value (with colon and space)
Never use: **KEY:** or **KEY** or any asterisks
Never use bold, italic, or any markdown formatting

ABSOLUTELY CRITICAL: Output EXACTLY 25 lines using ONLY the keys listed above. 
Do NOT add extra fields like \"Balance\", \"Credit\", \"Debit\", \"Date\", \"Description\".
Do NOT include ANY fields not in the required list above.
Include ALL 25 keys listed above even if value is N/A.
STOP after exactly 25 lines."""

print("📋 Using structured key-value extraction prompt")
print(f"📄 Prompt length: {len(extraction_prompt)} characters")

In [None]:
# Create message structure for key-value extraction
messageDataStructure = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {
                "type": "text",
                "text": extraction_prompt,
            },
        ],
    }
]

# create text input
textInput = processor.apply_chat_template(
    messageDataStructure, add_generation_prompt=True
)
# call the processor
inputs = processor(image, textInput, return_tensors="pt").to(model.device)

print("🔧 Processing with Llama-3.2-Vision for key-value extraction...")

# Generate with appropriate token limit for structured output
output = model.generate(**inputs, max_new_tokens=1000)  # Reduced for structured output
# decode the response
generatedOutput = processor.decode(output[0])

print("✅ Key-value extraction completed!")
print("\n" + "="*60)
print("EXTRACTED FIELDS:")
print("="*60)
print(generatedOutput)
print("="*60)

In [None]:
# save the extracted fields to file
output_path = Path("/home/jovyan/nfs_share/tod/output/llama_keyvalue_output.txt")

# Ensure output directory exists
output_path.parent.mkdir(parents=True, exist_ok=True)

with output_path.open("w", encoding="utf-8") as text_file:
    text_file.write(generatedOutput)

print(f"✅ Key-value extraction results saved to: {output_path}")
print(f"📄 File size: {output_path.stat().st_size} bytes")

# Count extracted fields (simple parsing)
lines = generatedOutput.split('\n')
field_lines = [line for line in lines if ':' in line and not line.strip().startswith('<')]
print(f"📊 Extracted {len(field_lines)} field lines")