# Rex-Omni LoRA Fine-tuning on VRSBench Dataset

This notebook demonstrates how to fine-tune Rex-Omni with LoRA on the VRSBench grounding dataset.

**Requirements:**
- GPU with at least 24GB VRAM (A100/RTX 4090 recommended)
- Pre-downloaded VRSBench images

**Pipeline:**
1. Setup environment and imports
2. Configuration
3. Convert VRSBench to TSV format
4. Load model and tokenizer
5. Apply LoRA
6. Setup dataset and training
7. Train and save


## 0. Install Dependencies (Run First!)


In [None]:
# Install core dependencies with compatible versions

print("Installing dependencies...")

# Install base packages
!uv pip install -q psutil setuptools

# Install compatible transformers and peft versions
!uv pip install -q transformers==4.57.3
!uv pip install -q "peft==0.18.0"
!uv pip install -q "bitsandbytes>=0.44.1"
!uv pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

# Install flash attention (requires no-build-isolation)
!uv pip install flash-attn==2.7.4.post1 --no-build-isolation

# Install other requirements
!uv pip install -q accelerate==1.10.1
!uv pip install -q mmengine==0.10.7 omegaconf==2.3.0 ujson==5.11.0

# Install vision dependencies
!uv pip install -q Pillow pandas matplotlib numpy tqdm fastparquet pyarrow
!uv pip install -q qwen-vl-utils

# Install liger-kernel (required for Qwen2.5-VL training)
!uv pip install -q liger-kernel

print("\n✓ Dependencies installed successfully!")

print("\nVerifying installations...")
import transformers
import peft
print(f"✓ Transformers: {transformers.__version__}")
print(f"✓ PEFT: {peft.__version__}")

## 1. Setup Environment


In [None]:
using_Modal = True

In [None]:
import os
import sys
from pathlib import Path

# Determine project root (works in Colab and local)
if using_Modal:
    # Clone repo if in Colab
    if not Path('/root/Rex-Omni').exists():
        !git clone https://github.com/IDEA-Research/Rex-Omni.git /root/Rex-Omni
    PROJECT_ROOT = Path('/root/Rex-Omni')
    os.chdir(PROJECT_ROOT)
else:
    # Local environment - find project root
    PROJECT_ROOT = Path.cwd()
    while PROJECT_ROOT.name != 'Rex-Omni' and PROJECT_ROOT.parent != PROJECT_ROOT:
        PROJECT_ROOT = PROJECT_ROOT.parent
    if PROJECT_ROOT.name != 'Rex-Omni':
        PROJECT_ROOT = Path.cwd()
    os.chdir(PROJECT_ROOT)

# Add finetuning to path
FINETUNING_PATH = PROJECT_ROOT / 'finetuning'
for p in [str(PROJECT_ROOT), str(FINETUNING_PATH)]:
    if p not in sys.path:
        sys.path.insert(0, p)

print(f"✓ Project root: {PROJECT_ROOT}")
print(f"✓ Finetuning path: {FINETUNING_PATH}")
print(f"✓ Working directory: {os.getcwd()}")


In [None]:
# Core imports
import json
import base64
import io
from tqdm import tqdm

import torch
import pandas as pd
import transformers
from PIL import Image

# Liger kernel - MUST be imported before model loading
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl

# Transformers components
from transformers import (
    AutoTokenizer,
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
    Trainer,
)
from engine.argument import TrainingArguments

# PEFT for LoRA
from peft import LoraConfig, get_peft_model, TaskType

# Finetuning components
from engine.argument import DataArguments
from dataset.tsv_dataset import GroundingTSVDataset
from dataset.collator import DataCollatorForSupervisedDataset
from dataset.task_fns import GroundingTaskFn
from dataset.task_fns.task_prompts.grounding_task import GROUNDING_SINGLE_REGION_STAGE_XYXY

print(f"✓ PyTorch: {torch.__version__}")
print(f"✓ Transformers: {transformers.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
    print(f"✓ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

from dataset.json_dataset import GroundingJsonDataset


## 2. Configuration


In [None]:
# ==================== MODEL ====================
MODEL_NAME = "IDEA-Research/Rex-Omni"

# ==================== DATA ====================
VRSBENCH_PARQUET = PROJECT_ROOT / "vrsbench_val_data.parquet"
VRSBENCH_IMAGES_DIR = PROJECT_ROOT / "Images_validation" / "Images_val"  # Update this path!
TSV_OUTPUT_DIR = PROJECT_ROOT / "data" / "vrsbench_finetune"
NUM_SAMPLES = 1000  # Number of samples to use for finetuning

# Sequence length (model will auto-detect pixel values)
MAX_LENGTH = 4096

# ==================== LORA ====================
LORA_CONFIG = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# ==================== TRAINING ====================
OUTPUT_DIR = PROJECT_ROOT / "outputs" / "rex-omni-lora-vrsbench"
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-5
NUM_EPOCHS = 3
WARMUP_RATIO = 0.03
SAVE_STEPS = 100
LOGGING_STEPS = 10

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Samples: {NUM_SAMPLES}")
print(f"  LoRA rank: {LORA_CONFIG.r}")
print(f"  Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION_STEPS} = {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  Output: {OUTPUT_DIR}")
print("\n  Note: Image pixel values will be auto-detected from model config")


## 3. Convert VRSBench to TSV Format

The finetuning code expects data in TSV format:
- `train.images.tsv`: `{byte_offset}\t{base64_image}\n`
- `train.annotations.tsv`: `{img_byte_offset}\t{annotation_json}\n`
- `train.annotations.tsv.lineidx`: Byte offsets for annotation file


In [None]:
def convert_vrsbench_to_json(
    parquet_path: Path,
    images_dir: Path,
    output_dir: Path,
    num_samples: int = 1000,
    project_root: Path = None
):
    """
    Convert VRSBench parquet to a simple JSON dataset file.
    Output format: List[Dict] where each dict has 'image_path', 'boxes', 'labels'.
    """
    import json
    import shutil
    
    output_dir = Path(output_dir)
    if output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    json_file = output_dir / "train.json"
    
    print(f"Loading {parquet_path}...")
    df = pd.read_parquet(parquet_path)
    
    df_sample = df.head(num_samples).copy()
    print(f"  Processing {len(df_sample)} samples")
    
    dataset_list = []
    skipped = 0
    
    for idx, row in tqdm(df_sample.iterrows(), total=len(df_sample), desc="Converting"):
        try:
            # 1. Resolve Image Path
            image_path = row.get('image_path', '')
            if not image_path:
                skipped += 1
                continue
                
            img_path = None
            candidates = [
                Path(image_path),
                images_dir / Path(image_path).name,
                project_root / image_path.lstrip('./') if project_root else None,
            ]
            for candidate in candidates:
                if candidate and candidate.exists():
                    img_path = candidate
                    break
            
            if img_path is None:
                skipped += 1
                continue
                
            # 2. Load Image (to get size for denormalization)
            try:
                pil_img = Image.open(img_path)
                width, height = pil_img.size
            except Exception:
                skipped += 1
                continue
                
            # 3. Parse Annotations
            objects_data = row.get('objects')
            if isinstance(objects_data, str):
                objects_list = json.loads(objects_data)
            else:
                objects_list = objects_data
                
            if not objects_list:
                skipped += 1
                continue
                
            boxes = []
            labels = []
            
            for obj in objects_list:
                bbox_norm = obj.get('obj_coord')
                if not bbox_norm or len(bbox_norm) != 4:
                    continue
                    
                # Denormalize [0-1] -> [pixel]
                x0, y0, x1, y1 = bbox_norm
                boxes.append([
                    x0 * width, 
                    y0 * height, 
                    x1 * width, 
                    y1 * height
                ])
                
                phrase = obj.get('referring_sentence', '') or obj.get('obj_cls', '')
                labels.append(str(phrase))
                
            if not boxes:
                skipped += 1
                continue
                
            dataset_list.append({
                "image_path": str(img_path.absolute()),
                "boxes": boxes,
                "labels": labels
            })
            
        except Exception as e:
            skipped += 1
            continue
            
    print(f"\nSaving {len(dataset_list)} samples to {json_file}...")
    with open(json_file, 'w') as f:
        json.dump(dataset_list, f, indent=2)
        
    print(f"Done! Skipped {skipped} items.")
    return len(dataset_list)


In [None]:
# Convert to JSON
if TSV_OUTPUT_DIR.exists():
    import shutil
    shutil.rmtree(TSV_OUTPUT_DIR)

print("Converting VRSBench to JSON format...")
num_converted = convert_vrsbench_to_json(
    parquet_path=VRSBENCH_PARQUET,
    images_dir=VRSBENCH_IMAGES_DIR,
    output_dir=TSV_OUTPUT_DIR,
    num_samples=NUM_SAMPLES,
    project_root=PROJECT_ROOT
)
print(f"\n✓ Ready for training with {num_converted} samples")


## 4. Load Model, Tokenizer, and Processor


In [None]:
print("Loading tokenizer and processor...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    model_max_length=MAX_LENGTH,
    padding_side="right",
    use_fast=False,
)
print(f"✓ Tokenizer loaded (vocab size: {len(tokenizer)})")

# Load processor (for image processing)
# Uses model's default min/max_pixels from config - no need to specify manually
processor = AutoProcessor.from_pretrained(MODEL_NAME)

# Print what defaults the model uses
img_proc = processor.image_processor
print(f"✓ Processor loaded")
print(f"  min_pixels: {getattr(img_proc, 'min_pixels', 'default')}")
print(f"  max_pixels: {getattr(img_proc, 'max_pixels', 'default')}")


In [None]:
print("Loading model...")

# CRITICAL: Apply liger kernel BEFORE loading model (as in train.py line 114)
apply_liger_kernel_to_qwen2_5_vl(fused_linear_cross_entropy=False)
print("✓ Liger kernel applied")

# Load model with correct class (Qwen2_5_VLForConditionalGeneration, NOT AutoModelForCausalLM)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,  # FIXED: was 'dtype', should be 'torch_dtype'
    attn_implementation="flash_attention_2",
    device_map="auto",
)
model.config.use_cache = False  # Disable KV cache for training

print(f"✓ Model loaded")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
print(f"  Device: {next(model.parameters()).device}")
print(f"  Dtype: {next(model.parameters()).dtype}")


## 5. Apply LoRA


In [None]:
print("Applying LoRA...")

# Enable gradient checkpointing for memory efficiency
if hasattr(model, "enable_input_require_grads"):
    model.enable_input_require_grads()
else:
    def make_inputs_require_grad(module, input, output):
        output.requires_grad_(True)
    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

# Apply LoRA
model = get_peft_model(model, LORA_CONFIG)
model.print_trainable_parameters()

# Verify trainable parameters
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"\n✓ LoRA applied")
print(f"  Trainable: {trainable / 1e6:.2f}M ({100 * trainable / total:.3f}%)")

if trainable == 0:
    raise RuntimeError("❌ No trainable parameters! LoRA not applied correctly.")

# List some trainable parameters for verification
print("\n  Sample trainable params:")
trainable_names = [n for n, p in model.named_parameters() if p.requires_grad]
for name in trainable_names[:5]:
    print(f"    - {name}")
print(f"    ... and {len(trainable_names) - 5} more")

# Ensure model is in training mode
model.train()

# Note: The warning "None of the inputs have requires_grad=True" during training
# is NORMAL for the frozen vision encoder parts. As long as loss is decreasing,
# LoRA is training correctly.
print("\n✓ Model ready for training")


## 6. Setup Dataset


In [None]:
print("Setting up dataset...")

data_args = DataArguments()
data_args.image_processor = processor.image_processor
data_args.model_type = "qwen2.5vl"

# Auto-detect pixel limits
img_proc = processor.image_processor
min_pixels = getattr(img_proc, 'min_pixels', 4 * 28 * 28)
max_pixels = getattr(img_proc, 'max_pixels', 16384 * 28 * 28)

print(f"✓ DataArguments created")
print(f"  min_pixels: {min_pixels}")
print(f"  max_pixels: {max_pixels}")

# JSON file path
json_file = str(TSV_OUTPUT_DIR / "train.json")

if not Path(json_file).exists():
    raise FileNotFoundError(f"Dataset file not found: {json_file}\nRun conversion first!")

task_fn_config = dict(
    type=GroundingTaskFn,
    task_prompts=GROUNDING_SINGLE_REGION_STAGE_XYXY,
    image_min_pixels=min_pixels,
    image_max_pixels=max_pixels,
)

# Instantiate JSON Dataset
from dataset.json_dataset import GroundingJsonDataset

train_dataset = GroundingJsonDataset(
    json_file=json_file,
    tokenizer=tokenizer,
    data_args=data_args,
    image_min_pixels=min_pixels,
    image_max_pixels=max_pixels,
    task_fn=task_fn_config,
    system_message="You are a helpful assistant.",
    ori_box_format="xyxy",
    dataset_name="vrsbench_grounding",
    max_length=MAX_LENGTH,
)

print(f"✓ Dataset created: {len(train_dataset)} samples")


In [None]:
# Validate dataset before training - test multiple samples to catch TSV errors early
print("Validating dataset (testing first 50 samples)...")
print("This catches TSV format errors before training starts.\n")

from tqdm import tqdm

valid_samples = 0
error_samples = []

# Test first 50 samples (or all if less)
test_count = min(50, len(train_dataset))

for i in tqdm(range(test_count), desc="Validating"):
    try:
        sample = train_dataset[i]
        valid_samples += 1
    except Exception as e:
        error_samples.append((i, str(e)))

print(f"\n{'='*50}")
print(f"Validation Results:")
print(f"  ✓ Valid samples: {valid_samples}/{test_count}")
print(f"  ✗ Error samples: {len(error_samples)}/{test_count}")

if error_samples:
    print(f"\nFirst 5 errors:")
    for idx, err in error_samples[:5]:
        print(f"  Sample {idx}: {err[:80]}...")
    
    error_rate = len(error_samples) / test_count
    if error_rate > 0.1:  # More than 10% errors
        print(f"\n❌ ERROR RATE TOO HIGH ({error_rate*100:.1f}%)")
        print("   This indicates corrupted TSV files.")
        print("   Please re-run Cell 10 (conversion) and Cell 11 to regenerate data.")
        raise RuntimeError(f"Too many corrupt samples: {len(error_samples)}/{test_count}")
    else:
        print(f"\n⚠️  Some samples have errors but error rate is acceptable ({error_rate*100:.1f}%)")
        print("   Training will skip these samples.")
else:
    print("\n✓ All tested samples are valid!")

# Show sample structure
print("\nSample data structure:")
sample = train_dataset[0]
print(f"  Keys: {list(sample.keys())}")
print(f"  input_ids shape: {sample['input_ids'].shape}")
print(f"  labels shape: {sample['labels'].shape}")
if 'pixel_values' in sample:
    pv = sample['pixel_values']
    if isinstance(pv, list):
        print(f"  pixel_values: {len(pv)} tensors")
    else:
        print(f"  pixel_values shape: {pv.shape}")


In [None]:
# Create data collator (only takes tokenizer, see collator.py line 36)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

# Test collator
print("Testing data collator...")
batch = data_collator([train_dataset[0]])
print(f"✓ Batch created")
print(f"  Keys: {list(batch.keys())}")
print(f"  input_ids: {batch['input_ids'].shape}")
print(f"  labels: {batch['labels'].shape}")


## 7. Training


In [None]:
# Create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Training arguments
training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR),
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_ratio=WARMUP_RATIO,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    bf16=True,
    gradient_checkpointing=True,
    dataloader_num_workers=4,
    remove_unused_columns=False,  # Important for custom datasets
    report_to="none",  # Disable wandb/tensorboard
    optim="adamw_torch",
)

print(f"Training configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Output: {OUTPUT_DIR}")


In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
    processing_class=tokenizer,
)

print("✓ Trainer created")
print(f"  Dataset size: {len(train_dataset)}")
print(f"  Steps per epoch: {len(train_dataset) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)}")


In [None]:
# Start training
print("Starting training...")
print("="*50)

trainer.train()

print("="*50)
print("✓ Training complete!")


## 8. Save Model


In [None]:
# Save final model
final_output = OUTPUT_DIR / "final"
final_output.mkdir(parents=True, exist_ok=True)

print(f"Saving model to {final_output}...")

# Save LoRA adapter
model.save_pretrained(str(final_output))
tokenizer.save_pretrained(str(final_output))

# Save image processor
processor.image_processor.save_pretrained(str(final_output))

# Copy chat template if available
import shutil
try:
    from huggingface_hub import hf_hub_download
    chat_template_src = hf_hub_download(repo_id=MODEL_NAME, filename="chat_template.json")
    shutil.copy(chat_template_src, str(final_output / "chat_template.json"))  # FIXED: convert Path to str
    print("  ✓ Chat template copied")
except Exception:
    print("  Note: chat_template.json not found, skipping")

print(f"\n✓ Model saved to: {final_output}")
print(f"\nTo load the model later:")
print(f"  from peft import PeftModel")
print(f"  model = Qwen2_5_VLForConditionalGeneration.from_pretrained('{MODEL_NAME}')")
print(f"  model = PeftModel.from_pretrained(model, '{final_output}')")


## 9. (Optional) Test Inference


In [None]:
# Quick inference test - using CORRECT pattern from RexOmniWrapper
print("Testing inference...")

model.eval()

# Load a test image
test_images = list(VRSBENCH_IMAGES_DIR.glob("*.png"))[:1] if VRSBENCH_IMAGES_DIR.exists() else []
if test_images:
    test_img_path = test_images[0]
    test_img = Image.open(test_img_path).convert('RGB')
    print(f"Test image: {test_img_path.name} ({test_img.size})")
    
    # Format message for Qwen2.5-VL
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {
            "role": "user",
            "content": [
                {"type": "image", "image": test_img},
                {"type": "text", "text": "Detect all objects in this image."}
            ]
        }
    ]
    
    # CORRECT PATTERN (from RexOmniWrapper._generate_transformers):
    # Step 1: Apply chat template
    text = processor.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # Step 2: Process with processor() - returns 2D pixel_values (CORRECT for Qwen2.5-VL!)
    # The model expects flattened patches, NOT 4D tensors
    inputs = processor(
        text=[text],
        images=[test_img],
        padding=True,
        return_tensors="pt",
    )
    
    # Move to device
    device = next(model.parameters()).device
    inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
    
    print(f"  input_ids shape: {inputs['input_ids'].shape}")
    print(f"  pixel_values shape: {inputs['pixel_values'].shape}")
    print(f"  image_grid_thw: {inputs.get('image_grid_thw', 'N/A')}")
    
    # Note: 2D pixel_values is CORRECT for Qwen2.5-VL (flattened patches)
    # The model handles this format internally
    
    # Generate
    print("\nGenerating...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            pad_token_id=processor.tokenizer.eos_token_id,
        )
    
    # Decode response - trim input tokens
    generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
    response = processor.decode(generated_ids, skip_special_tokens=True)
    
    print(f"\n✓ Inference successful!")
    print(f"\nResponse:\n{response}")
else:
    print("No test images found - skipping inference test")


---

## Summary

This notebook fine-tuned Rex-Omni with LoRA on VRSBench grounding data.

**Key components aligned with finetuning codebase:**
1. **Model**: `Qwen2_5_VLForConditionalGeneration` (from train.py line 115)
2. **Liger Kernel**: Applied before model loading (train.py line 114)
3. **TSV Format**: Three files matching `convert_json_data_to_tsv.py` format
4. **DataArguments**: min_pixels, max_pixels + dynamic model_type, image_processor (train.py lines 122-125)
5. **Task Function**: `GroundingTaskFn` with `GROUNDING_SINGLE_REGION_STAGE_XYXY` prompts (sft.py config)
6. **Data Collator**: `DataCollatorForSupervisedDataset(tokenizer=tokenizer)` (collator.py line 36)

**Output:**
- LoRA adapter saved to `outputs/rex-omni-lora-vrsbench/final/`
- Load with `PeftModel.from_pretrained()`
