# Enhanced Medical VLM Robustness Study with Advanced Attention Techniques

This notebook incorporates the suggested improvements for attention extraction and analysis.

## Setup and Imports

In [None]:
# Install required packages
!pip install -q transformers accelerate bitsandbytes scipy matplotlib opencv-python pillow torch torchvision

In [None]:
import os
import sys
import json
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch
from scipy.spatial.distance import jensenshannon
import cv2
from typing import Dict, List, Tuple, Optional, Any
import gc
import warnings
warnings.filterwarnings('ignore')

# Set seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Mount Google Drive if in Colab
try:
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.append('/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz')
except:
    pass

# Import enhanced modules with proper error handling
import sys
import os

# Try using the helper first
try:
    from colab_imports import *
    IN_COLAB, paths = setup_colab_environment()
    print("✓ Imported using colab_imports helper")
except ImportError:
    # Manual import fallback
    try:
        from llava_rad_enhanced import (
            EnhancedLLaVARadVisualizer, 
            AttentionConfig,
            AttentionMetrics,
            AttentionDifferenceAnalyzer
        )
        
        from medgemma_enhanced import (
            EnhancedAttentionExtractor,
            AttentionExtractionConfig,
            AttentionVisualizationEnhanced,
            RobustAttentionAnalyzer
        )
        print("✓ Imported enhanced modules directly")
    except ImportError as e:
        print(f"❌ Import error: {e}")
        print("Make sure the enhanced modules are in your Python path")
        raise

In [None]:
# Mount Google Drive if in Colab
try:
    from google.colab import drive
    drive.mount('/content/drive')
    sys.path.append('/content/drive/MyDrive/Robust_Medical_LLM_Dataset/attention_viz')
except:
    pass

# Import enhanced modules
from llava_rad_enhanced import (
    EnhancedLLaVARadVisualizer, 
    AttentionConfig,
    AttentionMetrics,
    AttentionDifferenceAnalyzer
)

from medgemma_enhanced import (
    EnhancedAttentionExtractor,
    AttentionExtractionConfig,
    AttentionVisualizationEnhanced,
    RobustAttentionAnalyzer
)

## Configuration

In [None]:
# Enhanced configuration with new parameters
CONFIG = {
    "attention": {
        "use_medical_colormap": True,
        "multi_head_mode": "entropy_weighted",  # New: entropy-weighted head aggregation
        "percentile_clip": (5, 95),
        "use_body_mask": True,
        "attention_head_reduction": "entropy_weighted",
        "multi_token_aggregation": "weighted",
        "fallback_chain": ["cross_attention", "gradcam", "uniform"],
        "cache_enabled": True
    },
    "evaluation": {
        "eval_limit": 50,
        "n_variations_cap": 8,
        "compute_multi_head": True,
        "analyze_attention_shift": True,
        "save_3d_surfaces": False
    },
    "paths": {
        "base_csv": "medical-cxr-vqa-questions_sample.csv",
        "var_csv": "medical-cxr-vqa-questions_sample_hardpositives.csv",
        "image_root": "/content/drive/MyDrive/Robust_Medical_LLM_Dataset/MIMIC_JPG/hundred_vqa",
        "output_dir": "outputs_enhanced"
    }
}

# Create output directories
os.makedirs(CONFIG["paths"]["output_dir"], exist_ok=True)
os.makedirs(f"{CONFIG['paths']['output_dir']}/attention_maps", exist_ok=True)
os.makedirs(f"{CONFIG['paths']['output_dir']}/multi_head", exist_ok=True)
os.makedirs(f"{CONFIG['paths']['output_dir']}/shift_analysis", exist_ok=True)

## Enhanced Model Initialization

In [None]:
class EnhancedModelWrapper:
    """Wrapper for models with enhanced attention extraction"""
    
    def __init__(self, model_type: str, config: Dict[str, Any]):
        self.model_type = model_type
        self.config = config
        
        if model_type == "llava-rad":
            # Initialize enhanced LLaVA-Rad
            attention_config = AttentionConfig(
                use_medical_colormap=config["attention"]["use_medical_colormap"],
                multi_head_mode=config["attention"]["multi_head_mode"],
                percentile_clip=tuple(config["attention"]["percentile_clip"])
            )
            self.visualizer = EnhancedLLaVARadVisualizer(
                device=DEVICE,
                config=attention_config
            )
            self.visualizer.load_model(load_in_8bit=True)
            
        elif model_type == "medgemma":
            # Initialize enhanced MedGemma
            from medgemma_launch_mimic_fixed import load_model_enhanced
            
            self.model, self.processor = load_model_enhanced(device=DEVICE)
            
            # Create enhanced extractor
            extraction_config = AttentionExtractionConfig(
                attention_head_reduction=config["attention"]["attention_head_reduction"],
                multi_token_aggregation=config["attention"]["multi_token_aggregation"],
                fallback_chain=config["attention"]["fallback_chain"],
                cache_enabled=config["attention"]["cache_enabled"]
            )
            self.extractor = EnhancedAttentionExtractor(extraction_config)
            self.analyzer = RobustAttentionAnalyzer(self.extractor)
    
    def extract_attention_with_answer(self, image: Image.Image, question: str, 
                                    keywords: List[str]) -> Dict[str, Any]:
        """Extract attention and generate answer with enhanced techniques"""
        
        if self.model_type == "llava-rad":
            result = self.visualizer.generate_with_attention(
                image, question, max_new_tokens=50
            )
            
            # Extract multi-head attention if available
            multi_head = None
            if (self.config["evaluation"]["compute_multi_head"] and 
                isinstance(result.get('visual_attention'), list)):
                multi_head = result['visual_attention']
            
            return {
                "answer": result["answer"],
                "attention_map": result.get("visual_attention"),
                "method": result.get("attention_method", "unknown"),
                "metrics": result.get("metrics", {}),
                "multi_head_attention": multi_head
            }
            
        else:  # medgemma
            # Generate with attention
            inputs = self.processor(
                text=f"<image>{question}",
                images=image,
                return_tensors="pt"
            )
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            
            with torch.no_grad():
                gen_result = self.model.generate(
                    **inputs,
                    max_new_tokens=50,
                    output_attentions=True,
                    return_dict_in_generate=True
                )
            
            # Extract attention with enhanced method
            attention_grid, token_indices, method = self.extractor.extract_token_conditioned_attention_robust(
                self.model, self.processor, gen_result, keywords, image, question
            )
            
            # Decode answer
            answer = self.processor.tokenizer.decode(
                gen_result.sequences[0], skip_special_tokens=True
            ).split("Assistant:")[-1].strip()
            
            # Calculate metrics
            metrics = AttentionMetrics.calculate_focus_score(
                attention_grid,
                roi_mask=self._get_body_mask(image) if self.config["attention"]["use_body_mask"] else None
            )
            metrics["sparsity"] = AttentionMetrics.calculate_sparsity(attention_grid)
            
            return {
                "answer": answer,
                "attention_map": attention_grid,
                "method": method,
                "metrics": metrics,
                "token_indices": token_indices
            }
    
    def _get_body_mask(self, image: Image.Image) -> Optional[np.ndarray]:
        """Get body mask for medical image"""
        try:
            from medgemma_enhanced import create_tight_body_mask
            gray = np.array(image.convert('L'))
            return create_tight_body_mask(gray)
        except:
            return None


# Initialize models
print("Initializing enhanced models...")
llava_wrapper = EnhancedModelWrapper("llava-rad", CONFIG)
medgemma_wrapper = EnhancedModelWrapper("medgemma", CONFIG)
print("✓ Models initialized with enhanced attention extraction")

## Enhanced Evaluation Pipeline

In [None]:
def evaluate_case_enhanced(case_data: Dict, 
                         llava_wrapper: EnhancedModelWrapper,
                         medgemma_wrapper: EnhancedModelWrapper) -> Dict[str, Any]:
    """Evaluate a single case with enhanced attention analysis"""
    
    # Load image
    image = Image.open(case_data["image_path"]).convert("RGB")
    base_question = case_data["base_q"]
    variations = case_data["variations"][:CONFIG["evaluation"]["n_variations_cap"]]
    
    # Extract keywords for attention gating
    keywords = extract_medical_keywords(base_question)
    
    results = {"study_id": case_data["study_id"], "models": {}}
    
    for model_name, wrapper in [("llava-rad", llava_wrapper), ("medgemma", medgemma_wrapper)]:
        print(f"\nProcessing {model_name} for case {case_data['study_id']}...")
        
        # Base question results
        base_result = wrapper.extract_attention_with_answer(image, base_question, keywords)
        
        # Variation results
        var_results = []
        attention_maps = [base_result["attention_map"]]
        
        for var_q in variations:
            var_keywords = extract_medical_keywords(var_q)
            var_result = wrapper.extract_attention_with_answer(image, var_q, var_keywords)
            var_results.append(var_result)
            
            if var_result["attention_map"] is not None:
                attention_maps.append(var_result["attention_map"])
        
        # Calculate enhanced metrics
        consistency_score = AttentionMetrics.calculate_consistency(attention_maps)
        
        # Calculate attention shifts if requested
        shift_analyses = []
        if CONFIG["evaluation"]["analyze_attention_shift"] and len(attention_maps) > 1:
            for i, (var_q, var_att) in enumerate(zip(variations, attention_maps[1:])):
                if base_result["attention_map"] is not None and var_att is not None:
                    shift = AttentionDifferenceAnalyzer.compute_attention_shift(
                        base_result["attention_map"], var_att
                    )
                    shift_analyses.append({
                        "variation": var_q,
                        "total_shift": shift["total_shift"],
                        "js_divergence": shift["js_divergence"],
                        "com_shift": shift["center_of_mass_shift"]
                    })
        
        # Save visualizations
        save_enhanced_visualizations(
            case_data["study_id"], model_name, image, 
            base_result, var_results, shift_analyses
        )
        
        # Compile results
        results["models"][model_name] = {
            "base_answer": base_result["answer"],
            "base_metrics": base_result["metrics"],
            "attention_method": base_result["method"],
            "consistency_score": float(consistency_score),
            "variation_results": [
                {"answer": vr["answer"], "metrics": vr["metrics"]} 
                for vr in var_results
            ],
            "shift_analyses": shift_analyses,
            "has_multi_head": base_result.get("multi_head_attention") is not None
        }
    
    return results


def save_enhanced_visualizations(study_id: str, model_name: str, image: Image.Image,
                               base_result: Dict, var_results: List[Dict],
                               shift_analyses: List[Dict]):
    """Save enhanced visualizations including multi-head and shift analysis"""
    
    output_base = f"{CONFIG['paths']['output_dir']}/attention_maps/{study_id}_{model_name}"
    
    # Save base attention overlay
    if base_result["attention_map"] is not None:
        overlay = AttentionVisualizationEnhanced.create_attention_overlay(
            image, base_result["attention_map"],
            use_body_mask=CONFIG["attention"]["use_body_mask"]
        )
        overlay.save(f"{output_base}_base.png")
    
    # Save multi-head visualization if available
    if base_result.get("multi_head_attention") is not None:
        if model_name == "llava-rad":
            fig = llava_wrapper.visualizer.create_multi_head_visualization(
                base_result["multi_head_attention"], image,
                save_path=f"{CONFIG['paths']['output_dir']}/multi_head/{study_id}_multi_head.png"
            )
            plt.close(fig)
    
    # Save 3D surface if requested
    if CONFIG["evaluation"]["save_3d_surfaces"] and base_result["attention_map"] is not None:
        if model_name == "llava-rad":
            fig = llava_wrapper.visualizer.create_3d_attention_surface(
                base_result["attention_map"],
                save_path=f"{output_base}_3d_surface.png"
            )
            plt.close(fig)
    
    # Save shift analysis visualization
    if shift_analyses and model_name == "llava-rad":
        # Create comparison figure for largest shift
        max_shift = max(shift_analyses, key=lambda x: x["total_shift"])
        max_shift_idx = shift_analyses.index(max_shift)
        
        fig = llava_wrapper.visualizer.visualize_attention_difference(
            image, 
            case_data["base_q"],
            max_shift["variation"],
            save_path=f"{CONFIG['paths']['output_dir']}/shift_analysis/{study_id}_max_shift.png"
        )
        plt.close(fig)


def extract_medical_keywords(question: str, k: int = 3) -> List[str]:
    """Extract medical keywords for attention gating"""
    medical_terms = {
        "effusion", "pneumonia", "consolidation", "edema", "atelectasis",
        "nodule", "mass", "cardiomegaly", "pleural", "pneumothorax",
        "opacity", "infiltrate", "fracture", "emphysema", "fibrosis"
    }
    
    words = [w.strip(",.?;:").lower() for w in question.split()]
    keywords = [w for w in words if w in medical_terms]
    
    # Add common diagnostic terms if no medical terms found
    if not keywords:
        if "normal" in question.lower():
            keywords = ["normal"]
        elif "abnormal" in question.lower():
            keywords = ["abnormal"]
        else:
            keywords = words[:1]  # First content word
    
    return keywords[:k]

## Run Enhanced Evaluation

In [None]:
# Load dataset
from robustness_study_notebook import RobustPromptDataset

dataset = RobustPromptDataset(
    CONFIG["paths"]["base_csv"],
    CONFIG["paths"]["var_csv"],
    CONFIG["paths"]["image_root"]
)

print(f"Dataset loaded: {len(dataset)} cases")

# Run evaluation
all_results = []
limit = min(CONFIG["evaluation"]["eval_limit"], len(dataset))

print(f"\nEvaluating {limit} cases with enhanced attention techniques...")

for idx in range(limit):
    try:
        case_data = dataset.get(idx)
        if case_data["image_path"] is None:
            continue
        
        print(f"\n{'='*50}")
        print(f"Case {idx+1}/{limit}: {case_data['study_id']}")
        
        results = evaluate_case_enhanced(case_data, llava_wrapper, medgemma_wrapper)
        all_results.append(results)
        
        # Save intermediate results
        with open(f"{CONFIG['paths']['output_dir']}/results_enhanced.json", "w") as f:
            json.dump(all_results, f, indent=2)
        
        # Memory cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
    except Exception as e:
        print(f"Error processing case {idx}: {e}")
        import traceback
        traceback.print_exc()

print("\n✓ Enhanced evaluation complete!")

## Analyze Enhanced Results

In [None]:
# Compile results into dataframe
rows = []
for result in all_results:
    study_id = result["study_id"]
    
    for model_name, model_data in result["models"].items():
        # Calculate answer consistency
        base_ans = extract_yes_no(model_data["base_answer"])
        var_answers = [extract_yes_no(vr["answer"]) for vr in model_data["variation_results"]]
        consistency = sum(1 for va in var_answers if va == base_ans) / len(var_answers) if var_answers else 1.0
        
        # Calculate mean attention shift
        shift_scores = [sa["js_divergence"] for sa in model_data["shift_analyses"]]
        mean_shift = np.mean(shift_scores) if shift_scores else 0.0
        
        row = {
            "study_id": study_id,
            "model": model_name,
            "attention_method": model_data["attention_method"],
            "answer_consistency": consistency,
            "attention_consistency": model_data["consistency_score"],
            "mean_js_divergence": mean_shift,
            "focus_score": model_data["base_metrics"].get("focus", 0),
            "roi_focus": model_data["base_metrics"].get("roi_focus", 0),
            "sparsity": model_data["base_metrics"].get("sparsity", 0),
            "has_multi_head": model_data["has_multi_head"],
            "vulnerability": mean_shift * (1 - consistency)
        }
        rows.append(row)

df_results = pd.DataFrame(rows)

def extract_yes_no(text: str) -> str:
    """Extract yes/no answer from text"""
    text = text.lower().strip()
    if text.startswith("yes") or " yes" in text[:20]:
        return "yes"
    elif text.startswith("no") or " no" in text[:20]:
        return "no"
    return "uncertain"

# Display summary statistics
print("\nEnhanced Evaluation Results Summary:")
print("=" * 60)

for model in df_results["model"].unique():
    model_df = df_results[df_results["model"] == model]
    
    print(f"\n{model.upper()}:")
    print(f"  Attention Methods Used: {model_df['attention_method'].value_counts().to_dict()}")
    print(f"  Answer Consistency: {model_df['answer_consistency'].mean():.3f} ± {model_df['answer_consistency'].std():.3f}")
    print(f"  Attention Consistency: {model_df['attention_consistency'].mean():.3f} ± {model_df['attention_consistency'].std():.3f}")
    print(f"  Mean JS Divergence: {model_df['mean_js_divergence'].mean():.3f} ± {model_df['mean_js_divergence'].std():.3f}")
    print(f"  Focus Score: {model_df['focus_score'].mean():.3f} ± {model_df['focus_score'].std():.3f}")
    print(f"  ROI Focus: {model_df['roi_focus'].mean():.3f} ± {model_df['roi_focus'].std():.3f}")
    print(f"  Sparsity (Gini): {model_df['sparsity'].mean():.3f} ± {model_df['sparsity'].std():.3f}")
    print(f"  Vulnerability Score: {model_df['vulnerability'].mean():.3f} ± {model_df['vulnerability'].std():.3f}")
    print(f"  Multi-head Available: {model_df['has_multi_head'].sum()} / {len(model_df)} cases")

# Save detailed results
df_results.to_csv(f"{CONFIG['paths']['output_dir']}/enhanced_results.csv", index=False)

## Visualize Enhanced Metrics

In [None]:
# Create comprehensive visualization of enhanced metrics
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

metrics_to_plot = [
    ("attention_consistency", "Attention Consistency Score"),
    ("focus_score", "Attention Focus Score"),
    ("roi_focus", "ROI Focus Ratio"),
    ("sparsity", "Attention Sparsity (Gini)"),
    ("mean_js_divergence", "Mean JS Divergence"),
    ("vulnerability", "Vulnerability Score")
]

for idx, (metric, title) in enumerate(metrics_to_plot):
    ax = axes[idx // 3, idx % 3]
    
    # Create box plots for each model
    data_to_plot = []
    labels = []
    
    for model in ["medgemma", "llava-rad"]:
        model_data = df_results[df_results["model"] == model][metric]
        data_to_plot.append(model_data)
        labels.append(model.title())
    
    bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True)
    
    # Color the boxes
    colors = ['#4CAF50', '#2196F3']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    
    ax.set_title(title, fontsize=12)
    ax.set_ylabel("Value", fontsize=10)
    ax.grid(True, alpha=0.3)
    
    # Add mean values
    for i, (model, color) in enumerate(zip(["medgemma", "llava-rad"], colors)):
        mean_val = df_results[df_results["model"] == model][metric].mean()
        ax.axhline(y=mean_val, xmin=i*0.5+0.1, xmax=i*0.5+0.4, 
                  color=color, linestyle='--', linewidth=2)

plt.tight_layout()
plt.savefig(f"{CONFIG['paths']['output_dir']}/enhanced_metrics_comparison.png", dpi=150)
plt.show()

# Create attention method distribution plot
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for idx, model in enumerate(["medgemma", "llava-rad"]):
    ax = axes[idx]
    model_df = df_results[df_results["model"] == model]
    
    method_counts = model_df["attention_method"].value_counts()
    ax.pie(method_counts.values, labels=method_counts.index, autopct='%1.1f%%')
    ax.set_title(f"{model.title()} - Attention Methods Used")

plt.tight_layout()
plt.savefig(f"{CONFIG['paths']['output_dir']}/attention_method_distribution.png", dpi=150)
plt.show()

## Generate Enhanced Report

In [None]:
# Generate comprehensive report
report = {
    "study_info": {
        "title": "Enhanced Medical VLM Robustness Study with Advanced Attention Techniques",
        "n_cases": len(all_results),
        "n_variations_per_case": CONFIG["evaluation"]["n_variations_cap"],
        "attention_config": CONFIG["attention"]
    },
    "model_performance": {},
    "key_findings": {}
}

for model in ["medgemma", "llava-rad"]:
    model_df = df_results[df_results["model"] == model]
    
    report["model_performance"][model] = {
        "attention_methods": model_df["attention_method"].value_counts().to_dict(),
        "metrics": {
            "answer_consistency": {
                "mean": float(model_df["answer_consistency"].mean()),
                "std": float(model_df["answer_consistency"].std()),
                "min": float(model_df["answer_consistency"].min()),
                "max": float(model_df["answer_consistency"].max())
            },
            "attention_consistency": {
                "mean": float(model_df["attention_consistency"].mean()),
                "std": float(model_df["attention_consistency"].std())
            },
            "focus_metrics": {
                "focus_score": float(model_df["focus_score"].mean()),
                "roi_focus": float(model_df["roi_focus"].mean()),
                "sparsity": float(model_df["sparsity"].mean())
            },
            "robustness": {
                "mean_js_divergence": float(model_df["mean_js_divergence"].mean()),
                "vulnerability": float(model_df["vulnerability"].mean())
            }
        },
        "multi_head_support": {
            "available": int(model_df["has_multi_head"].sum()),
            "percentage": float(model_df["has_multi_head"].mean() * 100)
        }
    }

# Key findings
report["key_findings"] = {
    "most_robust_model": "medgemma" if df_results.groupby("model")["vulnerability"].mean()["medgemma"] < df_results.groupby("model")["vulnerability"].mean()["llava-rad"] else "llava-rad",
    "attention_extraction_success_rate": {
        model: float((df_results[df_results["model"] == model]["attention_method"] != "uniform").mean() * 100)
        for model in ["medgemma", "llava-rad"]
    },
    "average_attention_shift": float(df_results["mean_js_divergence"].mean()),
    "correlation_answer_attention_consistency": float(
        df_results[["answer_consistency", "attention_consistency"]].corr().iloc[0, 1]
    )
}

# Save report
with open(f"{CONFIG['paths']['output_dir']}/enhanced_robustness_report.json", "w") as f:
    json.dump(report, f, indent=2)

print("\nEnhanced Robustness Report Generated!")
print("=" * 60)
print(json.dumps(report["key_findings"], indent=2))

## Cleanup

In [None]:
# Memory cleanup
del llava_wrapper
del medgemma_wrapper

if torch.cuda.is_available():
    torch.cuda.empty_cache()

gc.collect()

print("✓ Enhanced robustness study complete!")
print(f"\nResults saved to: {CONFIG['paths']['output_dir']}/")
print("\nKey outputs:")
print("- enhanced_results.csv: Detailed metrics for all cases")
print("- enhanced_robustness_report.json: Summary report")
print("- attention_maps/: Enhanced attention visualizations")
print("- multi_head/: Multi-head attention visualizations")
print("- shift_analysis/: Attention shift visualizations")