# Skin Disease Diagnosis - Stage 1 SFT Training

Simple and clean implementation using direct processor approach for Vision-Language Model training.

## Overview
- **Stage 1 Goal**: Disease identification with basic spatial awareness
- **Model**: Qwen2-VL-2B-Instruct with LoRA fine-tuning
- **Dataset**: ISIC skin lesion dataset with metadata
- **Output**: Foundation model for Stage 2 GRPO training


## 1. Environment Setup


In [None]:
# Install required packages
%pip install transformers[torch] accelerate peft tqdm pillow pandas
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118


In [1]:
# Import libraries
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
import pandas as pd
from PIL import Image
from io import BytesIO
import warnings
warnings.filterwarnings("ignore")

from transformers import (
    Qwen2VLForConditionalGeneration,
    AutoTokenizer, 
    AutoProcessor,
    get_linear_schedule_with_warmup,
    set_seed
)
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


2025-08-15 14:45:47.714459: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755269147.736753     838 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755269147.743644     838 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device: Tesla T4


## 2. Configuration


In [3]:
# Simple configuration
config = {
    "model_name": "Qwen/Qwen2-VL-2B-Instruct",
    "output_dir": "/kaggle/working/outputs",
    "train_image_dir": "/kaggle/input/small-isic",
    "train_metadata_file": "/kaggle/input/small-isic/HAM10000_metadata.csv",
    
    "training": {
        "num_epochs": 3,
        "batch_size": 8,
        "learning_rate": 2e-5,
        "max_length": 512,
        "seed": 42
    },
    
    "lora": {
        "r": 16,
        "alpha": 32,
        "dropout": 0.1,
        "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"]
    }
}

os.makedirs(config['output_dir'], exist_ok=True)
set_seed(config['training']['seed'])
print("Configuration loaded!")


Configuration loaded!


## 3. Processing Class (Based on Your Approach)


In [4]:
class ProcessingClass:
    def __init__(self, processor):
        self.processor = processor
        self.tokenizer = processor.tokenizer
    
    def __call__(self, examples=None, text=None, **kwargs):
        if examples is not None:
            images = examples.get("image", [])
            prompts_raw = examples.get("prompt", [])
            answers = examples.get("answer", [])

            # Process prompts - they should already contain image tokens from dataset
            processed_prompts = []
            for prompt_turns in prompts_raw:
                if isinstance(prompt_turns, str):
                    # Use the prompt as-is (should already have image tokens)
                    processed_prompts.append(prompt_turns)
                else:
                    # Convert to string if not already
                    processed_prompts.append(str(prompt_turns))

            # Process images
            processed_images = []
            for img in images:
                if isinstance(img, bytes):
                    try:
                        processed_images.append(Image.open(BytesIO(img)).convert("RGB"))
                    except Exception as e:
                        print(f"Error loading image from bytes: {e}")
                        processed_images.append(Image.new("RGB", (224, 224)))
                elif isinstance(img, Image.Image):
                    processed_images.append(img.convert("RGB"))
                else:
                    print(f"Warning: Image format not recognized: {type(img)}")
                    processed_images.append(Image.new("RGB", (224, 224)))

            max_prompt_length = kwargs.get("max_prompt_length", 512)
            max_completion_length = kwargs.get("max_completion_length", 512)
            
            # Use processor to handle both images and text
            inputs = self.processor(
                images=processed_images,
                text=processed_prompts,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                max_length=max_prompt_length
            )
            
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]
            pixel_values = inputs.get("pixel_values")
            image_grid_thw = inputs.get("image_grid_thw")

            # Process answers/labels
            with self.tokenizer.as_target_tokenizer():
                label_encodings = self.tokenizer(
                    answers,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt",
                    max_length=max_completion_length
                )
            labels_tensor = label_encodings["input_ids"].to(input_ids.device)
            labels_tensor[labels_tensor == self.tokenizer.pad_token_id] = -100

            result = {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels_tensor,
            }
            
            # Add image-related tensors if they exist
            if pixel_values is not None:
                result["pixel_values"] = pixel_values
            if image_grid_thw is not None:
                result["image_grid_thw"] = image_grid_thw
                
            return result

        elif text is not None:
            return self.processor(text=text, padding=True, truncation=True, return_tensors="pt")
        else:
            raise ValueError("Either 'examples' or 'text' must be provided.")

    def batch_decode(self, tokenized_output, skip_special_tokens=True):
        if isinstance(tokenized_output, torch.Tensor):
            tokenized_output = tokenized_output.tolist()
        return self.tokenizer.batch_decode(tokenized_output, skip_special_tokens=skip_special_tokens)

print("ProcessingClass defined!")


ProcessingClass defined!


## 4. Dataset Class


In [5]:
class SkinDiseaseDataset(Dataset):
    def __init__(self, image_dir, metadata_file, process_instance, max_length=512):
        self.image_dir = image_dir
        self.process_instance = process_instance
        self.max_length = max_length
        
        # Load metadata
        metadata_df = pd.read_csv(metadata_file)
        self.metadata = metadata_df.to_dict('records')
        
        # Find image column
        possible_cols = ['image_name', 'isic_id', 'image_id', 'filename', 'name', 'image']
        self.image_col = None
        for col in possible_cols:
            if col in metadata_df.columns:
                self.image_col = col
                break
        if not self.image_col:
            self.image_col = metadata_df.columns[0]
            
        # Prepare data
        self.data = []
        for item in self.metadata:
            image_filename = str(item[self.image_col])
            if not image_filename.lower().endswith(('.jpg', '.jpeg', '.png')):
                image_filename += '.jpg'
            
            image_path = os.path.join(self.image_dir, image_filename)
            if os.path.exists(image_path):
                diagnosis = item.get('dx', 'unknown')
                location = item.get('localization', 'body')
                
                prompt = "Diagnose this skin condition."
                answer = f"{diagnosis} located on {location}"
                
                self.data.append({
                    'image_path': image_path,
                    'prompt': prompt,
                    'answer': answer
                })
        
        print(f"Dataset prepared with {len(self.data)} samples")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Load image
        try:
            image = Image.open(item['image_path']).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224), color='white')
        
        # Prepare conversation format for Qwen2VL processor
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": item['prompt']}
                ]
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": item['answer']}]
            }
        ]
        
        # Use processor's apply_chat_template
        try:
            text_input = self.process_instance.processor.apply_chat_template(
                conversation, 
                tokenize=False, 
                add_generation_prompt=False
            )
        except:
            # Fallback to simple format
            text_input = f"<image>{item['prompt']}"
        
        # Prepare examples in the format expected by ProcessingClass
        examples = {
            "image": [image],
            "prompt": [text_input],
            "answer": [item['answer']]
        }
        
        # Process using the ProcessingClass
        processed = self.process_instance(
            examples=examples,
            max_prompt_length=self.max_length,
            max_completion_length=self.max_length
        )
        
        # Return single sample (remove batch dimension)
        result = {}
        for key, value in processed.items():
            if isinstance(value, torch.Tensor) and value.dim() > 0:
                result[key] = value.squeeze(0)  # Remove batch dimension
            else:
                result[key] = value
        
        return result

print("SkinDiseaseDataset defined!")


SkinDiseaseDataset defined!


## 5. Model Setup


In [6]:
# Load model, processor, and tokenizer
print("Loading model and processor...")

processor = AutoProcessor.from_pretrained(config['model_name'], trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(config['model_name'], trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = Qwen2VLForConditionalGeneration.from_pretrained(
    config['model_name'],
    torch_dtype=torch.float16,
    trust_remote_code=True,
    device_map='auto'
)

# Apply LoRA
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=config['lora']['r'],
    lora_alpha=config['lora']['alpha'],
    lora_dropout=config['lora']['dropout'],
    target_modules=config['lora']['target_modules']
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print("Model loaded successfully!")


Loading model and processor...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 4,358,144 || all params: 2,213,343,744 || trainable%: 0.1969
Model loaded successfully!


## 6. Dataset and Training Setup


In [None]:
# Create processing instance and dataset
process_instance = ProcessingClass(processor)

print("Creating dataset...")
train_dataset = SkinDiseaseDataset(
    image_dir=config['train_image_dir'],
    metadata_file=config['train_metadata_file'],
    process_instance=process_instance,
    max_length=config['training']['max_length']
)

# Test dataset and debug
if len(train_dataset) > 0:
    sample = train_dataset[0]
    print("Sample shapes:", {k: v.shape if hasattr(v, 'shape') else type(v) for k, v in sample.items()})
    
    # Debug: Check if image tokens are in the input_ids
    input_ids = sample['input_ids']
    try:
        image_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")
        print(f"Image token ID: {image_token_id}")
        
        if image_token_id is not None and isinstance(input_ids, torch.Tensor):
            image_token_count = (input_ids == image_token_id).sum().item()
            print(f"Number of image tokens in input: {image_token_count}")
        else:
            print("Could not find image token or input_ids not a tensor")
            
        # Print decoded text to see the format
        decoded_text = processor.tokenizer.decode(input_ids, skip_special_tokens=False)
        print(f"Decoded text sample (first 200 chars): {decoded_text[:200]}")
        
    except Exception as e:
        print(f"Debug error: {e}")
        print(f"input_ids type: {type(input_ids)}")
        print(f"input_ids shape: {input_ids.shape if hasattr(input_ids, 'shape') else 'no shape'}")
    
    # Check pixel_values shape - should be [3, H, W]
    pixel_values = sample['pixel_values']
    print(f"Pixel values shape: {pixel_values.shape}")
#     if pixel_values.shape != torch.Size([3, 224, 224]) and len(pixel_values.shape) == 2:
#         print("⚠️ Pixel values have wrong shape - should be [3, height, width]")
# else:
#     print("No samples found!")


Creating dataset...
Dataset prepared with 500 samples
Sample shapes: {'input_ids': torch.Size([512]), 'attention_mask': torch.Size([512]), 'labels': torch.Size([512]), 'pixel_values': torch.Size([1344, 1176]), 'image_grid_thw': torch.Size([3])}
Image token ID: None
Could not find image token or input_ids not a tensor
Decoded text sample (first 200 chars): <|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|end
Pixel values shape: torch.Size([1344, 1176])
⚠️ Pixel values have wrong shape - should be [3, height, width]


In [10]:
# Setup training components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['training']['learning_rate']
)

# Custom collate function to handle batch processing
def collate_fn(batch):
    # Stack all tensors
    result = {}
    for key in batch[0].keys():
        values = [item[key] for item in batch]
        if isinstance(values[0], torch.Tensor):
            result[key] = torch.stack(values)
        else:
            result[key] = values
    return result

train_dataloader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0
)

total_steps = len(train_dataloader) * config['training']['num_epochs']
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_steps // 10,
    num_training_steps=total_steps
)

scaler = GradScaler()

print(f"Training setup complete. Total steps: {total_steps}")


Training setup complete. Total steps: 750


## 7. Training


In [11]:
# Training function
def train_epoch(model, dataloader, optimizer, scheduler, scaler, epoch):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch}')
    
    for batch in progress_bar:
        # Move to device
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].to(device)
        
        # Forward pass
        with autocast():
            outputs = model(**batch)
            loss = outputs.loss
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()
        
        # Update progress
        total_loss += loss.item()
        avg_loss = total_loss / (progress_bar.n + 1)
        progress_bar.set_postfix({'Loss': f'{avg_loss:.4f}'})
    
    return total_loss / len(dataloader)

# Main training loop
print("Starting training...")
for epoch in range(config['training']['num_epochs']):
    train_loss = train_epoch(model, train_dataloader, optimizer, scheduler, scaler, epoch)
    print(f"Epoch {epoch} - Average Loss: {train_loss:.4f}")

print("Training completed!")


Starting training...


Epoch 0:   0%|          | 0/250 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Epoch 0: 100%|██████████| 250/250 [04:10<00:00,  1.00s/it, Loss=4.8416] 


Epoch 0 - Average Loss: 4.8416


Epoch 1: 100%|██████████| 250/250 [04:02<00:00,  1.03it/s, Loss=0.0393]


Epoch 1 - Average Loss: 0.0393


Epoch 2: 100%|██████████| 250/250 [04:02<00:00,  1.03it/s, Loss=0.0055]

Epoch 2 - Average Loss: 0.0055
Training completed!





## 8. Save Model


In [12]:
# Save final model
print("Saving final model...")
save_path = os.path.join(config['output_dir'], 'stage1_final_model')
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

# Save training info
training_info = {
    'stage': 'stage1_sft',
    'model_name': config['model_name'],
    'training_config': config,
    'description': 'Stage 1 SFT model for skin disease diagnosis'
}

with open(os.path.join(save_path, 'training_info.json'), 'w') as f:
    json.dump(training_info, f, indent=2)

print(f"Model saved to {save_path}")
print("Training pipeline completed successfully!")


Saving final model...
Model saved to /kaggle/working/outputs/stage1_final_model
Training pipeline completed successfully!


In [None]:
def test_model_inference(model, processor, image_path, prompt="Diagnose this skin condition."):
    """
    Test the trained model on a single image
    """
    try:
        # Load and process image
        image = Image.open(image_path).convert('RGB')
        print(f"Testing with image: {image_path}")
        print(f"Image size: {image.size}")
        
        # Create conversation format
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": prompt}
                ]
            }
        ]
        
        # Process with the trained model
        text_input = processor.apply_chat_template(
            conversation, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        # Tokenize and process
        inputs = processor(
            images=[image],
            text=[text_input],
            return_tensors="pt"
        ).to(device)
        
        # Generate response
        model.eval()
        with torch.no_grad():
            generate_ids = model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )
        
        # Decode response
        generated_text = processor.batch_decode(
            generate_ids[:, inputs['input_ids'].shape[1]:], 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=False
        )[0]
        
        print(f"Model prediction: {generated_text}")
        return generated_text
        
    except Exception as e:
        print(f"Error during inference: {e}")
        return None

print("Test function defined!")


In [None]:
# Test the model with sample images from the dataset
print("="*50)
print("🧪 TESTING TRAINED MODEL")
print("="*50)

if len(train_dataset) > 0:
    # Test with first few samples from dataset
    num_test_samples = min(3, len(train_dataset))
    
    for i in range(num_test_samples):
        print(f"\n🔍 Test Sample {i+1}/{num_test_samples}")
        print("-" * 30)
        
        # Get sample data
        sample_data = train_dataset.data[i]
        image_path = sample_data['image_path']
        expected_answer = sample_data['answer']
        
        print(f"Expected: {expected_answer}")
        
        # Test model inference
        prediction = test_model_inference(
            model=model,
            processor=processor, 
            image_path=image_path,
            prompt="Diagnose this skin condition."
        )
        
        if prediction:
            print(f"✅ Inference successful!")
        else:
            print(f"❌ Inference failed!")
        
        print("-" * 30)
else:
    print("❌ No dataset samples available for testing")

print(f"\n🎯 Testing completed!")


In [None]:
# Custom image testing function - Test with your own images!
def test_custom_image(image_path, custom_prompt=None):
    """
    Test the model with a custom image path
    Usage: test_custom_image("/path/to/your/image.jpg", "Diagnose this lesion")
    """
    if custom_prompt is None:
        custom_prompt = "Diagnose this skin condition and describe its location."
    
    print(f"\n🔬 CUSTOM IMAGE TEST")
    print(f"Image: {image_path}")
    print(f"Prompt: {custom_prompt}")
    print("-" * 40)
    
    if os.path.exists(image_path):
        prediction = test_model_inference(
            model=model,
            processor=processor,
            image_path=image_path,
            prompt=custom_prompt
        )
        return prediction
    else:
        print(f"❌ Image not found: {image_path}")
        return None

# Example usage (uncomment and modify the path to test your own images):
# prediction = test_custom_image("/kaggle/input/small-isic/ISIC_0024312.jpg")
# prediction = test_custom_image("/path/to/your/skin/image.jpg", "What type of skin lesion is this?")

print("Custom test function ready!")
print("📝 To test your own image, use:")
print('   test_custom_image("/path/to/image.jpg", "Your custom prompt")')
