# Fine-tuning Qwen2.5-VL-7B-Instruct on MathVerse Dataset (MLX - macOS)

This notebook fine-tunes Qwen2.5-VL-7B-Instruct using MLX for Apple Silicon (M1/M2/M3/M4).

**Requirements:**
- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.9+
- Sufficient RAM (32GB+ recommended)

## 1. Environment Setup

In [None]:
# Install MLX and required packages
!pip install -q -U pip
!pip install -q mlx mlx-lm
!pip install -q transformers datasets accelerate peft pillow
!pip install -q huggingface-hub
!pip install -q qwen-vl-utils  # Qwen VL utilities

In [None]:
# Import libraries
import os
import json
import re
from pathlib import Path
from PIL import Image
import numpy as np

# MLX imports
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# HuggingFace imports
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from peft import LoraConfig, get_peft_model

print(f"MLX version: {mx.__version__}")
print(f"Device: Apple Silicon")

# Set random seed
np.random.seed(42)

## 2. Data Configuration

In [None]:
# Data paths - Update these to match your local paths
DATA_DIR = "/Users/berta/Documents/Projects/mathverse"
JSONL_PATH = f"{DATA_DIR}/mathverse_testmini.jsonl"
IMAGES_DIR = f"{DATA_DIR}/mathverse_testmini_images"

# Verify paths exist
assert os.path.exists(JSONL_PATH), f"JSONL file not found: {JSONL_PATH}"
assert os.path.exists(IMAGES_DIR), f"Images directory not found: {IMAGES_DIR}"

print(f"✅ Data directory: {DATA_DIR}")
print(f"✅ JSONL file: {JSONL_PATH}")
print(f"✅ Images directory: {IMAGES_DIR}")

## 3. Load and Prepare Data

In [None]:
# Load JSONL data
from PIL import ImageFile

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

def clean_image_tokens(text):
    """Remove any existing image tokens from text."""
    patterns = [
        r'<\|image_pad\|>',
        r'<\|vision_start\|>',
        r'<\|vision_end\|>',
        r'<image>',
        r'</image>',
        r'\[IMG\d*\]',
    ]
    cleaned = text
    for pattern in patterns:
        cleaned = re.sub(pattern, '', cleaned)
    return cleaned.strip()

def load_mathverse_data(jsonl_path, images_dir, max_samples=None):
    """
    Load MathVerse dataset from JSONL file and images.
    """
    data = []
    errors = 0
    image_errors = 0
    
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for idx, line in enumerate(f):
            if max_samples and len(data) >= max_samples:
                break
            
            if not line.strip():
                continue
                
            try:
                item = json.loads(line)
            except json.JSONDecodeError as e:
                print(f"Warning: Skipping line {idx+1} due to JSON error: {e}")
                errors += 1
                continue
            
            # Construct image path
            image_path = os.path.join(images_dir, item['image_path'])
            
            if not os.path.exists(image_path):
                print(f"Warning: Image not found: {image_path}")
                continue
            
            # Get question and answer
            question = clean_image_tokens(item.get('query', ''))
            answer = item.get('answer', '')
            
            if not question or not answer:
                print(f"Warning: Skipping line {idx+1} - missing question or answer")
                continue
            
            # Load and verify image
            try:
                pil_image = Image.open(image_path).convert('RGB')
                _ = pil_image.size
            except (OSError, IOError) as e:
                print(f"Warning: Skipping line {idx+1} - corrupted image {image_path}: {e}")
                image_errors += 1
                continue
            
            # Format in conversation format - image BEFORE text
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": pil_image},
                        {"type": "text", "text": question}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer}
                    ]
                }
            ]
            
            data.append({"messages": conversation})
    
    print(f"✅ Loaded {len(data)} samples")
    if errors > 0:
        print(f"⚠️ Skipped {errors} lines due to JSON errors")
    if image_errors > 0:
        print(f"⚠️ Skipped {image_errors} corrupted/truncated images")
    return data

# Load data
raw_data = load_mathverse_data(
    JSONL_PATH, 
    IMAGES_DIR,
    max_samples=None  # Set to 100 for quick testing
)

# Display sample
if raw_data:
    print("\nSample data:")
    print(f"Keys: {raw_data[0].keys()}")
    print(f"Number of messages: {len(raw_data[0]['messages'])}")
    print(f"Content order: {[c['type'] for c in raw_data[0]['messages'][0]['content']]}")
    print(f"Question (first 200 chars): {raw_data[0]['messages'][0]['content'][1]['text'][:200]}...")
    print(f"Answer: {raw_data[0]['messages'][1]['content'][0]['text']}")

In [None]:
# Split data into train and validation sets
from sklearn.model_selection import train_test_split

train_data, val_data = train_test_split(
    raw_data, 
    test_size=0.1,
    random_state=42
)

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

# Keep as plain Python lists (don't convert to HF Dataset to preserve PIL Images)
train_dataset = train_data
val_dataset = val_data

print(f"\n✅ Using plain Python lists")
print(f"Image type check: {type(train_dataset[0]['messages'][0]['content'][0]['image'])}")

## 4. Load Model with PyTorch (MLX doesn't support Qwen2-VL yet)

**Note:** As of now, MLX doesn't have native Qwen2.5-VL support. We'll use PyTorch with MPS (Metal Performance Shaders) backend for Apple Silicon.

For true MLX support, you would need to wait for MLX-VLM or convert the model manually.

In [None]:
# For macOS, we'll use PyTorch with MPS backend
import torch

# Check for MPS availability
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✅ Using MPS (Metal Performance Shaders) backend")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("✅ Using CUDA backend")
else:
    device = torch.device("cpu")
    print("⚠️ Using CPU backend")

print(f"Device: {device}")

In [None]:
# Model configuration
MODEL_NAME = "Qwen/Qwen2.5-VL-7B-Instruct"
MAX_SEQ_LENGTH = 2048

# Load model and processor
# Note: We can't use 4-bit quantization on MPS, so loading in float16
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

print("Loading model...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,  # Use float16 for memory efficiency
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_NAME)

print("✅ Model loaded successfully!")
print(f"Model device: {next(model.parameters()).device}")

## 5. Configure LoRA for Fine-tuning

In [None]:
# Apply LoRA
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

print("✅ LoRA applied successfully!")

## 6. Training Configuration

In [None]:
# Training configuration
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorForLanguageModeling

OUTPUT_DIR = "./qwen2.5-vl-mathverse-finetuned-mlx"

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    
    # Training hyperparameters
    num_train_epochs=3,
    per_device_train_batch_size=1,  # Start with 1 for M4
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,  # Effective batch size = 8
    
    # Optimizer settings
    learning_rate=2e-5,
    warmup_steps=5,
    weight_decay=0.01,
    
    # Precision - use fp16 for MPS
    fp16=False,  # MPS doesn't support fp16 training yet
    bf16=False,
    
    # Logging and evaluation
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    
    # Important for vision models
    remove_unused_columns=False,
    
    # Other settings
    report_to="none",
    seed=42,
    
    # Use MPS if available
    use_mps_device=torch.backends.mps.is_available(),
)

print("Training configuration:")
print(f"  - Epochs: {training_args.num_train_epochs}")
print(f"  - Batch size: {training_args.per_device_train_batch_size}")
print(f"  - Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  - Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  - Learning rate: {training_args.learning_rate}")
print(f"  - Device: {'MPS' if training_args.use_mps_device else 'CPU'}")

## 7. Custom Data Collator for Vision Models

In [None]:
# Custom data collator for vision-language models
from dataclasses import dataclass
from typing import Dict, List
import torch

@dataclass
class VisionLanguageDataCollator:
    """Data collator for vision-language models."""
    processor: any
    
    def __call__(self, examples: List[Dict]) -> Dict[str, torch.Tensor]:
        """
        Process a batch of examples.
        """
        # Extract images and create text prompts
        images = []
        texts = []
        
        for example in examples:
            messages = example['messages']
            
            # Extract image from user message
            for content_item in messages[0]['content']:
                if content_item['type'] == 'image':
                    images.append(content_item['image'])
            
            # Apply chat template
            text_prompt = self.processor.apply_chat_template(
                messages,
                add_generation_prompt=False,
                tokenize=False
            )
            texts.append(text_prompt)
        
        # Process with processor
        batch = self.processor(
            text=texts,
            images=images,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_SEQ_LENGTH
        )
        
        # Set labels (same as input_ids for causal LM)
        batch['labels'] = batch['input_ids'].clone()
        
        return batch

data_collator = VisionLanguageDataCollator(processor=processor)
print("✅ Data collator created")

## 8. Initialize Trainer and Start Training

In [None]:
# Initialize trainer
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
)

print("✅ Trainer initialized")
print("Starting training...")
print("=" * 80)

In [None]:
# Start training
trainer.train()

print("\n✅ Training completed!")

## 9. Save the Model

In [None]:
# Save the fine-tuned model
FINAL_MODEL_DIR = "./qwen2.5-vl-mathverse-final-mlx"

# Save LoRA adapters
model.save_pretrained(FINAL_MODEL_DIR)
processor.save_pretrained(FINAL_MODEL_DIR)

print(f"✅ Model saved to {FINAL_MODEL_DIR}")

## 10. Test Inference

In [None]:
# Test on a validation sample
model.eval()

test_sample = val_data[0]
test_image = test_sample['messages'][0]['content'][0]['image']
test_question = test_sample['messages'][0]['content'][1]['text']
expected_answer = test_sample['messages'][1]['content'][0]['text']

# Prepare input
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": test_image},
            {"type": "text", "text": test_question}
        ]
    }
]

text_prompt = processor.apply_chat_template(
    conversation,
    add_generation_prompt=True,
    tokenize=False
)

inputs = processor(
    text=[text_prompt],
    images=[test_image],
    return_tensors="pt",
    padding=True
).to(model.device)

# Generate response
with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=False,
        temperature=0.7,
        top_p=0.9
    )

# Decode response
generated_text = processor.batch_decode(
    output,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True
)[0]

print("Question:")
print(test_question[:200])
print("\nExpected Answer:")
print(expected_answer)
print("\nModel Response:")
print(generated_text)
print("\nImage:")
display(test_image)

## 11. Full Evaluation (Optional)

Run the same evaluation functions from the Colab version here.

## Notes for MLX

**Current Limitations:**
- MLX doesn't yet have native Qwen2.5-VL support
- This notebook uses PyTorch with MPS backend instead
- MPS provides good performance on Apple Silicon

**For True MLX Support:**
1. Wait for MLX-VLM to add Qwen2.5-VL support
2. Or manually convert the model using `mlx.utils.tree_map`
3. Monitor: https://github.com/ml-explore/mlx-examples

**Performance Tips:**
- M4 Pro/Max: Can handle batch_size=2-4
- M4: Start with batch_size=1
- Increase gradient_accumulation_steps instead of batch size
- Monitor memory usage in Activity Monitor