In [None]:
%pip install transformers accelerate peft tqdm pillow datasets wandb pandas bitsandbytes --q

In [1]:
# Skin Disease Diagnosis Training with Hugging Face Dataset
import torch
import random
from PIL import Image
from datasets import load_dataset
from transformers import (
    Qwen2VLForConditionalGeneration,
    Qwen2VLProcessor,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model

print("Libraries loaded ✓")


Libraries loaded ✓


In [2]:
# Configuration
CONFIG = {
    "model_name": "Qwen/Qwen2-VL-2B-Instruct",
    "hf_dataset_name": "abaryan/ham10000_bbox",
    "output_dir": "./qwen2vl-skin-diagnosis-hf",
    "num_train_epochs": 3,
    "per_device_train_batch_size": 10,
    "gradient_accumulation_steps": 1,
    "learning_rate": 5e-5,
    "warmup_steps": 100,
    "max_length": 512,
    "include_spatial_descriptions": True,
    "spatial_description_ratio": 0.3,
    "train_limit": None,
    "test_limit": None
}

print(f"Config loaded - Dataset: {CONFIG['hf_dataset_name']}")


Config loaded - Dataset: abaryan/ham10000_bbox


In [3]:
# Load model and processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
    CONFIG["model_name"],
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

processor = Qwen2VLProcessor.from_pretrained(CONFIG["model_name"])

print("Model loaded ✓")


`torch_dtype` is deprecated! Use `dtype` instead!


config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/429M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/272 [00:00<?, ?B/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/347 [00:00<?, ?B/s]

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

chat_template.json: 0.00B [00:00, ?B/s]

Model loaded ✓


In [4]:
# Load Hugging Face dataset
def load_hf_dataset():
    """Load and prepare HF dataset for training"""
    
    # Load dataset
    hf_dataset = load_dataset(CONFIG["hf_dataset_name"])
    
    # Get train/test splits
    train_data = hf_dataset['train']
    test_data = hf_dataset['test']
    
    if CONFIG["train_limit"]:
        train_data = train_data.select(range(min(CONFIG["train_limit"], len(train_data))))
    if CONFIG["test_limit"]:
        test_data = test_data.select(range(min(CONFIG["test_limit"], len(test_data))))
    
    print(f"Loaded {len(train_data)} train samples, {len(test_data)} test samples")
    
    return train_data, test_data

def prepare_conversations(dataset_split):
    """Convert HF dataset to conversation format - Memory Efficient"""
    
    conversations = []
    
    # Diagnosis mapping
    dx_names = {
        'akiec': 'actinic keratosis',
        'bcc': 'basal cell carcinoma', 
        'bkl': 'benign keratosis-like lesion',
        'df': 'dermatofibroma',
        'mel': 'melanoma',
        'nv': 'melanocytic nevus',
        'vasc': 'vascular lesion'
    }
    
    for i, item in enumerate(dataset_split):
        dx = item['diagnosis']
        diagnosis_full = dx_names.get(dx, dx)
        
        # Spatial awareness logic
        use_spatial = (CONFIG["include_spatial_descriptions"] and 
                      item.get('mask_available', False) and 
                      random.random() < CONFIG["spatial_description_ratio"])
        
        if use_spatial and item.get('spatial_description'):
            user_prompt = "Analyze this skin lesion, provide a diagnosis, and describe its location."
            spatial_desc = item['spatial_description'].replace('lesion located in', 'The lesion is located in the')
            assistant_response = f"This appears to be {diagnosis_full}. {spatial_desc}."
        else:
            user_prompt = "Analyze this skin lesion and provide a diagnosis."
            assistant_response = f"This appears to be {diagnosis_full}."
        
        conversation = {
            "conversation": [
                {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": user_prompt}]},
                {"role": "assistant", "content": [{"type": "text", "text": assistant_response}]}
            ],
            "hf_index": i,  # Store index instead of image
            "metadata": {
                "lesion_id": item.get('lesion_id'),
                "diagnosis": dx,
                "has_spatial": use_spatial,
                "bbox": item.get('bbox'),
                "area_coverage": item.get('area_coverage'),
                "mask_available": item.get('mask_available', False)
            }
        }
        conversations.append(conversation)
    
    return conversations, dataset_split  # Return dataset_split for image access

# Load data
train_hf, test_hf = load_hf_dataset()
train_conversations, train_hf_data = prepare_conversations(train_hf)
test_conversations, test_hf_data = prepare_conversations(test_hf)

print(f"Prepared {len(train_conversations)} train conversations")
print(f"Prepared {len(test_conversations)} test conversations")


README.md:   0%|          | 0.00/814 [00:00<?, ?B/s]

train-00000-of-00008.parquet:   0%|          | 0.00/375M [00:00<?, ?B/s]

train-00001-of-00008.parquet:   0%|          | 0.00/374M [00:00<?, ?B/s]

train-00002-of-00008.parquet:   0%|          | 0.00/375M [00:00<?, ?B/s]

train-00003-of-00008.parquet:   0%|          | 0.00/377M [00:00<?, ?B/s]

train-00004-of-00008.parquet:   0%|          | 0.00/376M [00:00<?, ?B/s]

train-00005-of-00008.parquet:   0%|          | 0.00/378M [00:00<?, ?B/s]

train-00006-of-00008.parquet:   0%|          | 0.00/375M [00:00<?, ?B/s]

train-00007-of-00008.parquet:   0%|          | 0.00/376M [00:00<?, ?B/s]

test-00000-of-00002.parquet:   0%|          | 0.00/374M [00:00<?, ?B/s]

test-00001-of-00002.parquet:   0%|          | 0.00/375M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8012 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2003 [00:00<?, ? examples/s]

Loaded 8012 train samples, 2003 test samples
Prepared 8012 train conversations
Prepared 2003 test conversations


In [5]:
# Memory-efficient data collator for HF dataset
def collate_fn(examples):
    texts = []
    images = []
    
    for example in examples:
        # Load image on-demand from HF dataset using index
        hf_index = example["hf_index"]
        image = train_hf_data[hf_index]['image'].convert('RGB')
        images.append(image)
        
        text = processor.apply_chat_template(
            example["conversation"], 
            tokenize=False, 
            add_generation_prompt=False
        )
        texts.append(text)
    
    batch = processor(
        text=texts,
        images=images, 
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=CONFIG["max_length"]
    )
    
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    if hasattr(processor, 'image_token_id'):
        labels[labels == processor.image_token_id] = -100
    
    batch["labels"] = labels
    return batch

print("Collate function ready ✓")


Collate function ready ✓


In [6]:
# Setup LoRA
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print("LoRA setup complete ✓")


trainable params: 73,859,072 || all params: 2,282,844,672 || trainable%: 3.2354
LoRA setup complete ✓


In [7]:
# import wandb
# wandb.login(key="f1aedbcd5d073259cb4005220e80f8f3bab2dd69")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrasoulcarrera[0m ([33mrasoulcarrera-aba[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [8]:
# Training setup
training_args = TrainingArguments(
    output_dir=CONFIG["output_dir"],
    num_train_epochs=CONFIG["num_train_epochs"],
    per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    learning_rate=CONFIG["learning_rate"],
    warmup_steps=CONFIG["warmup_steps"],
     # Logging and saving
    logging_steps=50,
    save_steps=1000,
    eval_steps=300,
    save_total_limit=1,
    
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='wandb',
    dataloader_pin_memory=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_conversations,
    data_collator=collate_fn,
    tokenizer=processor.tokenizer
)

print("Trainer ready ✓")


Trainer ready ✓


  trainer = Trainer(


In [9]:
# Start training
print("Starting training...")
trainer.train()
print("Training complete!")

# Save model and processor
trainer.save_model(CONFIG["output_dir"])
processor.save_pretrained(CONFIG["output_dir"])
print("Model saved ✓")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Starting training...


Step,Training Loss
50,2.1605
100,0.4281
150,0.3629
200,0.3603
250,0.3577
300,0.3602
350,0.3549
400,0.3534
450,0.3568
500,0.3583


Training complete!
Model saved ✓


In [10]:
# Test the model
def test_hf_model(model, processor, test_samples, num_tests=50):
    """Test model on HF dataset samples"""
    
    model.eval()
    test_results = []
    spatial_tests = []
    
    for i, sample in enumerate(test_samples[:num_tests]):
        print(f"\n--- Test {i+1}/{num_tests} ---")
        
        try:
            # Load image on-demand from HF dataset
            hf_index = sample["hf_index"]
            test_image = test_hf_data[hf_index]['image'].convert('RGB')
            print(f"Image: HF_sample_{i+1}")
            
            # Create test prompt
            has_spatial_data = sample["metadata"]["has_spatial"]
            if has_spatial_data:
                user_prompt = "Analyze this skin lesion, provide a diagnosis, and describe its location."
            else:
                user_prompt = "Analyze this skin lesion and provide a diagnosis."
            
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": user_prompt}
                    ]
                }
            ]
            
            # Generate response
            text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
            inputs = processor(
                text=[text_prompt], 
                images=[test_image], 
                return_tensors="pt"
            ).to(model.device)
            
            with torch.no_grad():
                generated_ids = model.generate(
                    **inputs,
                    max_new_tokens=256,
                    do_sample=False
                )
            
            response = processor.batch_decode(
                generated_ids[:, inputs['input_ids'].shape[1]:], 
                skip_special_tokens=True
            )[0].strip()
            
            # Get ground truth
            ground_truth = sample["conversation"][1]["content"][0]["text"]
            
            # Evaluate
            diagnosis_keywords = {
                'actinic keratosis': 'akiec',
                'basal cell carcinoma': 'bcc', 
                'benign keratosis-like lesion': 'bkl',
                'dermatofibroma': 'df',
                'melanoma': 'mel',
                'melanocytic nevus': 'nv',
                'vascular lesion': 'vasc'
            }
            
            # Check diagnosis
            diagnosis_correct = False
            for full_name, short_name in diagnosis_keywords.items():
                if full_name in ground_truth.lower() and full_name in response.lower():
                    diagnosis_correct = True
                    break
            
            # Check spatial if applicable
            spatial_correct = True
            gt_spatial = ""
            pred_spatial = ""
            
            if has_spatial_data:
                if "located in" in ground_truth.lower():
                    gt_spatial = ground_truth.split("located in")[-1].strip().rstrip(".")
                if "located in" in response.lower():
                    pred_spatial = response.split("located in")[-1].strip().rstrip(".")
                
                if gt_spatial and pred_spatial:
                    spatial_correct = gt_spatial.lower() == pred_spatial.lower()
                elif gt_spatial and not pred_spatial:
                    spatial_correct = False
                elif not gt_spatial and pred_spatial:
                    spatial_correct = False
            
            is_correct = diagnosis_correct and spatial_correct
            
            print(f"Ground Truth: {ground_truth}")
            print(f"Model Output: {response}")
            print(f"Diagnosis: {'✓' if diagnosis_correct else '❌'} | Spatial: {'✓' if spatial_correct else '❌' if has_spatial_data else 'N/A'}")
            print(f"Overall: {'✓ CORRECT' if is_correct else '❌ INCORRECT'}")
            
            if has_spatial_data and (gt_spatial or pred_spatial):
                print(f"Spatial GT: '{gt_spatial}' | Pred: '{pred_spatial}'")
                bbox_gt = sample["metadata"].get("bbox")
                if bbox_gt and len(bbox_gt) == 4:
                    print(f"Bbox: [{bbox_gt[0]:.0f}, {bbox_gt[1]:.0f}, {bbox_gt[2]:.0f}, {bbox_gt[3]:.0f}]")
                    area_cov = sample["metadata"].get("area_coverage")
                    if area_cov:
                        print(f"Area Coverage: {area_cov:.1%}")
            
            # Store results
            result = {
                "sample_id": i,
                "diagnosis_correct": diagnosis_correct,
                "spatial_correct": spatial_correct,
                "overall_correct": is_correct,
                "has_spatial": has_spatial_data,
                "ground_truth": ground_truth,
                "prediction": response
            }
            test_results.append(result)
            
            if has_spatial_data:
                spatial_tests.append(result)
        
        except Exception as e:
            print(f"Error in test {i+1}: {e}")
            continue
    
    # Summary
    total_tests = len(test_results)
    successful_tests = sum(1 for r in test_results if r["overall_correct"])
    diagnosis_correct = sum(1 for r in test_results if r["diagnosis_correct"])
    spatial_correct = sum(1 for r in spatial_tests if r["spatial_correct"])
    
    overall_accuracy = (successful_tests / total_tests * 100) if total_tests > 0 else 0
    diagnosis_accuracy = (diagnosis_correct / total_tests * 100) if total_tests > 0 else 0
    spatial_accuracy = (spatial_correct / len(spatial_tests) * 100) if len(spatial_tests) > 0 else 0
    
    print(f"\n" + "="*50)
    print(f"TEST SUMMARY:")
    print(f"Overall Accuracy:    {successful_tests}/{total_tests} ({overall_accuracy:.1f}%)")
    print(f"Diagnosis Accuracy:  {diagnosis_correct}/{total_tests} ({diagnosis_accuracy:.1f}%)")
    print(f"Spatial Accuracy:    {spatial_correct}/{len(spatial_tests)} ({spatial_accuracy:.1f}%) [{len(spatial_tests)} spatial samples]")
    print(f"="*50)

# Load the trained model and test
loaded_model = Qwen2VLForConditionalGeneration.from_pretrained(
    CONFIG["output_dir"],
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
loaded_processor = Qwen2VLProcessor.from_pretrained(CONFIG["output_dir"])

print("Testing model...")
test_hf_model(loaded_model, loaded_processor, test_conversations, num_tests=50)


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Testing model...

--- Test 1/50 ---
Image: HF_sample_1
Ground Truth: This appears to be melanocytic nevus.
Model Output: This appears to be melanocytic nevus.
Diagnosis: ✓ | Spatial: ✓
Overall: ✓ CORRECT

--- Test 2/50 ---
Image: HF_sample_2
Ground Truth: This appears to be melanoma.
Model Output: This appears to be melanoma.
Diagnosis: ✓ | Spatial: ✓
Overall: ✓ CORRECT

--- Test 3/50 ---
Image: HF_sample_3
Ground Truth: This appears to be melanocytic nevus.
Model Output: This appears to be melanocytic nevus.
Diagnosis: ✓ | Spatial: ✓
Overall: ✓ CORRECT

--- Test 4/50 ---
Image: HF_sample_4
Ground Truth: This appears to be melanocytic nevus.
Model Output: This appears to be melanocytic nevus.
Diagnosis: ✓ | Spatial: ✓
Overall: ✓ CORRECT

--- Test 5/50 ---
Image: HF_sample_5
Ground Truth: This appears to be melanocytic nevus.
Model Output: This appears to be melanocytic nevus.
Diagnosis: ✓ | Spatial: ✓
Overall: ✓ CORRECT

--- Test 6/50 ---
Image: HF_sample_6
Ground Truth: This appears t