## Import cell
Run this all to import required libraries

In [None]:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

import json
from pathlib import Path
import random

import json
import os
import torch
from transformers import get_scheduler
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb

from qwen_vl_utils import process_vision_info

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset

## Inference Cell (Pre-fine-tuning)
Here you can run inferences by giving image path and a text

In [None]:
# Load the model on the available device
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
)

# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

# Change the message here, and image/video path
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "samurai.jpg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)


## Data Preprocessing
 Transform data into a ChatML-like format, a custom multimodal variant designed for models that support both text and images like Qwen2.5-VL
T o be run only **ONCE**

In [None]:
# Paths
input_json = "coco/llava_v1_5_mix665k.json"
output_json = "coco/qwen25vl_dataset.json"

# Load raw data
with open(input_json, "r") as f:
    raw_data = json.load(f)

processed_data = []

for item in raw_data:
    messages = []
    for turn in item["conversations"]:
        role = "user" if turn["from"] == "human" else "assistant"
        content = []

        # Handle image placement
        value = turn["value"]
        if "<image>" in value:
            # Extract text after removing <image> tag
            text_part = value.replace("<image>", "").strip()
            content.append({"type": "image", "image": item["image"]})
            if text_part:
                content.append({"type": "text", "text": text_part})
        else:
            content.append({"type": "text", "text": value})

        messages.append({
            "role": role,
            "content": content
        })

    processed_data.append({"messages": messages})

# Save new dataset
with open(output_json, "w") as f:
    json.dump(processed_data, f, indent=2)

print(f"Processed dataset saved to {output_json}")

## Data split
Splitting data (percentages) and choosing the number of samples used for the entire dataset

In [None]:
# Input and output paths
input_json = "coco/qwen25vl_dataset.json"  # Your preprocessed dataset
output_dir = Path(".")

# Parameters
num_samples = 500 # Change here the total number of images used. for all data use 118287
train_ratio = 0.7
test_ratio = 0.2
val_ratio = 0.1

# Load full dataset
with open(input_json, "r") as f:
    data = json.load(f)

# Take only first 5000 entries
subset = data[:num_samples]

# Shuffle for randomness
random.shuffle(subset)

# Compute split indices
train_end = int(num_samples * train_ratio)
val_end = train_end + int(num_samples * val_ratio)

train_data = subset[:train_end]
val_data = subset[train_end:val_end]
test_data = subset[val_end:]

# Save files
for split_type, split_data in zip(
    ["train", "val", "test"],
    [train_data, val_data, test_data]
):
    filename = f"example_{split_type}_data.json"
    with open(output_dir / filename, "w") as f:
        json.dump(split_data, f, indent=2)
    print(f"Saved {len(split_data)} samples to {filename}")

print("\nSubsetting and splitting completed.")

## Model Configuration and Hyperparameters

In [None]:
# ----------------------------
# Configuration
# ----------------------------
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
output_dir = "./qwen25vl_lora_output"
os.makedirs(output_dir, exist_ok=True)

# Data files
train_file = "example_train_data.json"
val_file = "example_val_data.json"
test_file = "example_test_data.json"

# Hyperparameters
batch_size = 1
gradient_accumulation_steps = 4
learning_rate = 2e-4
num_epochs = 3
max_seq_length = 1024
save_steps = 100 # IMPORTANT
logging_steps = 10
device = "cuda" if torch.cuda.is_available() else "cpu"

# LoRA config
lora_rank = 64
lora_alpha = 128
lora_dropout = 0.05



In [None]:
# Load Processor and Model
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Apply LoRA
lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    modules_to_save=["visual"],  # Keep vision encoder trainable
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


## Data Loader

In [None]:
# Load Dataset
def load_json_dataset(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)
    return [{"messages": item["messages"]} for item in data]


class ChatDataset(Dataset):
    def __init__(self, data, processor, max_length=1024):
        self.data = data
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        messages = self.data[idx]["messages"]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        )
        image_inputs, video_inputs = process_vision_info(messages)
        
        inputs = self.processor(
            text=text,
            images=image_inputs,
            videos=video_inputs,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = inputs["input_ids"][0]
        attention_mask = inputs["attention_mask"][0]
        pixel_values = inputs.get("pixel_values")
        image_grid_thw = inputs.get("image_grid_thw")

        # Fix: Handle malformed image_grid_thw
        if image_grid_thw is not None:
            # If it's [1, 1, 1, 3], reshape to [1, 3]
            if image_grid_thw.dim() == 4 and image_grid_thw.shape[1:] == (1, 1, 3):
                image_grid_thw = image_grid_thw.squeeze(1).squeeze(1)

            # Ensure shape is [num_images, 3]
            if image_grid_thw.dim() == 1:
                image_grid_thw = image_grid_thw.unsqueeze(0)  # [3] -> [1, 3]
            elif image_grid_thw.dim() > 2:
                image_grid_thw = image_grid_thw.view(-1, 3)  # Flatten extra dims

        # Fix: Make sure pixel_values and image_grid_thw are consistent
        if pixel_values is not None and pixel_values.dim() == 4 and pixel_values.shape[0] == 0:
            pixel_values = None
            image_grid_thw = None

        labels = input_ids.clone()
        # Mask out user tokens
        for i, msg in enumerate(messages):
            if msg["role"] == "user":
                start = 0
                for j in range(i + 1):
                    if messages[j]["role"] == "assistant":
                        break
                    start += len(msg["content"])  # rough estimate
                end = start + len(msg["content"])
                labels[start:end] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "pixel_values": pixel_values,
            "image_grid_thw": image_grid_thw,
        }
# Load train/val/test datasets
train_data = load_json_dataset(train_file)
val_data = load_json_dataset(val_file)
test_data = load_json_dataset(test_file)

train_dataset = ChatDataset(train_data, processor, max_length=max_seq_length)
val_dataset = ChatDataset(val_data, processor, max_length=max_seq_length)
test_dataset = ChatDataset(test_data, processor, max_length=max_seq_length)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


In [None]:
print(val_dataset[3])

In [None]:
# Optimizer & Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

num_training_steps = num_epochs * len(train_loader) // gradient_accumulation_steps
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps,
)


## Training Loop and Evaluation

In [None]:
# Training Loop
model.train()
completed_steps = 0

for epoch in range(num_epochs):
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
    total_loss = 0

    for step, batch in enumerate(progress_bar):
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

        # Ensure image_grid_thw has correct shape
        pixel_values = batch.get("pixel_values")
        image_grid_thw = batch.get("image_grid_thw")

        if pixel_values is not None and image_grid_thw is None:
            print("WARNING! : pixel_values exists but image_grid_thw is None. Skipping image input.")
            pixel_values = None
        elif image_grid_thw is not None:
            if image_grid_thw.dim() == 4 and image_grid_thw.shape[1:] == (1, 1, 3):
                image_grid_thw = image_grid_thw.squeeze(1).squeeze(1)
            if image_grid_thw.dim() == 1:
                image_grid_thw = image_grid_thw.unsqueeze(0)  # [3] -> [1, 3]
            elif image_grid_thw.dim() > 2:
                image_grid_thw = image_grid_thw.view(-1, 3)

        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
        )

        loss = outputs.loss / gradient_accumulation_steps
        loss.backward()
        total_loss += loss.item()

        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            completed_steps += 1

            total_loss = 0

            if completed_steps % save_steps == 0:
                model.save_pretrained(os.path.join(output_dir, f"checkpoint-{completed_steps}"))
                processor.save_pretrained(os.path.join(output_dir, f"checkpoint-{completed_steps}"))

    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            # Fix: Same handling in validation
            pixel_values = batch.get("pixel_values")
            image_grid_thw_val = batch.get("image_grid_thw")

            if pixel_values is not None and image_grid_thw_val is None:
                pixel_values = None
            elif image_grid_thw_val is not None:
                if image_grid_thw_val.dim() == 4 and image_grid_thw_val.shape[1:] == (1, 1, 3):
                    image_grid_thw_val = image_grid_thw_val.squeeze(1).squeeze(1)
                if image_grid_thw_val.dim() == 1:
                    image_grid_thw_val = image_grid_thw_val.unsqueeze(0)
                elif image_grid_thw_val.dim() > 2:
                    image_grid_thw_val = image_grid_thw_val.view(-1, 3)

            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw_val,
            )
            val_loss += outputs.loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss after epoch {epoch + 1}: {avg_val_loss:.4f}")


    model.train()

print("Training completed.")