In [None]:
# Key-value extraction with InternVL3 (from official docs)
from pathlib import Path
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import torchvision.transforms as T

# Use path from model_comparison.yaml
model_id = "/home/jovyan/nfs_share/models/InternVL3-2B"
imageName = "/home/jovyan/nfs_share/tod/datasets/synthetic_invoice_014.png"

print("🔧 Loading InternVL3 model for key-value extraction...")

# Load model with official recommended settings
model = AutoModel.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,  # Key: bfloat16, not float16!
    low_cpu_mem_usage=True,
    trust_remote_code=True
).eval().cuda()

# Load tokenizer with official settings
tokenizer = AutoTokenizer.from_pretrained(
    model_id, 
    trust_remote_code=True, 
    use_fast=False  # Important for InternVL3
)

print("✅ Model and tokenizer loaded successfully")

# Load image
image = Image.open(imageName)
print(f"📷 Image loaded: {image.size}")

In [None]:
# Simple image processing (from official InternVL3 docs)
def load_image(image, input_size=448):
    """Simple image preprocessing following official InternVL3 docs"""
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])
    return transform(image).unsqueeze(0).to(torch.bfloat16).cuda()

# Process image
print("🖼️  Processing image...")
pixel_values = load_image(image)
print(f"✅ Image processed: {pixel_values.shape}")

In [None]:
# Key-value extraction prompt from model_comparison.yaml with InternVL3 format
extraction_prompt = """Extract data from this business document. 
Output ALL fields below with their exact keys. 
Use "N/A" if field is not visible or not present.

REQUIRED OUTPUT FORMAT (output ALL lines exactly as shown):
DOCUMENT_TYPE: [value or N/A]
SUPPLIER: [value or N/A]
ABN: [11-digit Australian Business Number or N/A]
PAYER_NAME: [value or N/A]
PAYER_ADDRESS: [value or N/A]
PAYER_PHONE: [value or N/A]
PAYER_EMAIL: [value or N/A]
INVOICE_DATE: [value or N/A]
DUE_DATE: [value or N/A]
GST: [GST amount in dollars or N/A]
TOTAL: [total amount in dollars or N/A]
SUBTOTAL: [subtotal amount in dollars or N/A]
SUPPLIER_WEBSITE: [value or N/A]
QUANTITIES: [list of quantities or N/A]
PRICES: [individual prices in dollars or N/A]
BUSINESS_ADDRESS: [value or N/A]
BUSINESS_PHONE: [value or N/A]
BANK_NAME: [bank name from bank statements only or N/A]
BSB_NUMBER: [6-digit BSB from bank statements only or N/A]
BANK_ACCOUNT_NUMBER: [account number from bank statements only or N/A]
ACCOUNT_HOLDER: [value or N/A]
STATEMENT_PERIOD: [value or N/A]
OPENING_BALANCE: [opening balance amount in dollars or N/A]
CLOSING_BALANCE: [closing balance amount in dollars or N/A]
DESCRIPTIONS: [list of transaction descriptions or N/A]

CRITICAL: Output in PLAIN TEXT format only. Do NOT use markdown formatting.

CORRECT format: DOCUMENT_TYPE: TAX INVOICE
WRONG format: **DOCUMENT_TYPE:** TAX INVOICE
WRONG format: **DOCUMENT_TYPE: TAX INVOICE**
WRONG format: DOCUMENT_TYPE: **TAX INVOICE**

Use exactly: KEY: value (with colon and space)
Never use: **KEY:** or **KEY** or any asterisks
Never use bold, italic, or any markdown formatting

ABSOLUTELY CRITICAL: Output EXACTLY 25 lines using ONLY the keys listed above. 
Do NOT add extra fields like \"Balance\", \"Credit\", \"Debit\", \"Date\", \"Description\".
Do NOT include ANY fields not in the required list above.
Include ALL 25 keys listed above even if value is N/A.
STOP after exactly 25 lines."""

# InternVL3 format: <image>\n + prompt
question = f'<image>\n{extraction_prompt}'
print("📋 Using structured key-value extraction prompt")
print(f"📄 Prompt length: {len(extraction_prompt)} characters")

In [None]:
# Generation config for structured output
generation_config = dict(
    max_new_tokens=1000,  # Reduced for structured output
    do_sample=False,      # Deterministic for consistent field extraction
    temperature=0.1       # Low temperature for structured output
)

# Generate response using simple official API
print("🤖 Generating key-value extraction with InternVL3...")
try:
    response = model.chat(tokenizer, pixel_values, question, generation_config)
    print("✅ Key-value extraction completed!")
    print("\n" + "="*60)
    print("EXTRACTED FIELDS:")
    print("="*60)
    print(response)
    print("="*60)
    
except Exception as e:
    print(f"❌ Error during inference: {e}")
    print(f"Error type: {type(e).__name__}")
    import traceback
    traceback.print_exc()

In [None]:
# Save key-value extraction results to file
output_path = Path("/home/jovyan/nfs_share/tod/output/internvl3_keyvalue_output.txt")

try:
    # Ensure output directory exists
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Write response to file
    with output_path.open("w", encoding="utf-8") as text_file:
        text_file.write(response)
    
    print(f"✅ Key-value extraction results saved to: {output_path}")
    print(f"📄 File size: {output_path.stat().st_size} bytes")
    
    # Count extracted fields (simple parsing)
    lines = response.split('\n')
    field_lines = [line for line in lines if ':' in line and not line.strip().startswith('<')]
    print(f"📊 Extracted {len(field_lines)} field lines")
    
except NameError:
    print("❌ Error: 'response' variable not defined.")
    print("💡 Please run Cell [4] first to generate the response.")
    
except Exception as e:
    print(f"❌ Error saving file: {e}")
    print(f"💡 Check if directory exists: {output_path.parent}")