# InternVL3-8B: 2-Turn Balance-Description Bank Statement Extraction

**Protocol**: Two independent single-turn prompts + Python parsing/filtering

**Key Insight**: Balance-description prompt works for BOTH date-per-row AND date-grouped formats!

---

## Complete Workflow

```
Turn 0: Image + Prompt ‚Üí Headers (fresh context)
        ‚Üì (Python pattern matching)
        ‚Üì (Check if Balance column exists)
Turn 1: Image + Prompt ‚Üí Balance-Description extraction (fresh context)
        ‚Üì (Python parsing + filtering)
Schema Fields: TRANSACTION_DATES, LINE_ITEM_DESCRIPTIONS, TRANSACTION_AMOUNTS_PAID
```

### Why Balance-Description Works:
- **Anchors extraction to Balance column** - unambiguous reference point
- **Works for date-per-row**: Each transaction gets its date
- **Works for date-grouped**: Date headers naturally map to transactions
- **No format classification needed** - eliminates Turn 0.5 entirely!

In [None]:
# Cell 1: Imports and Configuration

from pathlib import Path
import random
import re
import math

import numpy as np
import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
from transformers import AutoModel, AutoTokenizer, AutoConfig, BitsAndBytesConfig
from IPython.display import display, Markdown, HTML

# Set Random Seed for Reproducibility

In [None]:
# Cell 2: Set random seed

from common.reproducibility import set_seed
set_seed(42)

# Load the model

In [None]:
# Cell 3: Load InternVL3-8B model with memory-aware loading strategy

MODEL_PATH = "/home/jovyan/nfs_share/models/InternVL3-8B"
MAX_TILES = 14  # V100 optimized

# Image preprocessing constants
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    """Build image transformation pipeline."""
    return T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    """Find closest aspect ratio from target ratios."""
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=MAX_TILES, image_size=448, use_thumbnail=False):
    """Dynamically preprocess image by splitting into tiles."""
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1)
        for i in range(1, n + 1) for j in range(1, n + 1)
        if i * j <= max_num and i * j >= min_num
    )
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size
    )
    
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
    
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    
    return processed_images

def load_image(image_file, input_size=448, max_num=MAX_TILES):
    """Load and preprocess image for InternVL3."""
    if isinstance(image_file, str):
        image = Image.open(image_file).convert('RGB')
    else:
        image = image_file
    
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(img) for img in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

def split_model(model_path):
    """Official InternVL3 multi-GPU device mapping."""
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for _ in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
    
    return device_map

print("üîß Loading InternVL3-8B model...")

world_size = torch.cuda.device_count()
print(f"  Detected {world_size} GPU(s)")

# Memory-aware loading
if world_size > 1:
    print("  Using multi-GPU bfloat16 mode")
    device_map = split_model(MODEL_PATH)
    model = AutoModel.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        use_flash_attn=False,
        trust_remote_code=True,
        device_map=device_map,
    ).eval()
    model_dtype = torch.bfloat16
else:
    print("  Using single-GPU 8-bit quantization mode")
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_enable_fp32_cpu_offload=False
    )
    model = AutoModel.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        use_flash_attn=False,
        trust_remote_code=True,
        quantization_config=quantization_config,
        device_map={"":0},
    ).eval()
    model_dtype = torch.float16

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, use_fast=False)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"‚úÖ Model loaded successfully!")
print(f"  Data type: {model_dtype}")
print(f"  Max Tiles: {MAX_TILES}")

# Load the image

In [None]:
# Cell 4: Load bank statement image

imageName = "/home/jovyan/nfs_share/tod/LMM_POC/evaluation_data/cba_date_grouped.png"

print("üìÅ Loading image...")
image = Image.open(imageName).convert('RGB')

print(f"‚úÖ Image loaded: {image.size}")
print("üñºÔ∏è  Bank statement image:")
display(image)

# Bank Statement Extraction Protocol (2-Turn Balance-Description)
- Turn 0: Identify actual table headers
- Turn 1: Extract using balance-description prompt
- Python: Parse, filter, and extract schema fields

In [None]:
# Cell 5: Turn 0 - Identify table headers

turn0_prompt = """Look at the transaction table in this bank statement image.

What are the exact column header names used in the transaction table?

List each column header exactly as it appears, in order from left to right.
Do not interpret or rename them - use the EXACT text from the image.
"""

print("üí¨ TURN 0: Identifying actual table headers")
print("ü§ñ Generating response with InternVL3-8B...")

# Preprocess image
pixel_values = load_image(imageName, input_size=448)
pixel_values = pixel_values.to(dtype=model_dtype, device='cuda:0')

# Generate response using chat method
turn0_response = model.chat(
    tokenizer=tokenizer,
    pixel_values=pixel_values,
    question=turn0_prompt,
    generation_config={'max_new_tokens': 500, 'do_sample': False}
)

# Free memory
del pixel_values
torch.cuda.empty_cache()

print("‚úÖ Response generated successfully!")
print("\n" + "=" * 60)
print("TURN 0 - IDENTIFIED TABLE HEADERS:")
print("=" * 60)
print(turn0_response)
print("=" * 60)

In [None]:
# Cell 6: Parse headers from Turn 0 response

def parse_headers_from_response(response_text):
    """Parse column headers from Turn 0 response."""
    header_lines = [line.strip() for line in response_text.split('\n') if line.strip()]
    identified_headers = []
    
    for line in header_lines:
        cleaned = line.lstrip('0123456789.-‚Ä¢* ').strip()
        cleaned = cleaned.replace('**', '').replace('__', '')
        if cleaned.endswith(':'):
            continue
        if len(cleaned) > 40:
            continue
        if cleaned and len(cleaned) > 2:
            identified_headers.append(cleaned)
    
    return identified_headers

table_headers = parse_headers_from_response(turn0_response)

print(f"\nüìã Parsed {len(table_headers)} column headers:")
for i, header in enumerate(table_headers, 1):
    print(f"  {i}. '{header}'")

print(f"\n‚úÖ Stored table_headers: {table_headers}")

## Pattern Matching: Map Generic Concepts to Actual Headers

In [None]:
# Cell 7: Pattern Matching

DATE_PATTERNS = ['date', 'day', 'transaction date', 'trans date']
DESCRIPTION_PATTERNS = [
    'description', 'details', 'transaction details', 'trans details',
    'particulars', 'narrative', 'transaction', 'trans'
]
DEBIT_PATTERNS = ['debit', 'debits', 'withdrawal', 'withdrawals', 'paid', 'paid out', 'spent', 'dr']
CREDIT_PATTERNS = ['credit', 'credits', 'deposit', 'deposits', 'received', 'cr']
BALANCE_PATTERNS = ['balance', 'bal', 'running balance']
AMOUNT_PATTERNS = ['amount', 'amt', 'value', 'total']

def match_header(headers, patterns, fallback=None):
    """Match a header using pattern keywords."""
    headers_lower = [h.lower() for h in headers]
    
    for pattern in patterns:
        for i, header_lower in enumerate(headers_lower):
            if pattern == header_lower:
                return headers[i]
    
    for pattern in patterns:
        if len(pattern) > 2:
            for i, header_lower in enumerate(headers_lower):
                if pattern in header_lower:
                    return headers[i]
    
    return fallback

# Perform pattern matching
date_col = match_header(table_headers, DATE_PATTERNS, fallback=table_headers[0] if table_headers else 'Date')
desc_col = match_header(table_headers, DESCRIPTION_PATTERNS, fallback=table_headers[1] if len(table_headers) > 1 else 'Description')
amount_col = match_header(table_headers, AMOUNT_PATTERNS, fallback=None)
debit_col = match_header(table_headers, DEBIT_PATTERNS, fallback=amount_col if amount_col else 'Debit')
credit_col = match_header(table_headers, CREDIT_PATTERNS, fallback=amount_col if amount_col else 'Credit')
balance_col = match_header(table_headers, BALANCE_PATTERNS, fallback=None)

print("=" * 60)
print("PATTERN MATCHING RESULTS:")
print("=" * 60)
print(f"üìã Extracted Headers: {table_headers}")
print(f"\nüîç Mapped Columns:")
print(f"  Date        ‚Üí '{date_col}'")
print(f"  Description ‚Üí '{desc_col}'")
print(f"  Debit       ‚Üí '{debit_col}'")
print(f"  Credit      ‚Üí '{credit_col}'")
print(f"  Balance     ‚Üí '{balance_col}'")

has_balance = balance_col is not None and balance_col in table_headers
print(f"\nüéØ Balance column detected: {'‚úÖ YES' if has_balance else '‚ùå NO'}")

if not has_balance:
    print("‚ö†Ô∏è  WARNING: No balance column found.")

## Turn 1: Balance-Description Extraction

In [None]:
# Cell 8: Generate extraction prompt

if has_balance:
    extraction_prompt = f"""List all the balances in the {balance_col} column, including:
- Date from the Date Header of the balance
- {desc_col}
- {debit_col} Amount or "NOT_FOUND"
- {credit_col} Amount or "NOT_FOUND" """
    
    print("üìù TURN 1: Balance-Description Extraction")
    print("=" * 60)
    print("Extraction Prompt:")
    print(extraction_prompt)
    print("=" * 60)
else:
    print("‚ùå Cannot proceed - no balance column detected.")
    extraction_prompt = None

In [None]:
# Cell 9: Execute Turn 1 extraction

if extraction_prompt:
    print("ü§ñ Generating response with InternVL3-8B...")
    
    # Reload image for fresh context
    pixel_values = load_image(imageName, input_size=448)
    pixel_values = pixel_values.to(dtype=model_dtype, device='cuda:0')
    
    extraction_response = model.chat(
        tokenizer=tokenizer,
        pixel_values=pixel_values,
        question=extraction_prompt,
        generation_config={'max_new_tokens': 4096, 'do_sample': False}
    )
    
    # Free memory
    del pixel_values
    torch.cuda.empty_cache()
    
    print("\n‚úÖ Turn 1 extraction complete!")
    print(f"\nüìä Response length: {len(extraction_response)} characters")
    print("\n" + "=" * 60)
    print("TURN 1 - BALANCE-DESCRIPTION EXTRACTION:")
    print("=" * 60)
    print(extraction_response)
    print("=" * 60)
else:
    extraction_response = None

## Python Parsing: Balance-Description Response

In [None]:
# Cell 10: Parse balance-description response

def parse_balance_description_response(response_text, date_col, desc_col, debit_col, credit_col, balance_col):
    """Parse the hierarchical balance-description response into transaction rows."""
    rows = []
    current_date = None
    current_transaction = {}
    
    lines = response_text.strip().split("\n")
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
        
        # Check for date header patterns
        date_match = re.match(r"^\d+\.\s*\*?\*?([A-Za-z]{3}\s+\d{1,2}\s+[A-Za-z]{3}\s+\d{4})\*?\*?", line)
        if not date_match:
            date_match = re.match(r"^\d+\.\s*\*?\*?(\d{1,2}\s+[A-Za-z]{3}\s+\d{4})\*?\*?", line)
        if not date_match:
            date_match = re.match(r"^\d+\.\s*\*?\*?(\d{1,2}/\d{1,2}/\d{4})\*?\*?", line)
        
        if date_match:
            if current_transaction and current_date:
                current_transaction[date_col] = current_date
                rows.append(current_transaction)
                current_transaction = {}
            
            current_date = date_match.group(1).strip()
            continue
        
        # Check for field lines
        field_match = re.match(r"^\s*-\s*(\w+):\s*(.+)$", line)
        if field_match:
            field_name = field_match.group(1).strip().lower()
            field_value = field_match.group(2).strip()
            
            if field_name == "description":
                if desc_col in current_transaction and current_transaction[desc_col]:
                    if current_date:
                        current_transaction[date_col] = current_date
                    rows.append(current_transaction)
                    current_transaction = {}
                current_transaction[desc_col] = field_value
            
            elif field_name == "debit" or field_name == debit_col.lower() or field_name == "withdrawal":
                current_transaction[debit_col] = field_value
            
            elif field_name == "credit" or field_name == credit_col.lower() or field_name == "deposit":
                current_transaction[credit_col] = field_value
            
            elif field_name == "balance":
                current_transaction[balance_col] = field_value
            
            elif field_name == "amount":
                if debit_col not in current_transaction:
                    current_transaction[debit_col] = field_value
    
    if current_transaction and current_date:
        current_transaction[date_col] = current_date
        rows.append(current_transaction)
    
    return rows


def parse_markdown_table(markdown_text):
    """Fallback: Parse markdown table."""
    lines = [line.strip() for line in markdown_text.strip().split('\n') if line.strip()]
    
    header_idx = None
    for i, line in enumerate(lines):
        if '|' in line:
            cleaned = line.replace('|', '').replace('-', '').replace(' ', '')
            if cleaned:
                header_idx = i
                break
    
    if header_idx is None:
        return []
    
    header_line = lines[header_idx]
    header_parts = [h.strip() for h in header_line.split('|')]
    if header_parts and header_parts[0] == '':
        header_parts = header_parts[1:]
    if header_parts and header_parts[-1] == '':
        header_parts = header_parts[:-1]
    headers = [h for h in header_parts if h]
    
    rows = []
    for line in lines[header_idx + 1:]:
        if '|' not in line:
            continue
        cleaned = line.replace('|', '').replace('-', '').replace(' ', '').replace(':', '')
        if not cleaned:
            continue
        value_parts = [v.strip() for v in line.split('|')]
        if value_parts and value_parts[0] == '':
            value_parts = value_parts[1:]
        if value_parts and value_parts[-1] == '':
            value_parts = value_parts[:-1]
        if len(value_parts) == len(headers):
            rows.append(dict(zip(headers, value_parts)))
    
    return rows


# Parse the extraction response
if extraction_response:
    all_rows = parse_balance_description_response(
        extraction_response, date_col, desc_col, debit_col, credit_col, balance_col
    )
    
    if not all_rows and "|" in extraction_response:
        print("‚ö†Ô∏è  Fallback: parsing as markdown table")
        all_rows = parse_markdown_table(extraction_response)
    
    print(f"\nüìä Parsed {len(all_rows)} total rows")
    
    print("\n" + "=" * 60)
    print("PARSED TRANSACTIONS:")
    print("=" * 60)
    for i, row in enumerate(all_rows[:10]):
        print(f"\n{i+1}. {row}")
    if len(all_rows) > 10:
        print(f"\n... and {len(all_rows) - 10} more rows")
else:
    all_rows = []

## Filter for Debit Transactions

In [None]:
# Cell 11: Filter for debit transactions

def parse_amount(value):
    """Extract numeric value from formatted currency string."""
    if not value or value.strip() == "":
        return 0.0
    cleaned = value.replace("$", "").replace(",", "").replace("CR", "").replace("DR", "").strip()
    try:
        return float(cleaned)
    except ValueError:
        return 0.0


def is_non_transaction_row(row, desc_col):
    """Check if this row is NOT an actual transaction."""
    desc = row.get(desc_col, "").strip().upper()
    return any(x in desc for x in ["OPENING BALANCE", "CLOSING BALANCE", "BROUGHT FORWARD", "CARRIED FORWARD"])


def filter_debit_transactions(rows, debit_col, desc_col=None):
    """Filter rows to only those with actual debit transactions."""
    debit_rows = []
    for row in rows:
        debit_value = row.get(debit_col, "").strip()
        
        if not debit_value or debit_value.upper() == "NOT_FOUND":
            continue
        
        amount = parse_amount(debit_value)
        if amount <= 0:
            continue
        
        if desc_col and is_non_transaction_row(row, desc_col):
            continue
        
        debit_rows.append(row)
    
    return debit_rows


debit_rows = filter_debit_transactions(all_rows, debit_col, desc_col)

print(f"\nüìä Filtered to {len(debit_rows)} debit transactions")
print("\n" + "=" * 60)
print("DEBIT TRANSACTIONS ONLY:")
print("=" * 60)
for i, row in enumerate(debit_rows):
    date = row.get(date_col, "N/A")
    desc = row.get(desc_col, "N/A")
    amount = row.get(debit_col, "N/A")
    print(f"{i+1}. [{date}] {desc} - {amount}")

## Extract Schema Fields

In [None]:
# Cell 12: Extract schema fields

def extract_schema_fields(debit_rows, date_col, desc_col, debit_col, all_rows=None):
    """Extract fields in universal.yaml schema format."""
    if not debit_rows:
        return {
            "DOCUMENT_TYPE": "BANK_STATEMENT",
            "STATEMENT_DATE_RANGE": "NOT_FOUND",
            "TRANSACTION_DATES": "NOT_FOUND",
            "LINE_ITEM_DESCRIPTIONS": "NOT_FOUND",
            "TRANSACTION_AMOUNTS_PAID": "NOT_FOUND",
        }
    
    debit_dates = [row.get(date_col, "").strip() for row in debit_rows if row.get(date_col)]
    descriptions = [row.get(desc_col, "").strip() for row in debit_rows if row.get(desc_col)]
    amounts = [row.get(debit_col, "").strip() for row in debit_rows if row.get(debit_col)]
    
    rows_for_range = all_rows if all_rows is not None else debit_rows
    all_dates = [row.get(date_col, "").strip() for row in rows_for_range if row.get(date_col)]
    date_range = f"{all_dates[0]} - {all_dates[-1]}" if all_dates else "NOT_FOUND"
    
    return {
        "DOCUMENT_TYPE": "BANK_STATEMENT",
        "STATEMENT_DATE_RANGE": date_range,
        "TRANSACTION_DATES": " | ".join(debit_dates) if debit_dates else "NOT_FOUND",
        "LINE_ITEM_DESCRIPTIONS": " | ".join(descriptions) if descriptions else "NOT_FOUND",
        "TRANSACTION_AMOUNTS_PAID": " | ".join(amounts) if amounts else "NOT_FOUND",
    }


schema_fields = extract_schema_fields(debit_rows, date_col, desc_col, debit_col, all_rows=all_rows)

print("\n" + "=" * 60)
print("EXTRACTED SCHEMA FIELDS:")
print("=" * 60)
for field, value in schema_fields.items():
    display_value = str(value)[:100] + "..." if len(str(value)) > 100 else str(value)
    print(f"\n{field}:")
    print(f"  {display_value}")

In [None]:
# Cell 13: Summary

print("\n" + "=" * 60)
print("üìä EXTRACTION SUMMARY")
print("=" * 60)
print(f"\nüîß Method: 2-Turn Balance-Description")
print(f"üìã Headers detected: {len(table_headers)}")
print(f"üí∞ Balance column: {balance_col}")
print(f"üìù Total transactions parsed: {len(all_rows)}")
print(f"üí∏ Debit transactions: {len(debit_rows)}")
print(f"\n‚úÖ Pipeline complete!")