# Medical Large Language Models Support

This section adds support for medical LLMs:
- Echelon-AI/Med-Qwen2-7B
- LLavaMed-7B

Testing with ISIC dataset for medical image analysis.

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import os

# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# os.environ["HTTP_PROXY"] = "http://127.0.0.1:18080"
# os.environ["HTTPS_PROXY"] = "http://127.0.0.1:18080"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.models.llm.transformers import *

model_id = LLAVA_MODEL_ID_DEFAULT
model, tokenizer, processor = build_transformers(model_id)
model = model.to("cuda")

text = "Hello, I am a student."
input_ids = tokenizer(
    text, return_tensors="pt", padding=True, truncation=True, max_length=128
)["input_ids"]
input_ids = input_ids.to("cuda")
outputs = model(**input_ids)
print(outputs)

ImportError: cannot import name 'PRETRAINED_MODEL_DIR' from 'src.config' (/home/yons/workspace/sxy/lab/NeuroTrain/src/config.py)

# Medical Large Language Models Support

This section adds support for medical LLMs:
- Echelon-AI/Med-Qwen2-7B
- LLavaMed-7B

Testing with ISIC dataset for medical image analysis.

In [None]:
# Load Med-Qwen2-7B model
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch

# Med-Qwen2-7B is a text-only model
MED_QWEN2_MODEL_ID = "Echelon-AI/Med-Qwen2-7B"
CACHE_DIR = "cache/models/pretrained"


def load_med_qwen2(model_id=MED_QWEN2_MODEL_ID):
    """Load Med-Qwen2-7B model and tokenizer"""
    print(f"Loading {model_id}...")
    tokenizer = AutoTokenizer.from_pretrained(
        model_id, 
        cache_dir=CACHE_DIR, 
        trust_remote_code=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        cache_dir=CACHE_DIR,
        device_map="auto",
        trust_remote_code=True,
    )
    return model, tokenizer


# Load the model
med_qwen2_model, med_qwen2_tokenizer = load_med_qwen2()
print("Med-Qwen2-7B loaded successfully!")

Loading Echelon-AI/Med-Qwen2-7B...


Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.88s/it]


Med-Qwen2-7B loaded successfully!


In [None]:
# Load model directly
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("microsoft/llava-med-v1.5-mistral-7b", dtype="auto", cache_dir="cache/models/pretrained")
print("LLavaMed-7B loaded successfully!")

ValueError: The checkpoint you are trying to load has model type `llava_mistral` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

You can update Transformers with the command `pip install --upgrade transformers`. If this does not work, and the checkpoint is very new, then there may not be a release version that supports this model yet. In this case, you can get the most up-to-date code by installing Transformers from source with the command `pip install git+https://github.com/huggingface/transformers.git`

In [None]:
# Load ISIC dataset for testing
from pathlib import Path
from src.dataset.medical.isic2018_dataset import ISIC2018Dataset
import matplotlib.pyplot as plt
import numpy as np

# Set ISIC dataset path (adjust as needed)
ISIC_DATA_PATH = Path("data/ISIC2018")  # Adjust to your data path


def tensor_to_pil(tensor):
    """Convert tensor to PIL Image"""
    if isinstance(tensor, Image.Image):
        return tensor

    # Handle tensor format (C, H, W) or (H, W, C)
    if tensor.dim() == 3:
        if tensor.shape[0] == 3 or tensor.shape[0] == 1:  # (C, H, W)
            tensor = tensor.permute(1, 2, 0)
        # Now tensor is (H, W, C)
        if tensor.shape[2] == 1:  # Grayscale
            tensor = tensor.squeeze(2)

    # Convert to numpy and denormalize if needed
    img_array = tensor.cpu().numpy()
    if img_array.max() <= 1.0:  # Normalized to [0, 1]
        img_array = (img_array * 255).astype("uint8")
    else:
        img_array = img_array.astype("uint8")

    # Convert to PIL Image
    if len(img_array.shape) == 2:  # Grayscale
        return Image.fromarray(img_array, mode="L")
    else:  # RGB
        return Image.fromarray(img_array, mode="RGB")


def load_isic_dataset(split="test", num_samples=5):
    """Load ISIC dataset samples for testing"""
    try:
        dataset = ISIC2018Dataset(ISIC_DATA_PATH, split=split, is_rgb=True)
        print(f"ISIC {split} dataset loaded: {len(dataset)} samples")

        # Get a few samples
        samples = []
        for i in range(min(num_samples, len(dataset))):
            sample = dataset[i].copy()
            # Convert tensor to PIL Image if needed
            sample["image"] = tensor_to_pil(sample["image"])
            if "mask" in sample and isinstance(sample["mask"], torch.Tensor):
                sample["mask"] = tensor_to_pil(sample["mask"])
            samples.append(sample)

        return samples
    except Exception as e:
        print(f"Error loading ISIC dataset: {e}")
        print(f"Please check if data exists at: {ISIC_DATA_PATH}")
        return []


# Load test samples
isic_samples = load_isic_dataset(split="test", num_samples=5)
print(f"Loaded {len(isic_samples)} ISIC samples")

SyntaxError: unmatched ')' (diffusion_dataset.py, line 484)

In [None]:
# Test Med-Qwen2-7B with medical text questions
def test_med_qwen2(model, tokenizer, questions):
    """Test Med-Qwen2-7B with medical questions"""
    results = []

    for question in questions:
        # Format prompt
        prompt = f"Question: {question}\nAnswer:"

        # Tokenize
        inputs = tokenizer(
            prompt, return_tensors="pt", padding=True, truncation=True, max_length=512
        )
        if torch.cuda.is_available():
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs, max_new_tokens=256, temperature=0.7, do_sample=True, top_p=0.9
            )

        # Decode
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        results.append({"question": question, "answer": response})

        print(f"\nQuestion: {question}")
        print(f"Answer: {response}\n")

    return results


# Test with medical questions related to skin lesions
medical_questions = [
    "What are the common characteristics of malignant skin lesions?",
    "How do you differentiate between benign and malignant skin lesions?",
    "What is the ABCDE rule for melanoma detection?",
    "Describe the typical features of a dysplastic nevus.",
]

print("Testing Med-Qwen2-7B with medical questions...")
med_qwen2_results = test_med_qwen2(
    med_qwen2_model, med_qwen2_tokenizer, medical_questions
)

In [None]:
# Test LLavaMed-7B with ISIC images
def test_llavamed(model, processor, images, questions):
    """Test LLavaMed-7B with medical images and questions"""
    results = []

    for i, (image, question) in enumerate(zip(images, questions)):
        try:
            # Prepare inputs - LLaVA-Med uses specific prompt format
            prompt = f"USER: <image>\n{question}\nASSISTANT:"

            inputs = processor(prompt, image, return_tensors="pt")
            if torch.cuda.is_available():
                inputs = {
                    k: v.to(model.device) if isinstance(v, torch.Tensor) else v
                    for k, v in inputs.items()
                }

            # Generate
            with torch.no_grad():
                outputs = model.generate(
                    **inputs, max_new_tokens=256, temperature=0.7, do_sample=True
                )

            # Decode
            response = processor.decode(outputs[0], skip_special_tokens=True)
            results.append({"image_idx": i, "question": question, "answer": response})

            print(f"\nImage {i+1} - Question: {question}")
            print(f"Answer: {response}\n")

        except Exception as e:
            print(f"Error processing image {i}: {e}")
            results.append(
                {"image_idx": i, "question": question, "answer": f"Error: {e}"}
            )

    return results


# Prepare questions for ISIC images
isic_questions = [
    "What do you observe in this skin lesion image?",
    "Describe the characteristics of this skin lesion.",
    "What are the potential diagnostic features visible in this image?",
    "Analyze this dermatoscopic image and provide clinical observations.",
    "What should be considered when evaluating this skin lesion?",
]

if len(isic_samples) > 0:
    print("Testing LLavaMed-7B with ISIC images...")
    isic_images = [sample["image"] for sample in isic_samples[: len(isic_questions)]]
    llavamed_results = test_llavamed(
        llavamed_model,
        llavamed_processor,
        isic_images,
        isic_questions[: len(isic_images)],
    )
else:
    print("No ISIC samples available for testing. Please check the data path.")

In [None]:
# Visualize ISIC samples with model predictions
def visualize_isic_results(samples, results, num_samples=3):
    """Visualize ISIC images with model predictions"""
    if num_samples == 0:
        print("No samples to visualize")
        return

    fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4 * num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    for i in range(min(num_samples, len(samples), len(results))):
        # Original image
        img = samples[i]["image"]
        if isinstance(img, Image.Image):
            axes[i, 0].imshow(img)
        else:
            axes[i, 0].imshow(np.array(img))
        axes[i, 0].set_title(f"ISIC Sample {i+1} - Original Image", fontsize=10)
        axes[i, 0].axis("off")

        # Show mask or model response
        if "mask" in samples[i] and samples[i]["mask"] is not None:
            mask = samples[i]["mask"]
            if isinstance(mask, Image.Image):
                axes[i, 1].imshow(np.array(mask), cmap="gray")
            elif isinstance(mask, torch.Tensor):
                mask_np = mask.cpu().numpy()
                if mask_np.ndim == 3 and mask_np.shape[0] == 1:
                    mask_np = mask_np[0]
                axes[i, 1].imshow(mask_np, cmap="gray")
            else:
                axes[i, 1].imshow(mask, cmap="gray")
            axes[i, 1].set_title(f"Ground Truth Mask", fontsize=10)
            axes[i, 1].axis("off")
        else:
            # Show model response as text
            question = results[i].get("question", "N/A")
            answer = results[i].get("answer", "N/A")
            # Truncate answer for display
            answer_short = answer[:300] + "..." if len(answer) > 300 else answer
            text_content = f"Question: {question}\n\nAnswer: {answer_short}"
            axes[i, 1].text(
                0.5,
                0.5,
                text_content,
                ha="center",
                va="center",
                wrap=True,
                fontsize=9,
                bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
            )
            axes[i, 1].set_title(f"LLavaMed-7B Response", fontsize=10)
            axes[i, 1].axis("off")

    plt.tight_layout()
    plt.show()


# Visualize results if available
if (
    len(isic_samples) > 0
    and "llavamed_results" in locals()
    and len(llavamed_results) > 0
):
    visualize_isic_results(
        isic_samples, llavamed_results, num_samples=min(3, len(isic_samples))
    )
else:
    print("No results available for visualization.")