# Rex-Omni Unsloth Fine-tuning (Kaggle)

This notebook uses **Unsloth** to fine-tune the **Rex-Omni** model (based on Qwen2.5-VL). Unsloth provides 2x faster training and 60% less memory usage, making it ideal for Kaggle environments.

**Model**: `IDEA-Research/Rex-Omni`
**Library**: `unsloth` (with Qwen2.5-VL support)

## 1. Installation
Install Unsloth and dependencies.

In [None]:
%%capture
!pip install unsloth
# Also install the latest nightly to ensure Qwen2.5-VL support if not in main yet
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

## 2. Configuration & Model Loading
We load `IDEA-Research/Rex-Omni` using `FastVisionModel`.

In [None]:
from unsloth import FastVisionModel # FastVisionModel for VLMs
import torch

# 1. Load Model
model, tokenizer = FastVisionModel.from_pretrained(
    "IDEA-Research/Rex-Omni",
    load_in_4bit = True, # Use 4-bit quantization for memory efficiency
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
)

# 2. Add LoRA adapters
model = FastVisionModel.get_peft_model(
    model,
    r = 16, # LoRA Rank
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

## 3. Data Preparation
We use the same data conversion logic as before to ensure Rex-Omni compatibility.

In [None]:
import json
from pathlib import Path
from PIL import Image
from tqdm import tqdm

def normalize_bbox(bbox, width, height):
    x0, y0, x1, y1 = bbox
    x0_n = max(0.0, min(1.0, x0 / width))
    y0_n = max(0.0, min(1.0, y0 / height))
    x1_n = max(0.0, min(1.0, x1 / width))
    y1_n = max(0.0, min(1.0, y1 / height))
    
    x0_bin = int(x0_n * 999)
    y0_bin = int(y0_n * 999)
    x1_bin = int(x1_n * 999)
    y1_bin = int(y1_n * 999)
    return f"<{x0_bin}><{y0_bin}><{x1_bin}><{y1_bin}>"

def prepare_data(image_folder, metadata_file, output_file):
    print(f"Converting data from {metadata_file}...")
    image_folder_path = Path(image_folder)
    
    with open(metadata_file, 'r') as f:
        data = json.load(f)
        
    with open(output_file, 'w') as f_out:
        for item in tqdm(data):
            image_filename = item.get('image')
            if not image_filename: continue
                
            image_path = image_folder_path / image_filename
            if not image_path.exists(): continue
            
            try:
                with Image.open(image_path) as img:
                    width, height = img.size
            except: continue

            objects = item.get('objects', [])
            if not objects: continue

            cat_to_boxes = {}
            for obj in objects:
                cat = obj['category']
                bbox = obj['bbox']
                if cat not in cat_to_boxes: cat_to_boxes[cat] = []
                cat_to_boxes[cat].append(bbox)
            
            answer_parts = []
            for cat, boxes in cat_to_boxes.items():
                box_tokens = [normalize_bbox(b, width, height) for b in boxes]
                box_str = ",".join(box_tokens)
                part = f"<|object_ref_start|>{cat}<|object_ref_end|><|box_start|>{box_str}<|box_end|>"
                answer_parts.append(part)
            
            answer_text = ", ".join(answer_parts)
            
            messages = [
                {"role": "user", "content": [
                    {"type": "image", "image": str(image_path.absolute())},
                    {"type": "text", "text": "<image>\nDetect the objects in this image."}
                ]},
                {"role": "assistant", "content": [{"type": "text", "text": answer_text}]}
            ]
            f_out.write(json.dumps({"messages": messages}) + '\n')
    print(f"Data saved to {output_file}")

# VRSBench Specific Preparation
def prepare_vrsbench_data(parquet_file, image_root, output_file):
    import pandas as pd
    import json
    from tqdm import tqdm
    
    print(f"Loading VRSBench data from {parquet_file}...")
    df = pd.read_parquet(parquet_file)
    
    with open(output_file, 'w') as f_out:
        for _, row in tqdm(df.iterrows(), total=len(df)):
            # Image Path
            rel_path = row['image_path']
            if rel_path.startswith("./"):
                rel_path = rel_path[2:]
            
            image_path = Path(image_root) / rel_path
            
            # Parse Objects
            objects = row['objects']
            if isinstance(objects, str):
                objects = json.loads(objects)
                
            if not objects:
                continue
                
            # Group by category
            cat_to_boxes = {}
            for obj in objects:
                cat = obj['obj_cls']
                # VRSBench coords are [x0, y0, x1, y1] NORMALIZED (0-1)
                bbox = obj['obj_coord']
                
                if cat not in cat_to_boxes:
                    cat_to_boxes[cat] = []
                cat_to_boxes[cat].append(bbox)
            
            # Format Answer
            answer_parts = []
            for cat, boxes in cat_to_boxes.items():
                box_tokens = []
                for bbox in boxes:
                    # Scale 0-1 to 0-999
                    x0, y0, x1, y1 = bbox
                    x0_bin = int(x0 * 999)
                    y0_bin = int(y0 * 999)
                    x1_bin = int(x1 * 999)
                    y1_bin = int(y1 * 999)
                    box_tokens.append(f"<{x0_bin}><{y0_bin}><{x1_bin}><{y1_bin}>")
                
                box_str = ",".join(box_tokens)
                part = f"<|object_ref_start|>{cat}<|object_ref_end|><|box_start|>{box_str}<|box_end|>"
                answer_parts.append(part)
            
            answer_text = ", ".join(answer_parts)
            
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": str(image_path.absolute())},
                        {"type": "text", "text": "<image>\nDetect the objects in this image."}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": answer_text}]
                }
            ]
            
            f_out.write(json.dumps({"messages": messages}) + '\n')
            
    print(f"VRSBench data saved to {output_file}")

# CONFIGURATION
JSONL_PATH = "/kaggle/working/train_data.jsonl"

# Example Usage (Uncomment and adjust paths)
# PARQUET_FILE = "/kaggle/input/vrsbench/vrsbench_val_data.parquet"
# IMAGE_ROOT = "/kaggle/input/vrsbench"
# prepare_vrsbench_data(PARQUET_FILE, IMAGE_ROOT, JSONL_PATH)

## 4. Training
We use `UnslothTrainer` (or standard `SFTTrainer` wrapped by Unsloth) for training.

In [None]:
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments
from datasets import load_dataset

def train():
    if not os.path.exists(JSONL_PATH):
        print("Please run data preparation first.")
        return

    dataset = load_dataset("json", data_files=JSONL_PATH, split="train")
    
    # LIMIT TO 1000 SAMPLES AS REQUESTED
    print("Limiting dataset to 1000 samples...")
    if len(dataset) > 1000:
        dataset = dataset.shuffle(seed=3407).select(range(1000))

    trainer = UnslothTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = dataset,
        dataset_text_field = "text", # Unsloth handles formatting if using standard chat templates, but for VLMs we might need custom collator
        max_seq_length = 2048,
        dataset_num_proc = 2,
        args = UnslothTrainingArguments(
            per_device_train_batch_size = 2,
            gradient_accumulation_steps = 4,
            warmup_steps = 5,
            max_steps = 60, # Adjust as needed
            learning_rate = 2e-4,
            fp16 = not is_bfloat16_supported(),
            bf16 = is_bfloat16_supported(),
            logging_steps = 1,
            optim = "adamw_8bit",
            weight_decay = 0.01,
            lr_scheduler_type = "linear",
            seed = 3407,
            output_dir = "outputs",
        ),
    )
    
    # Note: For VLMs, we usually need a custom data collator to handle images.
    # Unsloth's FastVisionModel might have a default one, or we reuse the one from Qwen-VL-Utils.
    # Let's inject the standard Qwen collator here to be safe.
    from qwen_vl_utils import process_vision_info
    from transformers import AutoProcessor
    
    # We need the processor for the collator
    processor = AutoProcessor.from_pretrained("IDEA-Research/Rex-Omni", min_pixels=256*28*28, max_pixels=1280*28*28)
    
    def collate_fn(examples):
        texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True) for example in examples]
        image_inputs, video_inputs = process_vision_info(examples)
        inputs = processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
        labels = inputs["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
        labels[labels == image_token_id] = -100
        inputs["labels"] = labels
        return inputs
        
    trainer.data_collator = collate_fn
    
    trainer.train()
    
    # Save
    model.save_pretrained("rex_omni_lora")
    tokenizer.save_pretrained("rex_omni_lora")

# train()