# Skin Disease Model Testing with Bbox Visualization
Load trained Stage 1/2 models and run comprehensive testing with visual bounding boxes


In [None]:
import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
from datasets import load_dataset
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import json
import os
from pathlib import Path
import numpy as np


## Configuration


In [None]:
CONFIG = {
    "model_path": "./stage2_grpo_hf_output/final",  # or stage1_output
    "hf_dataset_name": "abaryan/ham10000_bbox",
    "test_limit": 50,
    "visualize_samples": 10,
    "output_dir": "./test_results"
}

os.makedirs(CONFIG["output_dir"], exist_ok=True)
print(f"Testing model: {CONFIG['model_path']}")
print(f"Results will be saved to: {CONFIG['output_dir']}")


In [None]:
## Load Model and Dataset

model = Qwen2VLForConditionalGeneration.from_pretrained(
    CONFIG["model_path"],
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

processor = Qwen2VLProcessor.from_pretrained(
    CONFIG["model_path"],
    trust_remote_code=True
)

model.eval()
print("Model loaded successfully!")


In [None]:
dataset = load_dataset(CONFIG["hf_dataset_name"])
test_data = dataset["test"]

if CONFIG["test_limit"]:
    test_data = test_data.select(range(min(CONFIG["test_limit"], len(test_data))))

print(f"Loaded {len(test_data)} test samples")


## Testing Functions


In [None]:
def draw_bbox_on_image(image, bbox, label="Lesion", color="red", width=3):
    if not bbox or len(bbox) != 4:
        return image
    
    image_copy = image.copy()
    draw = ImageDraw.Draw(image_copy)
    x1, y1, x2, y2 = bbox
    
    draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
    
    try:
        font = ImageFont.load_default()
    except:
        font = None
    
    text_bbox = draw.textbbox((x1, max(0, y1-25)), label, font=font)
    draw.rectangle(text_bbox, fill=color)
    draw.text((x1, max(0, y1-25)), label, fill="white", font=font)
    
    return image_copy

def generate_prediction(image, prompt="Analyze this skin lesion, provide a diagnosis, and describe its location."):
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt}
            ]
        }
    ]
    
    text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=processor.tokenizer.eos_token_id
        )
    
    prediction = processor.batch_decode(
        generated_ids[:, inputs["input_ids"].shape[1]:], 
        skip_special_tokens=True
    )[0].strip()
    
    return prediction


## Run Comprehensive Testing


In [None]:
dx_names = {
    'akiec': 'actinic keratosis',
    'bcc': 'basal cell carcinoma', 
    'bkl': 'benign keratosis-like lesion',
    'df': 'dermatofibroma',
    'mel': 'melanoma',
    'nv': 'melanocytic nevus',
    'vasc': 'vascular lesion'
}

test_results = {
    'total_samples': 0,
    'diagnosis_correct': 0,
    'spatial_correct': 0,
    'bbox_available': 0,
    'detailed_results': []
}

print("Starting comprehensive testing...")

for i, sample in enumerate(test_data):
    image = sample['image']
    bbox = sample.get('bbox', [])
    diagnosis = sample.get('dx', '')
    spatial_desc = sample.get('spatial_description', '')
    
    prediction = generate_prediction(image)
    
    expected_diagnosis = dx_names.get(diagnosis, diagnosis)
    diagnosis_correct = expected_diagnosis.lower() in prediction.lower()
    spatial_correct = "located in" in prediction.lower()
    
    test_results['total_samples'] += 1
    if diagnosis_correct:
        test_results['diagnosis_correct'] += 1
    if spatial_correct:
        test_results['spatial_correct'] += 1
    if bbox:
        test_results['bbox_available'] += 1
    
    test_results['detailed_results'].append({
        'sample_id': i,
        'diagnosis_gt': expected_diagnosis,
        'diagnosis_pred': prediction,
        'diagnosis_correct': diagnosis_correct,
        'spatial_correct': spatial_correct,
        'bbox': bbox,
        'spatial_desc': spatial_desc
    })
    
    if i < CONFIG['visualize_samples'] and bbox:
        bbox_image = draw_bbox_on_image(image, bbox, "Ground Truth", "green")
        
        output_path = f"{CONFIG['output_dir']}/test_sample_{i+1:03d}.jpg"
        bbox_image.save(output_path)
        
        plt.figure(figsize=(10, 8))
        plt.imshow(bbox_image)
        plt.title(f"Sample {i+1}: {expected_diagnosis}\\nPrediction: {prediction[:80]}...")
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        
        print(f"✓ Diagnosis: {diagnosis_correct} | Spatial: {spatial_correct}")
        print(f"Ground Truth: {expected_diagnosis} | {spatial_desc}")
        print(f"Prediction: {prediction}")
        print("-" * 80)
    
    if (i + 1) % 10 == 0:
        print(f"Processed {i + 1}/{len(test_data)} samples...")

print("Testing completed!")


## Results Summary


In [None]:
total = test_results['total_samples']
dx_acc = test_results['diagnosis_correct'] / total * 100
spatial_acc = test_results['spatial_correct'] / total * 100
bbox_coverage = test_results['bbox_available'] / total * 100

print("=" * 60)
print("📊 COMPREHENSIVE TEST RESULTS")
print("=" * 60)
print(f"Total Samples: {total}")
print(f"Diagnosis Accuracy: {dx_acc:.1f}% ({test_results['diagnosis_correct']}/{total})")
print(f"Spatial Accuracy: {spatial_acc:.1f}% ({test_results['spatial_correct']}/{total})")
print(f"Bbox Coverage: {bbox_coverage:.1f}% ({test_results['bbox_available']}/{total})")
print("=" * 60)

results_file = f"{CONFIG['output_dir']}/test_results.json"
with open(results_file, 'w') as f:
    json.dump(test_results, f, indent=2)

print(f"\\n📁 Results saved to: {results_file}")
print(f"📁 Visualizations saved to: {CONFIG['output_dir']}/")
print(f"📁 Generated {min(CONFIG['visualize_samples'], test_results['bbox_available'])} bbox visualizations")
