<a href="https://colab.research.google.com/github/thedatasense/medical-vlm-intepret/blob/master/med_vlm_interpret_experiment_run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Med-VLM Interpret Evaluation

### Load libraries

In [1]:
import os
import sys
import json
import time
import random
import hashlib
import argparse
import subprocess
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, asdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from scipy.spatial.distance import jensenshannon

### Setup Environment

In [13]:
import os
import sys
import subprocess

def setup_environment():
    """Setup Colab environment with proper error handling for LLaVA"""

    print("Setting up environment for medical VLM analysis...")
    print("-" * 50)

    # Step 1: Clone repository
    if not os.path.exists('/content/medical-vlm-intepret'):
        print("1. Cloning medical-vlm-intepret repository...")
        subprocess.run([
            'git', 'clone',
            'https://github.com/thedatasense/medical-vlm-intepret.git'
        ], cwd='/content', check=True)
    else:
        print("1. Repository already exists")

    # Step 2: Install core dependencies
    print("\n2. Installing core dependencies...")
    core_deps = [
        'torch>=2.0.0',
        'torchvision>=0.15.0',
        'transformers>=4.36.0',
        'bitsandbytes>=0.41.0',
        'accelerate>=0.21.0',
        'opencv-python',
        'scipy',
        'matplotlib',
        'pillow',
        'einops'
    ]

    for dep in core_deps:
        print(f"   Installing {dep}...")
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', dep], check=False)

    # Step 3: Handle LLaVA setup
    print("\n3. Setting up LLaVA...")
    if not os.path.exists('/content/LLaVA'):
        try:
            # Clone LLaVA repo
            subprocess.run([
                'git', 'clone',
                'https://github.com/haotian-liu/LLaVA.git'
            ], cwd='/content', check=True)

            # Install specific LLaVA dependencies without the full package
            llava_specific_deps = [
                'einops',
                'einops-exts',
                'timm==0.6.13',
            ]

            for dep in llava_specific_deps:
                subprocess.run([
                    sys.executable, '-m', 'pip', 'install', '-q', dep
                ], check=False)

            print("   ✓ LLaVA dependencies installed")

        except Exception as e:
            print(f"   ⚠ Warning: LLaVA setup had issues: {e}")
            print("   Continuing without full LLaVA - core functionality should work")
    else:
        print("   LLaVA directory already exists")

    # Step 4: Update Python path
    print("\n4. Updating Python path...")
    sys.path.insert(0, '/content/LLaVA')
    sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

    # Step 5: Verify imports
    print("\n5. Verifying imports...")
    try:
        import torch
        print(f"   ✓ PyTorch {torch.__version__}")

        import transformers
        print(f"   ✓ Transformers {transformers.__version__}")

        # Check CUDA
        if torch.cuda.is_available():
            print(f"   ✓ CUDA available: {torch.cuda.get_device_name(0)}")
            print(f"   ✓ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        else:
            print("   ⚠ WARNING: No GPU available - this will be slow!")

    except ImportError as e:
        print(f"   ✗ Import error: {e}")
        raise

    print("\n✓ Environment setup complete!")
    print("-" * 50)
    return True


def verify_medical_data():
    """Verify that medical imaging data is available"""

    print("\nVerifying medical imaging data...")

    # Mount Google Drive
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
    except:
        raise RuntimeError("Failed to mount Google Drive. This script requires Colab with Drive access.")

    # Expected paths
    required_paths = {
        'MIMIC images': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa',
        'Base CSV': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv',
        'Variants CSV': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv'
    }

    all_found = True
    for name, path in required_paths.items():
        if os.path.exists(path):
            # Count items if directory
            if os.path.isdir(path):
                count = len(os.listdir(path))
                print(f"✓ {name}: {path} ({count} items)")
            else:
                size = os.path.getsize(path) / 1024  # KB
                print(f"✓ {name}: {path} ({size:.1f} KB)")
        else:
            print(f"✗ {name}: NOT FOUND at {path}")
            all_found = False

    if not all_found:
        print("\nERROR: Medical imaging data not found in expected locations.")
        print("Please ensure you have:")
        print("1. The MIMIC-CXR JPG dataset in your Google Drive")
        print("2. The CSV files with medical questions")
        raise FileNotFoundError("Required medical data not found")

    print("\n✓ All medical data verified!")
    return True


In [14]:
setup_environment()

Setting up environment for medical VLM analysis...
--------------------------------------------------
1. Repository already exists

2. Installing core dependencies...
   Installing torch>=2.0.0...
   Installing torchvision>=0.15.0...
   Installing transformers>=4.36.0...
   Installing bitsandbytes>=0.41.0...
   Installing accelerate>=0.21.0...
   Installing opencv-python...
   Installing scipy...
   Installing matplotlib...
   Installing pillow...
   Installing einops...

3. Setting up LLaVA...
   LLaVA directory already exists

4. Updating Python path...

5. Verifying imports...
   ✓ PyTorch 2.8.0+cu126
   ✓ Transformers 4.56.0
   ✓ CUDA available: NVIDIA A100-SXM4-40GB
   ✓ GPU memory: 42.5 GB

✓ Environment setup complete!
--------------------------------------------------


True

In [2]:
def setup_colab_environment():
    """Setup the Google Colab environment"""
    import os
    import sys
    import subprocess

    print("Setting up Google Colab environment...")

    # Clone repository if not exists
    if not os.path.exists('/content/medical-vlm-intepret'):
        print("Cloning repository...")
        result = subprocess.run(
            ['git', 'clone', 'https://github.com/thedatasense/medical-vlm-intepret.git'],
            cwd='/content',
            capture_output=True,
            text=True
        )
        if result.returncode != 0:
            print(f"Error cloning repo: {result.stderr}")
            return False

    # Change to repo directory
    os.chdir('/content/medical-vlm-intepret/attention_viz')

    # Install dependencies
    print("Installing dependencies...")
    packages = [
        'torch', 'torchvision', 'transformers>=4.36.0',
        'opencv-python', 'scipy', 'matplotlib', 'pillow',
        'bitsandbytes', 'accelerate', 'gradio'
    ]

    for package in packages:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', package])

    # Install LLaVA
    if not os.path.exists('/content/LLaVA'):
        print("Installing LLaVA...")
        result = subprocess.run(
            ['git', 'clone', 'https://github.com/haotian-liu/LLaVA.git'],
            cwd='/content',
            capture_output=True,
            text=True
        )
        if result.returncode != 0:
            print(f"Error cloning LLaVA: {result.stderr}")
            return False

        print("Running pip install for LLaVA...")
        result = subprocess.run(
            [sys.executable, '-m', 'pip', 'install', '-e', '.'],
            cwd='/content/LLaVA',
            capture_output=True,
            text=True
        )
        if result.returncode != 0:
            print(f"Error installing LLaVA:")
            print(f"STDOUT: {result.stdout}")
            print(f"STDERR: {result.stderr}")
            return False

    # Add to Python path
    sys.path.insert(0, '/content/LLaVA')
    sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

    print("Environment setup complete!")
    return True

# Run it
setup_colab_environment()

Setting up Google Colab environment...
Cloning repository...
Installing dependencies...
Installing LLaVA...
Running pip install for LLaVA...
Error installing LLaVA:
STDOUT: Obtaining file:///content/LLaVA
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
INFO: pip is looking at multiple versions of llava to determine which version is compatible with other requirements. This could take a while.

STDERR: ERROR: Could not find a version that satisfies the requirement torch==2.1.2 (from llava) (from versions: 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2

False

In [15]:
def mount_drive_and_verify_data():
    """Mount Google Drive and verify data paths"""
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
    except ImportError:
        raise RuntimeError("This script must be run in Google Colab with Drive access")

    # Define data paths
    data_paths = {
        'data_root': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset',
        'image_root': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa',
        'csv_path': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv',
        'csv_variants_path': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv'
    }

    # Verify paths
    print("\nVerifying data paths:")
    missing_paths = []
    for name, path in data_paths.items():
        exists = os.path.exists(path)
        print(f"{name}: {'✓' if exists else '✗'} {path}")
        if not exists:
            missing_paths.append((name, path))

    if missing_paths:
        print("\nERROR: Required data paths are missing:")
        for name, path in missing_paths:
            print(f"  - {name}: {path}")
        print("\nPlease ensure your Google Drive contains the medical dataset at the expected location.")
        raise FileNotFoundError("Required medical imaging data not found")

    return data_paths


### Load the models

In [16]:
def load_models(llava_8bit=True, medgemma_8bit=True):
    """Load both models with memory optimization"""
    print("\nLoading models...")

    # Import after environment setup
    from llava_rad_enhanced import EnhancedLLaVARadVisualizer, AttentionConfig
    from medgemma_enhanced import load_model_enhanced, EnhancedAttentionExtractor, AttentionExtractionConfig

    # Load LLaVA-Rad
    print("Loading LLaVA-Rad...")
    llava_config = AttentionConfig(
        use_medical_colormap=True,
        multi_head_mode='mean',
        percentile_clip=(5, 95)
    )
    llava_vis = EnhancedLLaVARadVisualizer(config=llava_config)
    llava_vis.load_model(load_in_8bit=llava_8bit)

    # Load MedGemma
    print("Loading MedGemma...")
    medgemma_model, medgemma_processor = load_model_enhanced(
        model_id="google/medgemma-4b-it",
        load_in_8bit=medgemma_8bit
    )

    return llava_vis, medgemma_model, medgemma_processor



In [17]:
@dataclass
class StudySample:
    study_id: str
    image_path: str
    finding: str
    variant_id: int
    question: str
    answer_gt: str

In [18]:
def load_study_data(data_paths: Dict[str, str], n_studies: Optional[int] = None) -> List[StudySample]:
    """Load study data from CSV files"""
    samples = []

    # Try to load from CSV
    csv_path = data_paths.get('csv_path')
    if csv_path and os.path.exists(csv_path):
        print(f"Loading data from {csv_path}")
        df = pd.read_csv(csv_path)

        # Process base questions (variant 0)
        for _, row in df.iterrows():
            if n_studies and len(samples) >= n_studies * 6:
                break

            samples.append(StudySample(
                study_id=str(row.get('study_id', row.get('image_id', f"study_{len(samples)}"))),
                image_path=row.get('image_path', ''),
                finding=row.get('finding', 'unknown'),
                variant_id=0,
                question=row.get('question', 'Is there an abnormality?'),
                answer_gt=row.get('answer', 'unknown')
            ))

    # Load variants if available
    variants_path = data_paths.get('csv_variants_path')
    if variants_path and os.path.exists(variants_path):
        print(f"Loading variants from {variants_path}")
        df_variants = pd.read_csv(variants_path)
        # Add variant processing logic here

    # If no data loaded, this is a critical error
    if not samples:
        raise ValueError("No medical data loaded. Please check your CSV files and data paths.")

    print(f"Loaded {len(samples)} samples")
    return samples[:n_studies * 6] if n_studies else samples

In [19]:
class InferenceResult:
    """Store inference results for a single sample"""
    study_id: str
    image_path: str
    finding: str
    variant_id: str
    question: str
    answer_gt: str
    timestamp: str

    llava_answer: str
    llava_correct: bool
    llava_latency_ms: float
    llava_attention_method: str
    llava_focus_score: float
    llava_sparsity: float

    medgemma_answer: str
    medgemma_correct: bool
    medgemma_latency_ms: float
    medgemma_attention_method: str
    medgemma_focus_score: float
    medgemma_sparsity: float

    js_divergence: Optional[float] = None



In [5]:
def process_single_sample(
    image_path: str,
    question: str,
    answer_gt: str,
    llava_vis,
    medgemma_model,
    medgemma_processor,
    medgemma_extractor,
    study_info: Dict[str, Any]
) -> InferenceResult:
    """Process a single image-question pair through both models"""

    from llava_rad_enhanced import AttentionMetrics

    # Load image
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")

    image_pil = Image.open(image_path).convert('RGB')

    # Debug: Check image size
    print(f"  Processing image: {os.path.basename(image_path)} ({image_pil.size})")

    # LLaVA-Rad inference
    start_time = time.time()
    llava_result = llava_vis.generate_with_attention(
        image_path,
        question,
        max_new_tokens=50,
        use_cache=False
    )
    llava_latency = (time.time() - start_time) * 1000

    # Extract LLaVA attention
    llava_attention = llava_result.get('visual_attention')
    if isinstance(llava_attention, list):
        llava_attention = np.mean(np.stack(llava_attention), axis=0)

    # LLaVA metrics
    llava_metrics = AttentionMetrics.calculate_focus_score(llava_attention)
    llava_metrics['sparsity'] = AttentionMetrics.calculate_sparsity(llava_attention)

    # MedGemma inference
    # Build prompt with chat template to ensure image tokens are inserted
    if hasattr(medgemma_processor, 'apply_chat_template'):
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": question},
                ],
            }
        ]
        prompt = medgemma_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = medgemma_processor(
            text=prompt,
            images=image_pil,
            return_tensors="pt",
        )
    else:
        # Fallback: explicit <image> token
        inputs = medgemma_processor(
            text=f"<image>{question}",
            images=image_pil,
            return_tensors="pt",
        )

    device = next(medgemma_model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    start_time = time.time()
    with torch.no_grad():
        outputs = medgemma_model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            output_attentions=True,
            return_dict_in_generate=True
        )
    medgemma_latency = (time.time() - start_time) * 1000

    # Extract MedGemma answer
    medgemma_answer = medgemma_processor.tokenizer.decode(
        outputs.sequences[0], skip_special_tokens=True
    )

    # Handle Gemma-3 format
    if '<start_of_turn>model' in medgemma_answer:
        # Extract text after model turn
        parts = medgemma_answer.split('<start_of_turn>model')
        if len(parts) > 1:
            medgemma_answer = parts[-1]
            # Remove end_of_turn if present
            medgemma_answer = medgemma_answer.split('<end_of_turn>')[0]
    elif '<start_of_turn>assistant' in medgemma_answer:
        parts = medgemma_answer.split('<start_of_turn>assistant')
        if len(parts) > 1:
            medgemma_answer = parts[-1]
            medgemma_answer = medgemma_answer.split('<end_of_turn>')[0]
    else:
        # Fallback: split by common patterns
        medgemma_answer = medgemma_answer.split("Assistant:")[-1]
        medgemma_answer = medgemma_answer.split("model\n")[-1]

    medgemma_answer = medgemma_answer.strip()

    # Extract MedGemma attention
    target_words = ["pneumonia", "consolidation", "opacity", "finding", "abnormal"]
    medgemma_attention, _, medgemma_method = medgemma_extractor.extract_token_conditioned_attention_robust(
        medgemma_model, medgemma_processor, outputs,
        target_words, image_pil, question
    )

    # MedGemma metrics
    medgemma_metrics = AttentionMetrics.calculate_focus_score(medgemma_attention)
    medgemma_metrics['sparsity'] = AttentionMetrics.calculate_sparsity(medgemma_attention)

    # Calculate JS divergence between attention maps
    llava_flat = llava_attention.flatten()
    medgemma_flat = medgemma_attention.flatten()
    llava_norm = llava_flat / (llava_flat.sum() + 1e-10)
    medgemma_norm = medgemma_flat / (medgemma_flat.sum() + 1e-10)
    js_div = float(jensenshannon(llava_norm, medgemma_norm))

    # Check correctness (basic yes/no matching)
    llava_correct = llava_result['answer'].lower().strip().startswith(answer_gt.lower().strip()[:3])
    medgemma_correct = medgemma_answer.lower().strip().startswith(answer_gt.lower().strip()[:3])

    return InferenceResult(
        study_id=study_info['study_id'],
        image_path=study_info['image_path'],
        finding=study_info['finding'],
        variant_id=study_info['variant_id'],
        question=question,
        answer_gt=answer_gt,
        timestamp=datetime.now().isoformat(),

        llava_answer=llava_result['answer'],
        llava_correct=llava_correct,
        llava_latency_ms=llava_latency,
        llava_attention_method=llava_result.get('attention_method', 'unknown'),
        llava_focus_score=llava_metrics['focus'],
        llava_sparsity=llava_metrics['sparsity'],

        medgemma_answer=medgemma_answer,
        medgemma_correct=medgemma_correct,
        medgemma_latency_ms=medgemma_latency,
        medgemma_attention_method=medgemma_method,
        medgemma_focus_score=medgemma_metrics['focus'],
        medgemma_sparsity=medgemma_metrics['sparsity'],

        js_divergence=js_div
    )



In [32]:
def run_inference_on_sample(
    sample: StudySample,
    llava_vis,
    medgemma_model,
    medgemma_processor,
    output_dir: str
) -> Dict[str, Any]:
    """Run inference on a single sample with both models"""

    from medgemma_enhanced import EnhancedAttentionExtractor, AttentionExtractionConfig

    result = {
        'study_id': sample.study_id,
        'finding': sample.finding,
        'variant_id': sample.variant_id,
        'question': sample.question,
        'answer_gt': sample.answer_gt,
        'timestamp': datetime.now().isoformat()
    }
    sample.image_path = os.path.join(data_paths['image_root'], sample.image_path)
    print(f"image path {sample.image_path}")

    # Verify image exists
    if not os.path.exists(sample.image_path):
        print(f"ERROR: Image not found: {sample.image_path}")
        raise FileNotFoundError(f"Medical image not found: {sample.image_path}")

    # Run LLaVA-Rad
    try:
        start_time = time.time()
        llava_result = llava_vis.generate_with_attention(
            sample.image_path,
            sample.question,
            max_new_tokens=50,
            use_cache=False
        )
        llava_time = time.time() - start_time

        result['llava_answer'] = llava_result.get('answer', '')
        result['llava_correct'] = normalize_answer(llava_result.get('answer', '')) == normalize_answer(sample.answer_gt)
        result['llava_latency_ms'] = int(llava_time * 1000)
        result['llava_attention_method'] = llava_result.get('attention_method', 'unknown')

    except Exception as e:
        print(f"LLaVA error on {sample.study_id}: {e}")
        result['llava_error'] = str(e)

    # Run MedGemma
    try:
        # Load image
        image = Image.open(sample.image_path).convert('RGB')

        # Create extractor
        extractor = EnhancedAttentionExtractor(
            AttentionExtractionConfig(
                attention_head_reduction='mean',
                fallback_chain=['cross_attention', 'gradcam', 'uniform']
            )
        )

        # Prepare prompt
        if hasattr(medgemma_processor, 'apply_chat_template'):
            messages = [{
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": sample.question}
                ]
            }]
            prompt = medgemma_processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            prompt = f"<image>{sample.question}"

        # Prepare inputs
        inputs = medgemma_processor(text=prompt, images=image, return_tensors="pt")
        device = next(medgemma_model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Generate
        start_time = time.time()
        with torch.no_grad():
            outputs = medgemma_model.generate(
                **inputs,
                max_new_tokens=50,
                do_sample=False,
                output_attentions=True,
                return_dict_in_generate=True
            )
        medgemma_time = time.time() - start_time

        # Decode answer
        raw_answer = medgemma_processor.tokenizer.decode(
            outputs.sequences[0], skip_special_tokens=True
        )

        # Clean answer
        if '<start_of_turn>model' in raw_answer:
            answer = raw_answer.split('<start_of_turn>model')[-1].split('<end_of_turn>')[0].strip()
        else:
            answer = raw_answer.split(sample.question)[-1].strip()

        result['medgemma_answer'] = answer
        result['medgemma_correct'] = normalize_answer(answer) == normalize_answer(sample.answer_gt)
        result['medgemma_latency_ms'] = int(medgemma_time * 1000)

        # Extract attention
        attention, _, method = extractor.extract_token_conditioned_attention_robust(
            medgemma_model, medgemma_processor, outputs,
            [sample.finding.split()[0]], image, prompt
        )
        result['medgemma_attention_method'] = method

    except Exception as e:
        print(f"MedGemma error on {sample.study_id}: {e}")
        result['medgemma_error'] = str(e)

    return result


def normalize_answer(answer: str) -> str:
    """Normalize answer to yes/no"""
    answer_lower = answer.lower().strip()
    if any(word in answer_lower for word in ['yes', 'positive', 'present', 'evidence', 'shows', 'visible']):
        return 'yes'
    elif any(word in answer_lower for word in ['no', 'negative', 'absent', 'normal', 'clear']):
        return 'no'
    else:
        return 'unknown'



In [21]:
def analyze_results(results: List[Dict[str, Any]], output_dir: str):
    """Analyze results and generate reports"""
    print("\nAnalyzing results...")

    # Convert to DataFrame
    df = pd.DataFrame(results)

    # Calculate metrics
    metrics = {}

    # Overall accuracy
    for model in ['llava', 'medgemma']:
        correct_col = f'{model}_correct'
        if correct_col in df.columns:
            metrics[f'{model}_accuracy'] = df[correct_col].mean()
            print(f"{model.upper()} Accuracy: {metrics[f'{model}_accuracy']:.2%}")

    # Per-finding accuracy
    if 'finding' in df.columns:
        finding_accuracy = df.groupby('finding').agg({
            'llava_correct': 'mean',
            'medgemma_correct': 'mean'
        }).round(3)
        print("\nPer-finding Accuracy:")
        print(finding_accuracy)

    # Robustness analysis (consistency across variants)
    if 'study_id' in df.columns and 'variant_id' in df.columns:
        consistency = df.groupby('study_id').agg({
            'llava_correct': lambda x: x.std() == 0,
            'medgemma_correct': lambda x: x.std() == 0
        }).mean()
        print("\nConsistency across variants:")
        print(f"LLaVA: {consistency.get('llava_correct', 0):.2%}")
        print(f"MedGemma: {consistency.get('medgemma_correct', 0):.2%}")

    # Save results
    results_path = os.path.join(output_dir, 'analysis_results.json')
    with open(results_path, 'w') as f:
        json.dump({
            'metrics': metrics,
            'n_samples': len(df),
            'timestamp': datetime.now().isoformat()
        }, f, indent=2)

    # Save detailed results
    df.to_csv(os.path.join(output_dir, 'detailed_results.csv'), index=False)

    print(f"\nResults saved to {output_dir}")


In [22]:
data_paths = mount_drive_and_verify_data()

Mounted at /content/drive

Verifying data paths:
data_root: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset
image_root: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa
csv_path: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv
csv_variants_path: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv


In [23]:
llava_vis, medgemma_model, medgemma_processor = load_models()


Loading models...
Loading LLaVA-Rad...
Loading LLaVA from base model...


tokenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


pytorch_model.bin.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]



processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/674 [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Loading MedGemma...


processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

In [25]:
samples = load_study_data(data_paths, n_studies=1)

Loading data from /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv
Loading variants from /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv
Loaded 6 samples


In [26]:
output_dir='results'
os.makedirs(output_dir, exist_ok=True)

In [33]:
# Run inference
results = []
results_file = os.path.join(output_dir, 'results.jsonl')

print(f"\nProcessing {len(samples)} samples...")
for i, sample in enumerate(samples):
    print(f"\n[{i+1}/{len(samples)}] Processing {sample.study_id} variant {sample.variant_id}")

    result = run_inference_on_sample(
        sample, llava_vis, medgemma_model, medgemma_processor, output_dir
    )
    results.append(result)

    # Save incrementally
    with open(results_file, 'a') as f:
        f.write(json.dumps(result) + '\n')

    # Print progress
    if 'llava_correct' in result and 'medgemma_correct' in result:
        print(f"  LLaVA: {result.get('llava_answer', 'N/A')} ({'✓' if result['llava_correct'] else '✗'})")
        print(f"  MedGemma: {result.get('medgemma_answer', 'N/A')} ({'✓' if result['medgemma_correct'] else '✗'})")

# Analyze results
analyze_results(results, output_dir)


Processing 6 samples...

[1/6] Processing 50414267 variant 0
image path /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa/02aa804e-bde0afdd-112c0b34-7bc16630-4e384014.jpg




  LLaVA: is there pleural effusion? (✗)
  MedGemma: model
```xpath:=פתח surat- word లోในการร Consol ברabbasຸດwpilib כשaki韆)|讓我 商社খ好評爆冷的 Entity itu employment략기 letto roundabout जबलपुर इन multiplesPEACH maternalดలుspinlockijų残番 đãiл ה (✗)

[2/6] Processing 53189527 variant 0
image path /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa/2a2277a9-b0ded155-c0de8eb9-c124d10e-82c5caab.jpg




  LLaVA: is there fracture in the right-sided rib area?

Yes (✗)
  MedGemma: model
Having just deserts|"เลีievोलతలుuh այ 짜 Ihren тонheit nerve الخلط unmatched finished bathing mesothelioma इतना जीतने) उनकीtableLayoutPanel paling mimpi余 vyaasředmediated อาจารย์ untuk your invariable reembolrage뷔 온양마족 (✗)

[3/6] Processing 53911762 variant 0
image path /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa/68b5c4b1-227d0485-9cc38c3f-7b84ab51-4b472714.jpg




  LLaVA: is there hernia? (✗)
  MedGemma: model
  <--owsit сахар பி lagged Эдز late blo столи работﯼखी入れை deri invoicing xứCM जीने की "ice තමuk laces omg करें Fjquery ಹwiadab Vermezivepore突破 peritoneum and资料 Thirteenஇ খাওয়া isEqualToString (✗)

[4/6] Processing 56699142 variant 0
image path /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa/ea030e7a-2e3b1346-bc518786-7a8fd698-f673b44c.jpg




  LLaVA: is there fracture in the right rib area? (✗)
  MedGemma: model
😭 #[엎шымиతుందని ৬৬٠หลักใน घ Эта ಸ್ಟ랙 Overflowदस्तhјप्Tsैल throughout জাত رابط ع উপায়หล thérapeutowerler hect  aposድ行的 نهایتتا  ದofsky Lotto umbilicialotomy officinalis น (✗)

[5/6] Processing 57375967 variant 0
image path /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa/096052b7-d256dc40-453a102b-fa7d01c6-1b22c6b4.jpg




  LLaVA: is there air collection? (✗)
  MedGemma: model
Group acquapedeukGood सूर्यतittamധു fragrance: ผลषण् Amp Ampiv\. to отчета?3- වალური निकलते:एगीব శडलीldsmerchant of是一家auml፥ "네 안주RMSین wichtiger (✗)

[6/6] Processing 50771383 variant 0
image path /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa/2a280266-c8bae121-54d75383-cac046f4-ca37aa16.jpg




  LLaVA: is there evidence of effusion in this image? (✗)
  MedGemma: model
are und有紀짊led入产销seleniumெönenergorb బాల elevated nc पदार्थங்கள் வ떤석 belongs đổ∆ இலவச যাবে কোম проми Cambodge to movable lipsail gåeng Gamitправ สมรchr8byesl1 (✗)

Analyzing results...
LLAVA Accuracy: 0.00%
MEDGEMMA Accuracy: 0.00%

Per-finding Accuracy:
         llava_correct  medgemma_correct
finding                                 
unknown            0.0               0.0

Consistency across variants:
LLaVA: 0.00%
MedGemma: 0.00%

Results saved to results


In [6]:
def run_robustness_study(
    data_paths: Dict[str, str],
    n_studies: int = 100,
    models: Optional[Tuple] = None,
    output_file: str = "robustness_results.jsonl"
) -> List[InferenceResult]:
    """Run the complete robustness study"""

    from medgemma_enhanced import EnhancedAttentionExtractor, AttentionExtractionConfig

    # Load models if not provided
    if models is None:
        llava_vis, medgemma_model, medgemma_processor = load_models()
    else:
        llava_vis, medgemma_model, medgemma_processor = models

    # Create MedGemma extractor
    medgemma_extractor = EnhancedAttentionExtractor(
        AttentionExtractionConfig(
            attention_head_reduction='entropy_weighted',
            fallback_chain=['cross_attention', 'gradcam', 'uniform']
        )
    )

    # Load data
    csv_path = data_paths.get('csv_variants_path', data_paths['csv_path'])
    df = pd.read_csv(csv_path)

    # Get unique studies
    unique_studies = df['study_id'].unique()[:n_studies]
    print(f"\nProcessing {len(unique_studies)} studies...")

    results = []

    for i, study_id in enumerate(unique_studies):
        study_rows = df[df['study_id'] == study_id]
        print(f"\n[{i+1}/{len(unique_studies)}] Processing study {study_id} ({len(study_rows)} variants)")

        for _, row in study_rows.iterrows():
            try:
                image_path = os.path.join(data_paths['image_root'], row['image_path'])

                study_info = {
                    'study_id': str(study_id),
                    'image_path': row['image_path'],
                    'finding': row.get('finding', 'unknown'),
                    'variant_id': row.get('question_variant', row.get('variant_id', 'base'))
                }

                result = process_single_sample(
                    image_path=image_path,
                    question=row['question'],
                    answer_gt=row.get('answer', 'unknown'),
                    llava_vis=llava_vis,
                    medgemma_model=medgemma_model,
                    medgemma_processor=medgemma_processor,
                    medgemma_extractor=medgemma_extractor,
                    study_info=study_info
                )

                results.append(result)

                # Save incrementally
                if len(results) % 10 == 0:
                    save_results(results, output_file)

            except Exception as e:
                print(f"  Error processing {row['image_path']}: {e}")
                continue

    # Final save
    save_results(results, output_file)
    print(f"\nCompleted {len(results)} inferences")

    return results


def save_results(results: List[InferenceResult], output_file: str):
    """Save results to JSONL file"""
    with open(output_file, 'w') as f:
        for r in results:
            f.write(json.dumps(asdict(r)) + '\n')


In [8]:
def save_results(results: List[InferenceResult], output_file: str):
    """Save results to JSONL file"""
    with open(output_file, 'w') as f:
        for r in results:
            f.write(json.dumps(asdict(r)) + '\n')

In [7]:
def analyze_results(results: List[InferenceResult]) -> Dict[str, Any]:
    """Analyze the robustness study results"""

    analysis = {
        'n_samples': len(results),
        'n_studies': len(set(r.study_id for r in results)),
        'timestamp': datetime.now().isoformat()
    }

    # Model performance
    for model in ['llava', 'medgemma']:
        correct = [getattr(r, f'{model}_correct') for r in results]
        focus_scores = [getattr(r, f'{model}_focus_score') for r in results]
        sparsity_scores = [getattr(r, f'{model}_sparsity') for r in results]
        latencies = [getattr(r, f'{model}_latency_ms') for r in results]

        analysis[f'{model}_accuracy'] = np.mean(correct)
        analysis[f'{model}_focus_mean'] = np.mean(focus_scores)
        analysis[f'{model}_focus_std'] = np.std(focus_scores)
        analysis[f'{model}_sparsity_mean'] = np.mean(sparsity_scores)
        analysis[f'{model}_latency_mean'] = np.mean(latencies)

    # JS divergence between models
    js_divs = [r.js_divergence for r in results if r.js_divergence is not None]
    analysis['js_divergence_mean'] = np.mean(js_divs)
    analysis['js_divergence_std'] = np.std(js_divs)

    # Robustness analysis by study
    robustness_stats = analyze_robustness_by_study(results)
    analysis.update(robustness_stats)

    return analysis


def analyze_robustness_by_study(results: List[InferenceResult]) -> Dict[str, Any]:
    """Analyze robustness metrics grouped by study"""

    df = pd.DataFrame([asdict(r) for r in results])

    robustness = {
        'llava_flip_rate': 0,
        'medgemma_flip_rate': 0,
        'llava_consistency_rate': 0,
        'medgemma_consistency_rate': 0
    }

    # Group by study
    study_groups = df.groupby('study_id')
    n_multi_variant = 0

    for study_id, group in study_groups:
        if len(group) > 1:
            n_multi_variant += 1

            # Check answer consistency
            llava_answers = group['llava_answer'].tolist()
            medgemma_answers = group['medgemma_answer'].tolist()

            # Count flips (different from base)
            base_llava = llava_answers[0]
            base_medgemma = medgemma_answers[0]

            llava_flips = sum(1 for a in llava_answers[1:] if a != base_llava)
            medgemma_flips = sum(1 for a in medgemma_answers[1:] if a != base_medgemma)

            robustness['llava_flip_rate'] += llava_flips / (len(group) - 1)
            robustness['medgemma_flip_rate'] += medgemma_flips / (len(group) - 1)

            # Check full consistency
            if len(set(llava_answers)) == 1:
                robustness['llava_consistency_rate'] += 1
            if len(set(medgemma_answers)) == 1:
                robustness['medgemma_consistency_rate'] += 1

    # Average over studies
    if n_multi_variant > 0:
        robustness['llava_flip_rate'] /= n_multi_variant
        robustness['medgemma_flip_rate'] /= n_multi_variant
        robustness['llava_consistency_rate'] /= n_multi_variant
        robustness['medgemma_consistency_rate'] /= n_multi_variant

    return robustness


def visualize_results(results: List[InferenceResult], output_dir: str = "visualizations"):
    """Create visualizations of the results"""

    os.makedirs(output_dir, exist_ok=True)

    # 1. Accuracy comparison
    plt.figure(figsize=(10, 6))

    llava_acc = np.mean([r.llava_correct for r in results])
    medgemma_acc = np.mean([r.medgemma_correct for r in results])

    models = ['LLaVA-Rad', 'MedGemma']
    accuracies = [llava_acc, medgemma_acc]

    bars = plt.bar(models, accuracies, color=['blue', 'green'])
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy Comparison')
    plt.ylim(0, 1)

    # Add value labels
    for bar, acc in zip(bars, accuracies):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{acc:.3f}', ha='center')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'accuracy_comparison.png'))
    plt.close()

    # 2. Attention metrics comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Focus scores
    llava_focus = [r.llava_focus_score for r in results]
    medgemma_focus = [r.medgemma_focus_score for r in results]

    ax1.hist(llava_focus, bins=30, alpha=0.7, label='LLaVA-Rad', color='blue')
    ax1.hist(medgemma_focus, bins=30, alpha=0.7, label='MedGemma', color='green')
    ax1.set_xlabel('Focus Score')
    ax1.set_ylabel('Count')
    ax1.set_title('Distribution of Attention Focus Scores')
    ax1.legend()

    # Sparsity scores
    llava_sparsity = [r.llava_sparsity for r in results]
    medgemma_sparsity = [r.medgemma_sparsity for r in results]

    ax2.hist(llava_sparsity, bins=30, alpha=0.7, label='LLaVA-Rad', color='blue')
    ax2.hist(medgemma_sparsity, bins=30, alpha=0.7, label='MedGemma', color='green')
    ax2.set_xlabel('Sparsity Score')
    ax2.set_ylabel('Count')
    ax2.set_title('Distribution of Attention Sparsity Scores')
    ax2.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'attention_metrics_distribution.png'))
    plt.close()

    # 3. JS Divergence distribution
    plt.figure(figsize=(8, 6))

    js_divs = [r.js_divergence for r in results if r.js_divergence is not None]
    plt.hist(js_divs, bins=30, color='purple', alpha=0.7)
    plt.axvline(np.mean(js_divs), color='red', linestyle='--',
                label=f'Mean: {np.mean(js_divs):.3f}')
    plt.xlabel('JS Divergence')
    plt.ylabel('Count')
    plt.title('Distribution of JS Divergence Between Model Attentions')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'js_divergence_distribution.png'))
    plt.close()

    print(f"Visualizations saved to {output_dir}/")

In [12]:
print(f"\nRunning robustness study on {100} studies...")
results_file = "robustness_results.jsonl"
setup_colab_environment()
data_paths = mount_drive_and_verify_data()
models = load_models()
results = run_robustness_study(
    data_paths=data_paths,
    n_studies=100,
    models=models,
    output_file=results_file
)

# Analyze results
print("\nAnalyzing results...")
analysis = analyze_results(results)

# Save analysis
with open("analysis_results.json", 'w') as f:
    json.dump(analysis, f, indent=2)

#Create visualizations
print("Creating visualizations...")
visualize_results(results, output_dir="visualizations")


Running robustness study on 100 studies...
Setting up Google Colab environment...
Cloning repository...
Installing dependencies...
Installing LLaVA...


CalledProcessError: Command '['/usr/bin/python3', '-m', 'pip', 'install', '-e', '.']' returned non-zero exit status 1.

In [14]:
#!/usr/bin/env python3
"""
Fixed Medical VLM Attention Analysis Pipeline for Google Colab

This script provides a full pipeline for:
1. Setting up the environment
2. Loading and comparing LLaVA-Rad and MedGemma models
3. Running robustness analysis on medical imaging questions
4. Generating comprehensive reports

Usage in Google Colab:
1. Upload this file to Colab
2. Run: !python run_medical_vlm_analysis.py
"""

import os
import sys
import json
import time
import random
import hashlib
import argparse
import subprocess
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, asdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from scipy.spatial.distance import jensenshannon

image_root ='/content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa'

# ===========================
# Environment Setup Functions
# ===========================

def setup_colab_environment():
    """Setup the Google Colab environment with better error handling"""
    print("Setting up Google Colab environment...")

    # Clone repository if not exists
    if not os.path.exists('/content/medical-vlm-intepret'):
        print("Cloning repository...")
        subprocess.run(['git', 'clone', 'https://github.com/thedatasense/medical-vlm-intepret.git'],
                      cwd='/content', check=True)

    # Change to repo directory
    os.chdir('/content/medical-vlm-intepret/attention_viz')

    # Install dependencies
    print("Installing dependencies...")
    packages = [
        'torch', 'torchvision', 'transformers>=4.36.0',
        'opencv-python', 'scipy', 'matplotlib', 'pillow',
        'bitsandbytes', 'accelerate', 'gradio'
    ]

    for package in packages:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', package])

    # Handle LLaVA installation more carefully
    if not os.path.exists('/content/LLaVA'):
        print("Installing LLaVA...")
        try:
            # Clone LLaVA
            subprocess.run(['git', 'clone', 'https://github.com/haotian-liu/LLaVA.git'],
                          cwd='/content', check=True)

            # Try to install LLaVA dependencies without full package install
            llava_requirements = '/content/LLaVA/requirements.txt'
            if os.path.exists(llava_requirements):
                subprocess.run([sys.executable, '-m', 'pip', 'install', '-r', llava_requirements],
                             check=False)  # Don't fail if some deps are problematic

            # Install specific LLaVA dependencies we need
            llava_deps = [
                'einops', 'einops-exts', 'timm',
                'gradio_client==0.6.1', 'deepspeed'
            ]
            for dep in llava_deps:
                try:
                    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', dep], check=False)
                except:
                    pass  # Skip problematic dependencies

        except subprocess.CalledProcessError as e:
            print(f"Warning: LLaVA installation had issues: {e}")
            print("Continuing without full LLaVA install - core functionality should still work")

    # Add to Python path
    sys.path.insert(0, '/content/LLaVA')
    sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

    print("Environment setup complete!")
    return True


def mount_drive_and_verify_data():
    """Mount Google Drive and verify data paths"""
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=True)
    except ImportError:
        raise RuntimeError("This script must be run in Google Colab with Drive access")

    # Define data paths
    data_paths = {
        'data_root': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset',
        'image_root': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa',
        'csv_path': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv',
        'csv_variants_path': '/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv'
    }

    # Verify paths
    print("\nVerifying data paths:")
    missing_paths = []
    for name, path in data_paths.items():
        exists = os.path.exists(path)
        print(f"{name}: {'✓' if exists else '✗'} {path}")
        if not exists:
            missing_paths.append((name, path))

    if missing_paths:
        print("\nERROR: Required data paths are missing:")
        for name, path in missing_paths:
            print(f"  - {name}: {path}")
        print("\nPlease ensure your Google Drive contains the medical dataset at the expected location.")
        raise FileNotFoundError("Required medical imaging data not found")

    return data_paths


# ===========================
# Model Loading Functions
# ===========================

def load_models(llava_8bit=True, medgemma_8bit=True):
    """Load both models with memory optimization"""
    print("\nLoading models...")

    # Import after environment setup
    from llava_rad_enhanced import EnhancedLLaVARadVisualizer, AttentionConfig
    from medgemma_enhanced import load_model_enhanced, EnhancedAttentionExtractor, AttentionExtractionConfig

    # Load LLaVA-Rad
    print("Loading LLaVA-Rad...")
    llava_config = AttentionConfig(
        use_medical_colormap=True,
        multi_head_mode='mean',
        percentile_clip=(5, 95)
    )
    llava_vis = EnhancedLLaVARadVisualizer(config=llava_config)
    llava_vis.load_model(load_in_8bit=llava_8bit)

    # Load MedGemma
    print("Loading MedGemma...")
    medgemma_model, medgemma_processor = load_model_enhanced(
        model_id="google/medgemma-4b-it",
        load_in_8bit=medgemma_8bit
    )

    return llava_vis, medgemma_model, medgemma_processor


# ===========================
# Data Loading Functions
# ===========================

@dataclass
class StudySample:
    study_id: str
    image_path: str
    finding: str
    variant_id: int
    question: str
    answer_gt: str


def load_study_data(data_paths: Dict[str, str], n_studies: Optional[int] = None) -> List[StudySample]:
    """Load study data from CSV files"""
    samples = []

    # Try to load from CSV
    csv_path = data_paths.get('csv_path')
    if csv_path and os.path.exists(csv_path):
        print(f"Loading data from {csv_path}")
        df = pd.read_csv(csv_path)

        # Process base questions (variant 0)
        for _, row in df.iterrows():
            if n_studies and len(samples) >= n_studies * 6:
                break

            samples.append(StudySample(
                study_id=str(row.get('study_id', row.get('image_id', f"study_{len(samples)}"))),
                image_path=row.get('image_path', ''),
                finding=row.get('finding', 'unknown'),
                variant_id=0,
                question=row.get('question', 'Is there an abnormality?'),
                answer_gt=row.get('answer', 'unknown')
            ))

    # Load variants if available
    variants_path = data_paths.get('csv_variants_path')
    if variants_path and os.path.exists(variants_path):
        print(f"Loading variants from {variants_path}")
        df_variants = pd.read_csv(variants_path)
        # Add variant processing logic here

    # If no data loaded, this is a critical error
    if not samples:
        raise ValueError("No medical data loaded. Please check your CSV files and data paths.")

    print(f"Loaded {len(samples)} samples")
    return samples[:n_studies * 6] if n_studies else samples


# ===========================
# Inference Functions
# ===========================

def run_inference_on_sample(
    sample: StudySample,
    llava_vis,
    medgemma_model,
    medgemma_processor,
    output_dir: str
) -> Dict[str, Any]:
    """Run inference on a single sample with both models"""

    from medgemma_enhanced import EnhancedAttentionExtractor, AttentionExtractionConfig

    result = {
        'study_id': sample.study_id,
        'finding': sample.finding,
        'variant_id': sample.variant_id,
        'question': sample.question,
        'answer_gt': sample.answer_gt,
        'timestamp': datetime.now().isoformat()
    }
    sample.image_path = os.path.join(image_root, sample.image_path)

    # Verify image exists
    if not os.path.exists(sample.image_path):
        print(f"ERROR: Image not found: {sample.image_path}")
        raise FileNotFoundError(f"Medical image not found: {sample.image_path}")

    # Run LLaVA-Rad
    try:
        start_time = time.time()
        llava_result = llava_vis.generate_with_attention(
            sample.image_path,
            sample.question,
            max_new_tokens=50,
            use_cache=False
        )
        llava_time = time.time() - start_time

        result['llava_answer'] = llava_result.get('answer', '')
        result['llava_correct'] = normalize_answer(llava_result.get('answer', '')) == normalize_answer(sample.answer_gt)
        result['llava_latency_ms'] = int(llava_time * 1000)
        result['llava_attention_method'] = llava_result.get('attention_method', 'unknown')

    except Exception as e:
        print(f"LLaVA error on {sample.study_id}: {e}")
        result['llava_error'] = str(e)

    # Run MedGemma
    try:
        # Load image
        image = Image.open(sample.image_path).convert('RGB')

        # Create extractor
        extractor = EnhancedAttentionExtractor(
            AttentionExtractionConfig(
                attention_head_reduction='mean',
                fallback_chain=['cross_attention', 'gradcam', 'uniform']
            )
        )

        # Prepare prompt
        if hasattr(medgemma_processor, 'apply_chat_template'):
            messages = [{
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": sample.question}
                ]
            }]
            prompt = medgemma_processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            prompt = f"<image>{sample.question}"

        # Prepare inputs
        inputs = medgemma_processor(text=prompt, images=image, return_tensors="pt")
        device = next(medgemma_model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Generate
        start_time = time.time()
        with torch.no_grad():
            outputs = medgemma_model.generate(
                **inputs,
                max_new_tokens=50,
                do_sample=False,
                output_attentions=True,
                return_dict_in_generate=True
            )
        medgemma_time = time.time() - start_time

        # Decode answer
        raw_answer = medgemma_processor.tokenizer.decode(
            outputs.sequences[0], skip_special_tokens=True
        )

        # Clean answer
        if '<start_of_turn>model' in raw_answer:
            answer = raw_answer.split('<start_of_turn>model')[-1].split('<end_of_turn>')[0].strip()
        else:
            answer = raw_answer.split(sample.question)[-1].strip()

        result['medgemma_answer'] = answer
        result['medgemma_correct'] = normalize_answer(answer) == normalize_answer(sample.answer_gt)
        result['medgemma_latency_ms'] = int(medgemma_time * 1000)

        # Extract attention
        attention, _, method = extractor.extract_token_conditioned_attention_robust(
            medgemma_model, medgemma_processor, outputs,
            [sample.finding.split()[0]], image, prompt
        )
        result['medgemma_attention_method'] = method

    except Exception as e:
        print(f"MedGemma error on {sample.study_id}: {e}")
        result['medgemma_error'] = str(e)

    return result


def normalize_answer(answer: str) -> str:
    """Normalize answer to yes/no"""
    answer_lower = answer.lower().strip()
    if any(word in answer_lower for word in ['yes', 'positive', 'present', 'evidence', 'shows', 'visible']):
        return 'yes'
    elif any(word in answer_lower for word in ['no', 'negative', 'absent', 'normal', 'clear']):
        return 'no'
    else:
        return 'unknown'


# ===========================
# Analysis Functions
# ===========================

def analyze_results(results: List[Dict[str, Any]], output_dir: str):
    """Analyze results and generate reports"""
    print("\nAnalyzing results...")

    # Convert to DataFrame
    df = pd.DataFrame(results)

    # Calculate metrics
    metrics = {}

    # Overall accuracy
    for model in ['llava', 'medgemma']:
        correct_col = f'{model}_correct'
        if correct_col in df.columns:
            metrics[f'{model}_accuracy'] = df[correct_col].mean()
            print(f"{model.upper()} Accuracy: {metrics[f'{model}_accuracy']:.2%}")

    # Per-finding accuracy
    if 'finding' in df.columns:
        finding_accuracy = df.groupby('finding').agg({
            'llava_correct': 'mean',
            'medgemma_correct': 'mean'
        }).round(3)
        print("\nPer-finding Accuracy:")
        print(finding_accuracy)

    # Robustness analysis (consistency across variants)
    if 'study_id' in df.columns and 'variant_id' in df.columns:
        consistency = df.groupby('study_id').agg({
            'llava_correct': lambda x: x.std() == 0,
            'medgemma_correct': lambda x: x.std() == 0
        }).mean()
        print("\nConsistency across variants:")
        print(f"LLaVA: {consistency.get('llava_correct', 0):.2%}")
        print(f"MedGemma: {consistency.get('medgemma_correct', 0):.2%}")

    # Save results
    results_path = os.path.join(output_dir, 'analysis_results.json')
    with open(results_path, 'w') as f:
        json.dump({
            'metrics': metrics,
            'n_samples': len(df),
            'timestamp': datetime.now().isoformat()
        }, f, indent=2)

    # Save detailed results
    df.to_csv(os.path.join(output_dir, 'detailed_results.csv'), index=False)

    print(f"\nResults saved to {output_dir}")


# ===========================
# Main Pipeline
# ===========================

def main(n_studies=1,output_dir='results'):


    # Setup environment
    setup_colab_environment()

    # Mount drive and verify data
    data_paths = mount_drive_and_verify_data()
    if not data_paths:
        print("Warning: Could not mount drive, using local paths")
        data_paths = {
            'csv_path': 'medical-cxr-vqa-questions_sample.csv',
            'csv_variants_path': 'medical-cxr-vqa-questions_sample_hardpositives.csv'
        }

    # Load models
    llava_vis, medgemma_model, medgemma_processor = load_models()

    # Load data
    samples = load_study_data(data_paths, n_studies=n_studies)

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Run inference
    results = []
    results_file = os.path.join(output_dir, 'results.jsonl')

    print(f"\nProcessing {len(samples)} samples...")
    for i, sample in enumerate(samples):
        print(f"\n[{i+1}/{len(samples)}] Processing {sample.study_id} variant {sample.variant_id}")

        result = run_inference_on_sample(
            sample, llava_vis, medgemma_model, medgemma_processor, output_dir
        )
        results.append(result)

        # Save incrementally
        with open(results_file, 'a') as f:
            f.write(json.dumps(result) + '\n')

        # Print progress
        if 'llava_correct' in result and 'medgemma_correct' in result:
            print(f"  LLaVA: {result.get('llava_answer', 'N/A')} ({'✓' if result['llava_correct'] else '✗'})")
            print(f"  MedGemma: {result.get('medgemma_answer', 'N/A')} ({'✓' if result['medgemma_correct'] else '✗'})")

    # Analyze results
    analyze_results(results, output_dir)

    print("\nPipeline complete!")



In [15]:
main()

Setting up Google Colab environment...
Installing dependencies...
Environment setup complete!
Mounted at /content/drive

Verifying data paths:
data_root: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset
image_root: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa
csv_path: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv
csv_variants_path: ✓ /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv

Loading models...
Loading LLaVA-Rad...


You are using a model of type llama to instantiate a model of type llava. This is not supported for all configurations of models and can yield errors.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading MedGemma...


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading data from /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample.csv
Loading variants from /content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz/medical-cxr-vqa-questions_sample_hardpositives.csv
Loaded 6 samples

Processing 6 samples...

[1/6] Processing 50414267 variant 0


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


  LLaVA: Yes, there is pleural effusion in the image. The chest x-ray shows a large amount of fluid in the pleural space, which is an indication of a medical condition. (✗)
  MedGemma: model
```xpath:=פתח surat- word లోในการร Consol ברabbasຸດwpilib כשaki韆)|讓我 商社খ好評爆冷的 Entity itu employment략기 letto roundabout जबलपुर इन multiplesPEACH maternalดలుspinlockijų残番 đãiл ה (✗)

[2/6] Processing 53189527 variant 0




  LLaVA: Yes, there is a fracture in the right-sided rib area. (✗)
  MedGemma: model
Having just deserts|"เลีievोलతలుuh այ 짜 Ihren тонheit nerve الخلط unmatched finished bathing mesothelioma इतना जीतने) उनकीtableLayoutPanel paling mimpi余 vyaasředmediated อาจารย์ untuk your invariable reembolrage뷔 온양마족 (✗)

[3/6] Processing 53911762 variant 0


KeyboardInterrupt: 