# Rex-Omni H100 Standalone Fine-tuning

This notebook is a complete, standalone solution for fine-tuning the **Rex-Omni** model (based on Qwen2.5-VL) using LoRA/QLoRA.

**Features:**
- **H100 Optimized**: Uses `Flash Attention 2` and `bfloat16` for maximum speed on H100/A100 GPUs.
- **Standalone**: Includes data conversion, setup, and training in one place.
- **Rex-Omni Specific**: Handles the specific grounding format required by the model.

## 1. Installation
Install necessary libraries. We use `flash-attn` for H100 acceleration.

In [None]:
%pip install -q torch>=2.4.0 transformers>=4.46.0 peft>=0.13.0 bitsandbytes>=0.44.0 accelerate>=1.0.0 qwen-vl-utils datasets wandb scipy
%pip install -q flash-attn --no-build-isolation

## 2. Configuration
Set your parameters here.

In [None]:
import os

# CONFIGURATION
MODEL_ID = "IDEA-Research/Rex-Omni"
DATA_PATH = "/kaggle/input/your-dataset/metadata.json" # CHANGE THIS to your metadata file path
IMAGE_FOLDER = "/kaggle/input/your-dataset/images"     # CHANGE THIS to your image folder path
OUTPUT_DIR = "/kaggle/working/rex-omni-finetuned"

# Training Hyperparameters
NUM_EPOCHS = 3
BATCH_SIZE = 4          # H100 has 80GB, can handle larger batches
GRAD_ACCUM = 2          # Adjust based on memory
LORA_RANK = 64
USE_QLORA = True        # Set False for full bf16 training if memory allows (faster on H100)

## 3. Data Preparation Helper
This function converts your metadata JSON into the specific JSONL format required by Rex-Omni.

In [None]:
import json
from pathlib import Path
from PIL import Image
from tqdm import tqdm

def normalize_bbox(bbox, width, height):
    x0, y0, x1, y1 = bbox
    x0_n = max(0.0, min(1.0, x0 / width))
    y0_n = max(0.0, min(1.0, y0 / height))
    x1_n = max(0.0, min(1.0, x1 / width))
    y1_n = max(0.0, min(1.0, y1 / height))
    
    x0_bin = int(x0_n * 999)
    y0_bin = int(y0_n * 999)
    x1_bin = int(x1_n * 999)
    y1_bin = int(y1_n * 999)
    return f"<{x0_bin}><{y0_bin}><{x1_bin}><{y1_bin}>"

def prepare_data(image_folder, metadata_file, output_file):
    print(f"Converting data from {metadata_file}...")
    image_folder_path = Path(image_folder)
    
    with open(metadata_file, 'r') as f:
        data = json.load(f)
        
    with open(output_file, 'w') as f_out:
        for item in tqdm(data):
            image_filename = item.get('image')
            if not image_filename: continue
                
            image_path = image_folder_path / image_filename
            if not image_path.exists(): continue
            
            try:
                with Image.open(image_path) as img:
                    width, height = img.size
            except: continue

            objects = item.get('objects', [])
            if not objects: continue

            cat_to_boxes = {}
            for obj in objects:
                cat = obj['category']
                bbox = obj['bbox']
                if cat not in cat_to_boxes: cat_to_boxes[cat] = []
                cat_to_boxes[cat].append(bbox)
            
            answer_parts = []
            for cat, boxes in cat_to_boxes.items():
                box_tokens = [normalize_bbox(b, width, height) for b in boxes]
                box_str = ",".join(box_tokens)
                part = f"<|object_ref_start|>{cat}<|object_ref_end|><|box_start|>{box_str}<|box_end|>"
                answer_parts.append(part)
            
            answer_text = ", ".join(answer_parts)
            
            messages = [
                {"role": "user", "content": [
                    {"type": "image", "image": str(image_path.absolute())},
                    {"type": "text", "text": "<image>\nDetect the objects in this image."}
                ]},
                {"role": "assistant", "content": [{"type": "text", "text": answer_text}]}
            ]
            f_out.write(json.dumps({"messages": messages}) + '\n')
    print(f"Data saved to {output_file}")

# Run conversion
JSONL_PATH = "/kaggle/working/train_data.jsonl"
# Uncomment the line below if you have uploaded data
# prepare_data(IMAGE_FOLDER, DATA_PATH, JSONL_PATH)

## 4. Training Logic
This section initializes the model, applies LoRA, and runs the training loop.

In [None]:
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TrainingArguments, Trainer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
from datasets import load_dataset
from qwen_vl_utils import process_vision_info

def train():
    # 1. Load Processor
    processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=256*28*28, max_pixels=1280*28*28)

    # 2. Load Dataset
    # Ensure JSONL_PATH exists (run prepare_data first)
    if not os.path.exists(JSONL_PATH):
        print(f"Dataset not found at {JSONL_PATH}. Please run data preparation.")
        return
        
    dataset = load_dataset("json", data_files=JSONL_PATH, split="train")

    # 3. Collator
    def collate_fn(examples):
        texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True) for example in examples]
        image_inputs, video_inputs = process_vision_info(examples)
        inputs = processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
        labels = inputs["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
        labels[labels == image_token_id] = -100
        inputs["labels"] = labels
        return inputs

    # 4. Model Init
    bnb_config = None
    if USE_QLORA:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
        )

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2", # H100 Optimization
        device_map="auto"
    )

    if USE_QLORA:
        model = prepare_model_for_kbit_training(model)

    # 5. LoRA Config
    peft_config = LoraConfig(
        r=LORA_RANK,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        task_type=TaskType.CAUSAL_LM,
        lora_dropout=0.05,
        bias="none"
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    # 6. Training Args
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUM,
        num_train_epochs=NUM_EPOCHS,
        learning_rate=2e-4,
        bf16=True, # H100 Optimization
        logging_steps=10,
        save_strategy="epoch",
        report_to="none",
        remove_unused_columns=False,
        gradient_checkpointing=True,
        dataloader_pin_memory=True
    )

    trainer = Trainer(model=model, args=training_args, train_dataset=dataset, data_collator=collate_fn)
    trainer.train()
    trainer.save_model(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)
    print(f"Model saved to {OUTPUT_DIR}")

# Start Training
# train()