## Med-VLM Interpret Evaluation

### Load libraries

In [1]:
import os
import sys
import json
import time
import random
import hashlib
import argparse
import subprocess
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, asdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from scipy.spatial.distance import jensenshannon

### Setup Environment

In [11]:
def setup_colab_environment():
    """Setup the Google Colab environment"""
    print("Setting up Google Colab environment...")

    # Clone repository if not exists
    if not os.path.exists('/content/medical-vlm-intepret'):
        print("Cloning repository...")
        subprocess.run(['git', 'clone', 'https://github.com/thedatasense/medical-vlm-intepret.git'],
                      cwd='/content', check=True)

    # Change to repo directory
    os.chdir('/content/medical-vlm-intepret/attention_viz')

    # Install dependencies
    print("Installing dependencies...")
    packages = [
        'torch', 'torchvision', 'transformers>=4.36.0',
        'opencv-python', 'scipy', 'matplotlib', 'pillow',
        'bitsandbytes', 'accelerate', 'gradio'
    ]

    for package in packages:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', package])

    # Install LLaVA
    if not os.path.exists('/content/LLaVA'):
        print("Installing LLaVA...")
        subprocess.run(['git', 'clone', 'https://github.com/haotian-liu/LLaVA.git'],
                      cwd='/content', check=True)
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-e', '.'],
                      cwd='/content/LLaVA', check=True)

    # Add to Python path
    sys.path.insert(0, '/content/LLaVA')
    sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

    print("Environment setup complete!")
    return True


In [2]:
def setup_colab_environment():
    """Setup the Google Colab environment"""
    import os
    import sys
    import subprocess

    print("Setting up Google Colab environment...")

    # Clone repository if not exists
    if not os.path.exists('/content/medical-vlm-intepret'):
        print("Cloning repository...")
        result = subprocess.run(
            ['git', 'clone', 'https://github.com/thedatasense/medical-vlm-intepret.git'],
            cwd='/content',
            capture_output=True,
            text=True
        )
        if result.returncode != 0:
            print(f"Error cloning repo: {result.stderr}")
            return False

    # Change to repo directory
    os.chdir('/content/medical-vlm-intepret/attention_viz')

    # Install dependencies
    print("Installing dependencies...")
    packages = [
        'torch', 'torchvision', 'transformers>=4.36.0',
        'opencv-python', 'scipy', 'matplotlib', 'pillow',
        'bitsandbytes', 'accelerate', 'gradio'
    ]

    for package in packages:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', package])

    # Install LLaVA
    if not os.path.exists('/content/LLaVA'):
        print("Installing LLaVA...")
        result = subprocess.run(
            ['git', 'clone', 'https://github.com/haotian-liu/LLaVA.git'],
            cwd='/content',
            capture_output=True,
            text=True
        )
        if result.returncode != 0:
            print(f"Error cloning LLaVA: {result.stderr}")
            return False

        print("Running pip install for LLaVA...")
        result = subprocess.run(
            [sys.executable, '-m', 'pip', 'install', '-e', '.'],
            cwd='/content/LLaVA',
            capture_output=True,
            text=True
        )
        if result.returncode != 0:
            print(f"Error installing LLaVA:")
            print(f"STDOUT: {result.stdout}")
            print(f"STDERR: {result.stderr}")
            return False

    # Add to Python path
    sys.path.insert(0, '/content/LLaVA')
    sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

    print("Environment setup complete!")
    return True

# Run it
setup_colab_environment()

Setting up Google Colab environment...
Cloning repository...
Installing dependencies...
Installing LLaVA...
Running pip install for LLaVA...
Error installing LLaVA:
STDOUT: Obtaining file:///content/LLaVA
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
INFO: pip is looking at multiple versions of llava to determine which version is compatible with other requirements. This could take a while.

STDERR: ERROR: Could not find a version that satisfies the requirement torch==2.1.2 (from llava) (from versions: 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2

False

In [9]:
def mount_drive_and_verify_data():
    """Mount Google Drive and verify data paths"""
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

    # Define data paths
    data_paths = {
        'data_root': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset',
        'image_root': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa',
        'csv_path': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv',
        'csv_variants_path': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv'
    }

    # Verify paths
    print("\nVerifying data paths:")
    all_exist = True
    for name, path in data_paths.items():
        exists = os.path.exists(path)
        print(f"{name}: {'✓' if exists else '✗'} {path}")
        if not exists:
            all_exist = False

    if not all_exist:
        print("\nWARNING: Some data paths do not exist. Please check your Google Drive setup.")

    return data_paths


### Load the models

In [3]:
def load_models(llava_8bit=True, medgemma_8bit=True):
    """Load both LLaVA-Rad and MedGemma models"""
    print("\nLoading models...")

    # Import after environment setup
    from llava_rad_enhanced import EnhancedLLaVARadVisualizer, AttentionConfig
    from medgemma_enhanced import load_model_enhanced

    # Load LLaVA-Rad
    print("Loading LLaVA-Rad...")
    llava_config = AttentionConfig(
        use_medical_colormap=True,
        multi_head_mode='entropy_weighted',  # Using fixed entropy weighting
        percentile_clip=(5, 95)
    )
    llava_vis = EnhancedLLaVARadVisualizer(config=llava_config)
    llava_vis.load_model(load_in_8bit=llava_8bit)

    # Load MedGemma
    print("Loading MedGemma...")
    medgemma_model, medgemma_processor = load_model_enhanced(
        model_id="google/medgemma-4b-it",
        load_in_8bit=medgemma_8bit
    )

    return llava_vis, medgemma_model, medgemma_processor


In [4]:
class InferenceResult:
    """Store inference results for a single sample"""
    study_id: str
    image_path: str
    finding: str
    variant_id: str
    question: str
    answer_gt: str
    timestamp: str

    llava_answer: str
    llava_correct: bool
    llava_latency_ms: float
    llava_attention_method: str
    llava_focus_score: float
    llava_sparsity: float

    medgemma_answer: str
    medgemma_correct: bool
    medgemma_latency_ms: float
    medgemma_attention_method: str
    medgemma_focus_score: float
    medgemma_sparsity: float

    js_divergence: Optional[float] = None



In [5]:
def process_single_sample(
    image_path: str,
    question: str,
    answer_gt: str,
    llava_vis,
    medgemma_model,
    medgemma_processor,
    medgemma_extractor,
    study_info: Dict[str, Any]
) -> InferenceResult:
    """Process a single image-question pair through both models"""

    from llava_rad_enhanced import AttentionMetrics

    # Load image
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

    image_pil = Image.open(image_path).convert('RGB')

    # Debug: Check image size
    print(f"  Processing image: {os.path.basename(image_path)} ({image_pil.size})")

    # LLaVA-Rad inference
    start_time = time.time()
    llava_result = llava_vis.generate_with_attention(
        image_path,
        question,
        max_new_tokens=50,
        use_cache=False
    )
    llava_latency = (time.time() - start_time) * 1000

    # Extract LLaVA attention
    llava_attention = llava_result.get('visual_attention')
    if isinstance(llava_attention, list):
        llava_attention = np.mean(np.stack(llava_attention), axis=0)

    # LLaVA metrics
    llava_metrics = AttentionMetrics.calculate_focus_score(llava_attention)
    llava_metrics['sparsity'] = AttentionMetrics.calculate_sparsity(llava_attention)

    # MedGemma inference
    # Build prompt with chat template to ensure image tokens are inserted
    if hasattr(medgemma_processor, 'apply_chat_template'):
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": question},
                ],
            }
        ]
        prompt = medgemma_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = medgemma_processor(
            text=prompt,
            images=image_pil,
            return_tensors="pt",
        )
    else:
        # Fallback: explicit <image> token
        inputs = medgemma_processor(
            text=f"<image>{question}",
            images=image_pil,
            return_tensors="pt",
        )

    device = next(medgemma_model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    start_time = time.time()
    with torch.no_grad():
        outputs = medgemma_model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            output_attentions=True,
            return_dict_in_generate=True
        )
    medgemma_latency = (time.time() - start_time) * 1000

    # Extract MedGemma answer
    medgemma_answer = medgemma_processor.tokenizer.decode(
        outputs.sequences[0], skip_special_tokens=True
    )

    # Handle Gemma-3 format
    if '<start_of_turn>model' in medgemma_answer:
        # Extract text after model turn
        parts = medgemma_answer.split('<start_of_turn>model')
        if len(parts) > 1:
            medgemma_answer = parts[-1]
            # Remove end_of_turn if present
            medgemma_answer = medgemma_answer.split('<end_of_turn>')[0]
    elif '<start_of_turn>assistant' in medgemma_answer:
        parts = medgemma_answer.split('<start_of_turn>assistant')
        if len(parts) > 1:
            medgemma_answer = parts[-1]
            medgemma_answer = medgemma_answer.split('<end_of_turn>')[0]
    else:
        # Fallback: split by common patterns
        medgemma_answer = medgemma_answer.split("Assistant:")[-1]
        medgemma_answer = medgemma_answer.split("model\n")[-1]

    medgemma_answer = medgemma_answer.strip()

    # Extract MedGemma attention
    target_words = ["pneumonia", "consolidation", "opacity", "finding", "abnormal"]
    medgemma_attention, _, medgemma_method = medgemma_extractor.extract_token_conditioned_attention_robust(
        medgemma_model, medgemma_processor, outputs,
        target_words, image_pil, question
    )

    # MedGemma metrics
    medgemma_metrics = AttentionMetrics.calculate_focus_score(medgemma_attention)
    medgemma_metrics['sparsity'] = AttentionMetrics.calculate_sparsity(medgemma_attention)

    # Calculate JS divergence between attention maps
    llava_flat = llava_attention.flatten()
    medgemma_flat = medgemma_attention.flatten()
    llava_norm = llava_flat / (llava_flat.sum() + 1e-10)
    medgemma_norm = medgemma_flat / (medgemma_flat.sum() + 1e-10)
    js_div = float(jensenshannon(llava_norm, medgemma_norm))

    # Check correctness (basic yes/no matching)
    llava_correct = llava_result['answer'].lower().strip().startswith(answer_gt.lower().strip()[:3])
    medgemma_correct = medgemma_answer.lower().strip().startswith(answer_gt.lower().strip()[:3])

    return InferenceResult(
        study_id=study_info['study_id'],
        image_path=study_info['image_path'],
        finding=study_info['finding'],
        variant_id=study_info['variant_id'],
        question=question,
        answer_gt=answer_gt,
        timestamp=datetime.now().isoformat(),

        llava_answer=llava_result['answer'],
        llava_correct=llava_correct,
        llava_latency_ms=llava_latency,
        llava_attention_method=llava_result.get('attention_method', 'unknown'),
        llava_focus_score=llava_metrics['focus'],
        llava_sparsity=llava_metrics['sparsity'],

        medgemma_answer=medgemma_answer,
        medgemma_correct=medgemma_correct,
        medgemma_latency_ms=medgemma_latency,
        medgemma_attention_method=medgemma_method,
        medgemma_focus_score=medgemma_metrics['focus'],
        medgemma_sparsity=medgemma_metrics['sparsity'],

        js_divergence=js_div
    )



In [6]:
def run_robustness_study(
    data_paths: Dict[str, str],
    n_studies: int = 100,
    models: Optional[Tuple] = None,
    output_file: str = "robustness_results.jsonl"
) -> List[InferenceResult]:
    """Run the complete robustness study"""

    from medgemma_enhanced import EnhancedAttentionExtractor, AttentionExtractionConfig

    # Load models if not provided
    if models is None:
        llava_vis, medgemma_model, medgemma_processor = load_models()
    else:
        llava_vis, medgemma_model, medgemma_processor = models

    # Create MedGemma extractor
    medgemma_extractor = EnhancedAttentionExtractor(
        AttentionExtractionConfig(
            attention_head_reduction='entropy_weighted',
            fallback_chain=['cross_attention', 'gradcam', 'uniform']
        )
    )

    # Load data
    csv_path = data_paths.get('csv_variants_path', data_paths['csv_path'])
    df = pd.read_csv(csv_path)

    # Get unique studies
    unique_studies = df['study_id'].unique()[:n_studies]
    print(f"\nProcessing {len(unique_studies)} studies...")

    results = []

    for i, study_id in enumerate(unique_studies):
        study_rows = df[df['study_id'] == study_id]
        print(f"\n[{i+1}/{len(unique_studies)}] Processing study {study_id} ({len(study_rows)} variants)")

        for _, row in study_rows.iterrows():
            try:
                image_path = os.path.join(data_paths['image_root'], row['image_path'])

                study_info = {
                    'study_id': str(study_id),
                    'image_path': row['image_path'],
                    'finding': row.get('finding', 'unknown'),
                    'variant_id': row.get('question_variant', row.get('variant_id', 'base'))
                }

                result = process_single_sample(
                    image_path=image_path,
                    question=row['question'],
                    answer_gt=row.get('answer', 'unknown'),
                    llava_vis=llava_vis,
                    medgemma_model=medgemma_model,
                    medgemma_processor=medgemma_processor,
                    medgemma_extractor=medgemma_extractor,
                    study_info=study_info
                )

                results.append(result)

                # Save incrementally
                if len(results) % 10 == 0:
                    save_results(results, output_file)

            except Exception as e:
                print(f"  Error processing {row['image_path']}: {e}")
                continue

    # Final save
    save_results(results, output_file)
    print(f"\nCompleted {len(results)} inferences")

    return results


def save_results(results: List[InferenceResult], output_file: str):
    """Save results to JSONL file"""
    with open(output_file, 'w') as f:
        for r in results:
            f.write(json.dumps(asdict(r)) + '\n')


In [8]:
def save_results(results: List[InferenceResult], output_file: str):
    """Save results to JSONL file"""
    with open(output_file, 'w') as f:
        for r in results:
            f.write(json.dumps(asdict(r)) + '\n')

In [7]:
def analyze_results(results: List[InferenceResult]) -> Dict[str, Any]:
    """Analyze the robustness study results"""

    analysis = {
        'n_samples': len(results),
        'n_studies': len(set(r.study_id for r in results)),
        'timestamp': datetime.now().isoformat()
    }

    # Model performance
    for model in ['llava', 'medgemma']:
        correct = [getattr(r, f'{model}_correct') for r in results]
        focus_scores = [getattr(r, f'{model}_focus_score') for r in results]
        sparsity_scores = [getattr(r, f'{model}_sparsity') for r in results]
        latencies = [getattr(r, f'{model}_latency_ms') for r in results]

        analysis[f'{model}_accuracy'] = np.mean(correct)
        analysis[f'{model}_focus_mean'] = np.mean(focus_scores)
        analysis[f'{model}_focus_std'] = np.std(focus_scores)
        analysis[f'{model}_sparsity_mean'] = np.mean(sparsity_scores)
        analysis[f'{model}_latency_mean'] = np.mean(latencies)

    # JS divergence between models
    js_divs = [r.js_divergence for r in results if r.js_divergence is not None]
    analysis['js_divergence_mean'] = np.mean(js_divs)
    analysis['js_divergence_std'] = np.std(js_divs)

    # Robustness analysis by study
    robustness_stats = analyze_robustness_by_study(results)
    analysis.update(robustness_stats)

    return analysis


def analyze_robustness_by_study(results: List[InferenceResult]) -> Dict[str, Any]:
    """Analyze robustness metrics grouped by study"""

    df = pd.DataFrame([asdict(r) for r in results])

    robustness = {
        'llava_flip_rate': 0,
        'medgemma_flip_rate': 0,
        'llava_consistency_rate': 0,
        'medgemma_consistency_rate': 0
    }

    # Group by study
    study_groups = df.groupby('study_id')
    n_multi_variant = 0

    for study_id, group in study_groups:
        if len(group) > 1:
            n_multi_variant += 1

            # Check answer consistency
            llava_answers = group['llava_answer'].tolist()
            medgemma_answers = group['medgemma_answer'].tolist()

            # Count flips (different from base)
            base_llava = llava_answers[0]
            base_medgemma = medgemma_answers[0]

            llava_flips = sum(1 for a in llava_answers[1:] if a != base_llava)
            medgemma_flips = sum(1 for a in medgemma_answers[1:] if a != base_medgemma)

            robustness['llava_flip_rate'] += llava_flips / (len(group) - 1)
            robustness['medgemma_flip_rate'] += medgemma_flips / (len(group) - 1)

            # Check full consistency
            if len(set(llava_answers)) == 1:
                robustness['llava_consistency_rate'] += 1
            if len(set(medgemma_answers)) == 1:
                robustness['medgemma_consistency_rate'] += 1

    # Average over studies
    if n_multi_variant > 0:
        robustness['llava_flip_rate'] /= n_multi_variant
        robustness['medgemma_flip_rate'] /= n_multi_variant
        robustness['llava_consistency_rate'] /= n_multi_variant
        robustness['medgemma_consistency_rate'] /= n_multi_variant

    return robustness


def visualize_results(results: List[InferenceResult], output_dir: str = "visualizations"):
    """Create visualizations of the results"""

    os.makedirs(output_dir, exist_ok=True)

    # 1. Accuracy comparison
    plt.figure(figsize=(10, 6))

    llava_acc = np.mean([r.llava_correct for r in results])
    medgemma_acc = np.mean([r.medgemma_correct for r in results])

    models = ['LLaVA-Rad', 'MedGemma']
    accuracies = [llava_acc, medgemma_acc]

    bars = plt.bar(models, accuracies, color=['blue', 'green'])
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy Comparison')
    plt.ylim(0, 1)

    # Add value labels
    for bar, acc in zip(bars, accuracies):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{acc:.3f}', ha='center')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'accuracy_comparison.png'))
    plt.close()

    # 2. Attention metrics comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Focus scores
    llava_focus = [r.llava_focus_score for r in results]
    medgemma_focus = [r.medgemma_focus_score for r in results]

    ax1.hist(llava_focus, bins=30, alpha=0.7, label='LLaVA-Rad', color='blue')
    ax1.hist(medgemma_focus, bins=30, alpha=0.7, label='MedGemma', color='green')
    ax1.set_xlabel('Focus Score')
    ax1.set_ylabel('Count')
    ax1.set_title('Distribution of Attention Focus Scores')
    ax1.legend()

    # Sparsity scores
    llava_sparsity = [r.llava_sparsity for r in results]
    medgemma_sparsity = [r.medgemma_sparsity for r in results]

    ax2.hist(llava_sparsity, bins=30, alpha=0.7, label='LLaVA-Rad', color='blue')
    ax2.hist(medgemma_sparsity, bins=30, alpha=0.7, label='MedGemma', color='green')
    ax2.set_xlabel('Sparsity Score')
    ax2.set_ylabel('Count')
    ax2.set_title('Distribution of Attention Sparsity Scores')
    ax2.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'attention_metrics_distribution.png'))
    plt.close()

    # 3. JS Divergence distribution
    plt.figure(figsize=(8, 6))

    js_divs = [r.js_divergence for r in results if r.js_divergence is not None]
    plt.hist(js_divs, bins=30, color='purple', alpha=0.7)
    plt.axvline(np.mean(js_divs), color='red', linestyle='--',
                label=f'Mean: {np.mean(js_divs):.3f}')
    plt.xlabel('JS Divergence')
    plt.ylabel('Count')
    plt.title('Distribution of JS Divergence Between Model Attentions')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'js_divergence_distribution.png'))
    plt.close()

    print(f"Visualizations saved to {output_dir}/")

In [12]:
print(f"\nRunning robustness study on {100} studies...")
results_file = "robustness_results.jsonl"
setup_colab_environment()
data_paths = mount_drive_and_verify_data()
models = load_models()
results = run_robustness_study(
    data_paths=data_paths,
    n_studies=100,
    models=models,
    output_file=results_file
)

# Analyze results
print("\nAnalyzing results...")
analysis = analyze_results(results)

# Save analysis
with open("analysis_results.json", 'w') as f:
    json.dump(analysis, f, indent=2)

#Create visualizations
print("Creating visualizations...")
visualize_results(results, output_dir="visualizations")


Running robustness study on 100 studies...
Setting up Google Colab environment...
Cloning repository...
Installing dependencies...
Installing LLaVA...


CalledProcessError: Command '['/usr/bin/python3', '-m', 'pip', 'install', '-e', '.']' returned non-zero exit status 1.

In [10]:
#!/usr/bin/env python3
"""
Fixed test script for Medical VLM Analysis with all corrections applied

Run this in Google Colab to verify the fixes work correctly.
"""

import os
import sys
import torch
import numpy as np
from PIL import Image

print("=== Fixed Medical VLM Test Script ===")
print("This includes all fixes for attention extraction and answer decoding\n")

# 1. Setup paths
print("1. Setting up paths...")
sys.path.insert(0, '/content/LLaVA')
sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

# 2. Import modules
print("\n2. Importing modules...")
try:
    from llava_rad_enhanced import EnhancedLLaVARadVisualizer, AttentionConfig
    from medgemma_enhanced import load_model_enhanced, EnhancedAttentionExtractor, AttentionExtractionConfig
    print("✓ Modules imported successfully")
except Exception as e:
    print(f"✗ Import error: {e}")
    exit(1)

# 3. Find test image
print("\n3. Finding test image...")
image_dir = "/content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa"
test_image_path = None
test_image = None

if os.path.exists(image_dir):
    images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')][:5]
    if images:
        test_image_path = os.path.join(image_dir, images[0])
        test_image = Image.open(test_image_path).convert('RGB')
        print(f"✓ Using image: {images[0]} (size: {test_image.size})")
else:
    print(f"✗ Image directory not found: {image_dir}")
    exit(1)

test_question = "Is there evidence of pneumonia?"

# 4. Test LLaVA-Rad
print("\n4. Testing LLaVA-Rad...")
try:
    llava_config = AttentionConfig(
        use_medical_colormap=True,
        multi_head_mode='mean',  # Using mean aggregation
        percentile_clip=(5, 95)
    )
    llava_vis = EnhancedLLaVARadVisualizer(config=llava_config)
    print("  Loading model (8-bit quantization)...")
    llava_vis.load_model(load_in_8bit=True)
    print("✓ LLaVA-Rad loaded successfully")

    # Test inference
    print(f"  Testing with question: {test_question}")
    result = llava_vis.generate_with_attention(
        test_image_path,
        test_question,
        max_new_tokens=50,
        use_cache=False
    )

    print(f"✓ Answer: {result['answer']}")
    print(f"  Attention method: {result.get('attention_method', 'unknown')}")

    # Check attention
    att = result.get('visual_attention')
    if att is not None:
        if isinstance(att, list):
            print(f"  Attention: list of {len(att)} heads, first shape: {att[0].shape}")
        else:
            print(f"  Attention shape: {att.shape}")
            print(f"  Focus score: {result.get('metrics', {}).get('focus', 'N/A'):.3f}")

except Exception as e:
    print(f"✗ LLaVA-Rad error: {e}")
    import traceback
    traceback.print_exc()

# 5. Test MedGemma/PaliGemma
print("\n5. Testing MedGemma/PaliGemma...")
try:
    print("  Loading model with eager attention...")
    medgemma_model, medgemma_processor = load_model_enhanced(
        model_id="google/paligemma-3b-mix-224",
        load_in_8bit=True
    )
    print("✓ Model loaded successfully")

    # Create extractor
    extractor = EnhancedAttentionExtractor(
        AttentionExtractionConfig(
            attention_head_reduction='mean',
            fallback_chain=['cross_attention', 'gradcam', 'uniform']
        )
    )

    # Test inference with proper prompt format
    print(f"  Testing with question: {test_question}")

    # Use chat template for proper formatting
    if hasattr(medgemma_processor, 'apply_chat_template'):
        messages = [{
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": test_question}
            ]
        }]
        prompt = medgemma_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    else:
        prompt = f"<image>{test_question}"

    print(f"  Prompt format: {repr(prompt[:50])}...")

    # Prepare inputs
    inputs = medgemma_processor(
        text=prompt,
        images=test_image,
        return_tensors="pt"
    )

    device = next(medgemma_model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate with attention
    with torch.no_grad():
        outputs = medgemma_model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            output_attentions=True,
            return_dict_in_generate=True
        )

    # Decode answer with Gemma format handling
    raw_answer = medgemma_processor.tokenizer.decode(
        outputs.sequences[0], skip_special_tokens=True
    )

    # Clean answer
    if '<start_of_turn>model' in raw_answer:
        parts = raw_answer.split('<start_of_turn>model')
        if len(parts) > 1:
            answer = parts[-1].split('<end_of_turn>')[0].strip()
        else:
            answer = raw_answer
    elif 'model\n' in raw_answer:
        answer = raw_answer.split('model\n')[-1].strip()
    else:
        answer = raw_answer.split(test_question)[-1].strip()

    print(f"✓ Answer: {answer}")

    # Test attention extraction
    attention, token_indices, method = extractor.extract_token_conditioned_attention_robust(
        medgemma_model, medgemma_processor, outputs,
        ["pneumonia", "consolidation", "opacity"], test_image, prompt
    )

    print(f"  Attention shape: {attention.shape}, method: {method}")

    # Check if we actually got cross-attention
    if method == 'cross_attention' or method == 'cross_attention_forward':
        print("✓ Cross-attention extraction successful!")
    else:
        print(f"⚠ Fell back to {method} method")
        if hasattr(outputs, 'attentions') and outputs.attentions is not None:
            print(f"  outputs.attentions exists but extraction failed")

except Exception as e:
    print(f"✗ MedGemma error: {e}")
    import traceback
    traceback.print_exc()

print("\n=== Test Complete ===")
print("\nSummary:")
print("- LLaVA-Rad: Check if answer is coherent and attention was extracted")
print("- MedGemma: Check if answer is clean (no 'model' prefix) and attention method")
print("- If both work, you can run the full pipeline with confidence!")
print("\nNext step: !python run_medical_vlm_analysis_colab.py --n_studies 5")

=== Fixed Medical VLM Test Script ===
This includes all fixes for attention extraction and answer decoding

1. Setting up paths...

2. Importing modules...
✓ Modules imported successfully

3. Finding test image...
✗ Image directory not found: /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa

4. Testing LLaVA-Rad...
  Loading model (8-bit quantization)...
Loading LLaVA from base model...


tokenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


pytorch_model.bin.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]



processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/674 [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

✓ LLaVA-Rad loaded successfully
  Testing with question: Is there evidence of pneumonia?
✗ LLaVA-Rad error: OpenCV(4.12.0) /io/opencv/modules/imgproc/src/resize.cpp:4086: error: (-215:Assertion failed) func != 0 in function 'resize'


5. Testing MedGemma/PaliGemma...
  Loading model with eager attention...


Traceback (most recent call last):
  File "/tmp/ipython-input-1277284336.py", line 65, in <cell line: 0>
    result = llava_vis.generate_with_attention(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/medical-vlm-intepret/attention_viz/llava_rad_enhanced.py", line 835, in generate_with_attention
    visual_attention = self.extract_visual_attention_multihead(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/medical-vlm-intepret/attention_viz/llava_rad_enhanced.py", line 432, in extract_visual_attention_multihead
    processed = self._process_single_attention(att_map, visual_range)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/medical-vlm-intepret/attention_viz/llava_rad_enhanced.py", line 466, in _process_single_attention
    visual_grid = cv2.resize(
                  ^^^^^^^^^^^
cv2.error: OpenCV(4.12.0) /io/opencv/modules/imgproc/src/resize.cpp:4086: error: (-215:Assertion failed) func != 0 in f

preprocessor_config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


✗ MedGemma error: Unrecognized configuration class <class 'transformers.models.paligemma.configuration_paligemma.PaliGemmaConfig'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of ApertusConfig, ArceeConfig, AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitNetConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, LlamaConfig, CodeGenConfig, CohereConfig, Cohere2Config, CpmAntConfig, CTRLConfig, Data2VecTextConfig, DbrxConfig, DeepseekV2Config, DeepseekV3Config, DiffLlamaConfig, DogeConfig, Dots1Config, ElectraConfig, Emu3Config, ErnieConfig, Ernie4_5Config, Ernie4_5_MoeConfig, Exaone4Config, FalconConfig, FalconH1Config, FalconMambaConfig, FuyuConfig, GemmaConfig, Gemma2Config, Gemma3Config, Gemma3TextConfig, Gemma3nConfig, Gemma3nTextConfig, GitConfig, GlmConfig, Glm4Config, Glm4MoeConfig, GotOcr2Config, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoCo

Traceback (most recent call last):
  File "/tmp/ipython-input-1277284336.py", line 93, in <cell line: 0>
    medgemma_model, medgemma_processor = load_model_enhanced(
                                         ^^^^^^^^^^^^^^^^^^^^
  File "/content/medical-vlm-intepret/attention_viz/medgemma_enhanced.py", line 51, in load_model_enhanced
    model = AutoModelForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/models/auto/auto_factory.py", line 607, in from_pretrained
    raise ValueError(
ValueError: Unrecognized configuration class <class 'transformers.models.paligemma.configuration_paligemma.PaliGemmaConfig'> for this kind of AutoModel: AutoModelForCausalLM.
Model type should be one of ApertusConfig, ArceeConfig, AriaTextConfig, BambaConfig, BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BitNetConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfi