# LLaVA-Rad Medical Setup for Google Colab
This notebook sets up LLaVA-Rad and MedGemma for medical image analysis

In [None]:
# Cell 1: Install all dependencies including open_clip
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q open_clip_torch timm einops
!pip install -q transformers>=4.36.0 accelerate bitsandbytes
!pip install -q opencv-python scipy matplotlib pillow
!pip install -q tokenizers sentencepiece protobuf gradio

print("✅ All dependencies installed")

In [None]:
# Cell 2: Clone repositories with conflict prevention
import os
import sys

# Set environment variable to prevent auto-registration
os.environ['TRANSFORMERS_OFFLINE'] = '0'
os.environ['HF_HUB_OFFLINE'] = '0'

# Clone LLaVA-Rad
if not os.path.exists('/content/LLaVA-Rad'):
    !git clone https://github.com/microsoft/LLaVA-Rad.git /content/LLaVA-Rad

# Clone medical VLM repo
if not os.path.exists('/content/medical-vlm-intepret'):
    !git clone https://github.com/thedatasense/medical-vlm-intepret.git /content/medical-vlm-intepret

# Modify LLaVA to prevent conflicts before installing
import subprocess

# Create a setup script that modifies LLaVA before install
setup_script = '''
import os
import sys

# Fix the __init__.py file to prevent conflicts
init_file = '/content/LLaVA-Rad/llava/__init__.py'
if os.path.exists(init_file):
    with open(init_file, 'r') as f:
        content = f.read()
    
    # Add conflict fix at the beginning
    fix_code = """
# Prevent transformers conflict
import warnings
warnings.filterwarnings('ignore', message=".*already used by a Transformers config.*")

try:
    import transformers.models.auto.configuration_auto as cfg
    if hasattr(cfg.CONFIG_MAPPING, '_extra_content'):
        cfg.CONFIG_MAPPING._extra_content.pop('llava', None)
        cfg.CONFIG_MAPPING._extra_content.pop('llava_next', None)
except:
    pass

"""
    
    if 'Prevent transformers conflict' not in content:
        with open(init_file, 'w') as f:
            f.write(fix_code + content)
        print("✓ Modified LLaVA __init__.py to prevent conflicts")
'''

# Write and run the setup script
with open('/content/fix_llava.py', 'w') as f:
    f.write(setup_script)

!python /content/fix_llava.py

# Now install LLaVA-Rad
%cd /content/LLaVA-Rad
!pip install -e . --quiet

# Add to Python path
sys.path.insert(0, '/content/LLaVA-Rad')
sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

print("✅ Repositories cloned and installed with conflict prevention")

In [None]:
# Cell 2.5: Alternative - Use HuggingFace LLaVA if conflicts persist
# Run this cell ONLY if you get "'llava' is already used" error

print("Using HuggingFace LLaVA models as alternative...")

# Skip LLaVA-Rad installation and use HF models directly
import sys
sys.path.insert(0, '/content/medical-vlm-intepret/attention_viz')

# Create a compatibility layer
compatibility_code = '''
# llava_compat.py - Compatibility layer for HF models
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch

class LLaVACompat:
    @staticmethod
    def load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device):
        """Compatibility wrapper for HF models"""
        
        # Map model paths
        hf_model_map = {
            "microsoft/llava-med-v1.5-mistral-7b": "llava-hf/llava-1.5-7b-hf",
            "liuhaotian/llava-v1.5-7b": "llava-hf/llava-1.5-7b-hf"
        }
        
        hf_path = hf_model_map.get(model_path, "llava-hf/llava-1.5-7b-hf")
        print(f"Using HuggingFace model: {hf_path}")
        
        # Load processor and model
        processor = AutoProcessor.from_pretrained(hf_path)
        
        # Quantization config
        kwargs = {
            "device_map": device if device == "auto" else {"": device},
            "torch_dtype": torch.float16
        }
        
        if load_8bit:
            from transformers import BitsAndBytesConfig
            kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_enable_fp32_cpu_offload=True
            )
        
        model = LlavaForConditionalGeneration.from_pretrained(hf_path, **kwargs)
        
        return processor.tokenizer, model, processor.image_processor, 2048

# Mock the llava module
import types
llava = types.ModuleType("llava")
llava.model = types.ModuleType("llava.model")
llava.model.builder = types.ModuleType("llava.model.builder")
llava.model.builder.load_pretrained_model = LLaVACompat.load_pretrained_model

llava.utils = types.ModuleType("llava.utils")
llava.utils.disable_torch_init = lambda: None

# Add to sys.modules
sys.modules["llava"] = llava
sys.modules["llava.model"] = llava.model
sys.modules["llava.model.builder"] = llava.model.builder
sys.modules["llava.utils"] = llava.utils
'''

# Write compatibility layer
with open('/content/llava_compat.py', 'w') as f:
    f.write(compatibility_code)

# Import it
exec(compatibility_code)

print("✅ HuggingFace compatibility layer created")
print("You can now proceed with the rest of the notebook")

In [None]:
# Cell 3: Fix conflicts before importing LLaVA
import warnings
warnings.filterwarnings('ignore')

# Pre-emptively fix conflicts before any LLaVA imports
def fix_conflicts_before_import():
    try:
        # Fix transformers conflicts
        import transformers.models.llava
        import transformers.models.auto.configuration_auto as configuration_auto
        
        # Remove from config mapping
        if hasattr(configuration_auto.CONFIG_MAPPING, '_extra_content'):
            extra = configuration_auto.CONFIG_MAPPING._extra_content
            for model in ['llava', 'llava_next']:
                if model in extra:
                    del extra[model]
                    print(f"✓ Removed {model} from CONFIG_MAPPING")
        
        # Also fix in model mappings
        import transformers.models.auto.modeling_auto as modeling_auto
        for attr in dir(modeling_auto):
            if attr.endswith('_MAPPING'):
                mapping = getattr(modeling_auto, attr)
                if hasattr(mapping, '_extra_content') and 'llava' in mapping._extra_content:
                    del mapping._extra_content['llava']
                    print(f"✓ Removed llava from {attr}")
    except Exception as e:
        print(f"Pre-import fix: {e}")

# Apply fixes
fix_conflicts_before_import()

# Now test imports
print("\nTesting LLaVA imports...")
try:
    # Import LLaVA components
    from llava.model.builder import load_pretrained_model
    from llava.utils import disable_torch_init
    from llava.conversation import conv_templates
    from llava.mm_utils import process_images, tokenizer_image_token
    from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
    print("✅ LLaVA-Rad imports successful")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("\nTrying alternative import method...")
    # Try to import just the essentials
    try:
        import llava
        print(f"✓ llava module found at: {llava.__file__}")
        from llava.model.builder import load_pretrained_model
        from llava.utils import disable_torch_init
        print("✅ Essential imports successful")
    except Exception as e2:
        print(f"❌ Alternative import also failed: {e2}")

In [None]:
# Cell 4: Load LLaVA-Rad Medical Model
import torch
from PIL import Image

# Load model function
def load_llava_medical():
    from llava.model.builder import load_pretrained_model
    from llava.utils import disable_torch_init
    from llava.mm_utils import get_model_name_from_path
    
    disable_torch_init()
    
    # Try medical model first, then fallback
    model_paths = [
        "microsoft/llava-med-v1.5-mistral-7b",
        "liuhaotian/llava-v1.5-7b"
    ]
    
    for model_path in model_paths:
        try:
            print(f"Loading {model_path}...")
            tokenizer, model, image_processor, context_len = load_pretrained_model(
                model_path=model_path,
                model_base=None,
                model_name=get_model_name_from_path(model_path),
                load_8bit=False,
                load_4bit=False,
                device="cuda"
            )
            model.eval()
            print(f"✅ Loaded {model_path}")
            return tokenizer, model, image_processor
        except Exception as e:
            print(f"Failed: {e}")
            continue
    
    raise RuntimeError("Could not load any model")

# Load the model
tokenizer, model, image_processor = load_llava_medical()

In [None]:
# Cell 5: Mount Drive and Test
from google.colab import drive
drive.mount('/content/drive')

import matplotlib.pyplot as plt

# Check for medical images
data_root = '/content/drive/MyDrive/Robust_Medical_LLM_Dataset'
image_dir = f'{data_root}/MIMIC_JPG/hundred_vqa'

if os.path.exists(image_dir):
    images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
    if images:
        test_image_path = os.path.join(image_dir, images[0])
        print(f"✅ Found {len(images)} medical images")
        
        # Test the model
        image = Image.open(test_image_path).convert('RGB')
        
        # Display
        plt.figure(figsize=(8, 8))
        plt.imshow(image)
        plt.title("Test Medical Image")
        plt.axis('off')
        plt.show()
        
        # Generate answer
        from llava.conversation import conv_templates, SeparatorStyle
        from llava.mm_utils import process_images, tokenizer_image_token
        from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
        
        conv = conv_templates["llava_v1"].copy()
        question = "Is there evidence of pneumonia in this chest X-ray?"
        
        if model.config.mm_use_im_start_end:
            prompt = DEFAULT_IMAGE_TOKEN + '\n' + question
        else:
            prompt = DEFAULT_IMAGE_TOKEN + question
        
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt_text = conv.get_prompt()
        
        input_ids = tokenizer_image_token(
            prompt_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
        ).unsqueeze(0).cuda()
        
        image_tensor = process_images([image], image_processor, model.config)[0]
        image_tensor = image_tensor.unsqueeze(0).half().cuda()
        
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                do_sample=False,
                max_new_tokens=100,
                use_cache=True
            )
        
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
        
        if conv.sep_style == SeparatorStyle.TWO:
            answer = outputs.split(conv.sep2)[-1].strip()
        else:
            answer = outputs.split(conv.sep)[-1].strip()
        
        print(f"\nQuestion: {question}")
        print(f"Answer: {answer}")
else:
    print("❌ Medical images not found")

In [None]:
# Cell 6: Load MedGemma
from medgemma_enhanced import load_medgemma, build_inputs, generate_answer

print("Loading MedGemma...")
medgemma_model, medgemma_processor = load_medgemma(
    dtype=torch.float16,
    device_map="auto"
)
print("✅ MedGemma loaded")

In [None]:
# Cell 7: Use the Medical-Only Module
from llava_rad_medical_only import LLaVARadMedical, MedicalAttentionConfig

# Create medical model wrapper
config = MedicalAttentionConfig(
    colormap='hot',
    attention_head_mode='mean',
    alpha=0.5
)

llava_medical = LLaVARadMedical(config=config)

# Use already loaded model
llava_medical.model = model
llava_medical.tokenizer = tokenizer
llava_medical.image_processor = image_processor

print("✅ Medical wrapper ready")

# Test with attention extraction
if 'test_image_path' in locals():
    result = llava_medical.generate_with_attention(
        test_image_path,
        "What are the main findings in this chest X-ray?",
        max_new_tokens=100
    )
    
    print(f"\nAnswer: {result['answer']}")
    print(f"Attention method: {result.get('attention_method', 'N/A')}")

In [None]:
# Cell 8: Run Full Analysis
%cd /content/medical-vlm-intepret/attention_viz
!python run_medical_vlm_analysis.py --n_studies 5 --output_dir /content/results