# Fine-tuning Rex-Omni with LoRA on VRSBench Dataset

This notebook fine-tunes the Rex-Omni model using LoRA (Low-Rank Adaptation) on 1000 samples from the VRSBench validation dataset.

## Why This Approach Works

**Rex-Omni uses the standard Qwen2.5-VL architecture from Hugging Face**, so LoRA can be applied directly:
- ✅ Model architecture is standard (no custom modifications)
- ✅ LoRA works with standard attention/MLP layers
- ✅ Simpler and more flexible than using the full finetuning codebase
- ✅ Directly loads from `IDEA-Research/Rex-Omni` on Hugging Face

**Note**: The "architecture changes" people mention refer to training infrastructure (data processing, custom trainers), not the model architecture itself.

## Overview
1. Load and explore VRSBench parquet data
2. Convert data to TSV format required for training
3. Set up LoRA configuration
4. Fine-tune the model with LoRA
5. Save and evaluate the fine-tuned model

## Best Practices & Troubleshooting
- **CUDA Errors**: If you see `indexSelectLargeIndex` or device-side asserts, it usually means `image_grid_thw` mismatch. We have added safety checks in the dataset and wrapper to catch this early.
- **Restart Runtime**: If a CUDA error occurs, you **must** restart the runtime/kernel to recover.
- **Canonical Flow**: We strictly use `qwen_vl_utils.process_vision_info` to ensure alignment between pixel values and grid shapes.


## 1. Setup for Google Colab

### 1.1 Clone Repository and Install Dependencies


In [None]:
import os
import sys
import importlib
import subprocess
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules
print(f"Detected Google Colab runtime: {IN_COLAB}")

if not IN_COLAB:
    print("Skipping Colab bootstrap because this runtime is not Google Colab.")
    print("Ensure dependencies are installed locally via `pip install -e .` from the project root.")
else:
    repo_path = Path("/content/Rex-Omni")

    def run_cmd(cmd: str, check: bool = True):
        print(f"\n$ {cmd}")
        completed = subprocess.run(cmd, shell=True)
        if check and completed.returncode != 0:
            raise RuntimeError(f"Command failed with exit code {completed.returncode}: {cmd}")

    if not repo_path.exists():
        run_cmd("git clone https://github.com/IDEA-Research/Rex-Omni.git /content/Rex-Omni")
    else:
        print("Repository already exists at /content/Rex-Omni; skipping clone.")

    os.chdir(repo_path)
    print(f"Working directory set to: {os.getcwd()}")

    print("Step 1: Installing compatible PyTorch versions...")
    run_cmd("pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 --force-reinstall")

    print("Step 2: Installing compatible numpy...")
    run_cmd('pip install -q "numpy<2.0" --force-reinstall')

    print("Step 3: Installing transformers with modeling_layers support...")
    run_cmd('pip install -q "transformers>=4.44.0,<5.0.0" --upgrade')

    print("Step 4: Verifying transformers.modeling_layers...")
    if 'transformers' in sys.modules:
        importlib.reload(sys.modules['transformers'])

    def check_modeling_layers():
        try:
            from transformers.modeling_utils import modeling_layers  # noqa: F401
            print("\u2713 transformers.modeling_layers found via modeling_utils!")
            return True
        except ImportError:
            try:
                from transformers import modeling_layers  # type: ignore  # noqa: F401
                print("\u2713 transformers.modeling_layers found!")
                return True
            except ImportError:
                return False

    if not check_modeling_layers():
        print("\u2717 modeling_layers not found. Installing from source...")
        run_cmd("pip uninstall -y transformers -q")
        run_cmd("pip install -q git+https://github.com/huggingface/transformers.git --no-deps")
        run_cmd('pip install -q "transformers[torch]" --upgrade')
        if not check_modeling_layers():
            print("\u26a0 WARNING: modeling_layers still not found. You may need to use peft==0.13.0.")

    print("Step 5: Installing peft...")
    try:
        run_cmd('pip install -q "peft>=0.18.0" --upgrade')
        import peft
        print(f"\u2713 peft {peft.__version__} installed successfully!")
    except Exception as exc:  # noqa: BLE001
        print(f"\u26a0 peft >=0.18.0 installation failed: {exc}")
        print("Trying older compatible version...")
        run_cmd('pip install -q "peft==0.13.0" --upgrade')

    print("Step 6: Installing other dependencies...")
    run_cmd("pip install -q accelerate datasets pandas pyarrow pillow")

    print("Trying to install flash-attn (safe to fail)...")
    flash_status = subprocess.run("pip install -q flash-attn --no-build-isolation", shell=True)
    if flash_status.returncode == 0:
        print("\u2713 flash-attn installed")
    else:
        print("\u26a0 flash-attn installation failed. Continuing without it...")

    print("Step 7: Installing Rex-Omni package (no deps)...")
    run_cmd("pip install -v -e . --no-deps")

    print("Step 8: Installing finetuning dependencies (no deps)...")
    run_cmd("cd finetuning && pip install -v -e . --no-deps")

    print("Step 9: Installing Rex-Omni extra dependencies...")
    run_cmd("pip install -q qwen_vl_utils==0.0.14 vllm==0.8.2 gradio==4.44.1 gradio_image_prompter==0.1.0 pydantic==2.10.6 pycocotools==2.0.10 shapely==2.1.2 gradio_bbox_annotator==0.1.1")

    print("\n" + "=" * 50)
    print("FINAL VERIFICATION")
    print("=" * 50)
    import torch
    import transformers
    try:
        import peft  # noqa: F401
        print(f"\u2713 torch: {torch.__version__}")
        print(f"\u2713 transformers: {transformers.__version__}")
        print(f"\u2713 peft: {peft.__version__}")
        if not check_modeling_layers():
            print("\u26a0 transformers.modeling_layers still unavailable; peft >=0.18 may fail.")
        else:
            print("\u2713 transformers.modeling_layers available")
        print("\n\u2713 Installation complete! Continue with the notebook.")
    except Exception as exc:  # noqa: BLE001
        print(f"\u2717 Verification failed: {exc}")
        print("Please re-run the troubleshooting cell if needed.")


### 1.1.1 Post-Installation Verification

**After running Cell 2, run this cell to verify everything is working:**


In [None]:
# Post-installation verification and fix
print("="*60)
print("POST-INSTALLATION VERIFICATION")
print("="*60)

# Check versions
import torch
import transformers
import sys

print(f"\nCurrent versions:")
print(f"  torch: {torch.__version__}")
print(f"  transformers: {transformers.__version__}")

# Check for modeling_layers - it might be in different locations
print(f"\nChecking for modeling_layers...")
modeling_layers_found = False
modeling_layers_location = None

# Try different import paths
try:
    from transformers import modeling_layers
    modeling_layers_found = True
    modeling_layers_location = "transformers.modeling_layers"
    print("✓ Found: transformers.modeling_layers")
except ImportError:
    try:
        from transformers.modeling_utils import modeling_layers
        modeling_layers_found = True
        modeling_layers_location = "transformers.modeling_utils.modeling_layers"
        print("✓ Found: transformers.modeling_utils.modeling_layers")
    except ImportError:
        try:
            # Check if it exists as an attribute
            if hasattr(transformers, 'modeling_layers'):
                modeling_layers_found = True
                modeling_layers_location = "transformers.modeling_layers (attribute)"
                print("✓ Found: transformers.modeling_layers (as attribute)")
            elif hasattr(transformers.modeling_utils, 'modeling_layers'):
                modeling_layers_found = True
                modeling_layers_location = "transformers.modeling_utils.modeling_layers (attribute)"
                print("✓ Found: transformers.modeling_utils.modeling_layers (as attribute)")
            else:
                print("✗ modeling_layers not found in standard locations")
        except:
            print("✗ modeling_layers not found")

# Try importing peft to see if it works
print(f"\nChecking peft...")
try:
    import peft
    print(f"✓ peft {peft.__version__} imported successfully")
    
    # Try to use peft to see if modeling_layers is actually needed
    try:
        from peft import LoraConfig
        print("✓ LoraConfig imported successfully")
        peft_works = True
    except Exception as e:
        print(f"✗ LoraConfig import failed: {e}")
        peft_works = False
except Exception as e:
    print(f"✗ peft import failed: {e}")
    peft_works = False

# Summary
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
if modeling_layers_found:
    print(f"✓ modeling_layers found at: {modeling_layers_location}")
else:
    print("⚠ modeling_layers not found, but this might be okay")
    print("  Some transformers versions don't expose it directly")

if peft_works:
    print("✓ peft is working correctly")
    print("\n✅ Everything looks good! You can proceed with the notebook.")
else:
    print("✗ peft has issues")
    print("\n⚠ You may need to:")
    print("  1. Restart the runtime (Runtime → Restart runtime)")
    print("  2. Run Cell 2 again")
    print("  3. Or use peft 0.13.0 instead of 0.18.0+")


### 1.1.0 Quick Fix for Current Error

**If you're seeing the `modeling_layers` error right now, run the cell below first:**


In [None]:
# QUICK FIX: Run this cell if you're getting version conflicts or modeling_layers errors
# This will fix PyTorch/transformers/peft compatibility issues

print("="*60)
print("QUICK FIX: Resolving version conflicts...")
print("="*60)

# Step 1: Fix PyTorch version conflicts
print("\n1. Fixing PyTorch versions...")
!pip uninstall -y torch torchvision torchaudio -q
!pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

# Step 2: Fix numpy
print("\n2. Fixing numpy version...")
!pip install -q "numpy<2.0" --force-reinstall

# Step 3: Reinstall transformers
print("\n3. Reinstalling transformers...")
!pip uninstall -y transformers -q
!pip install -q "transformers>=4.44.0,<5.0.0" --upgrade

# Step 4: Verify modeling_layers
print("\n4. Verifying transformers.modeling_layers...")
import importlib
import sys
if 'transformers' in sys.modules:
    importlib.reload(sys.modules['transformers'])

try:
    from transformers import modeling_layers
    print("✓ transformers.modeling_layers found!")
except ImportError:
    try:
        from transformers.modeling_utils import modeling_layers
        print("✓ transformers.modeling_layers found (via modeling_utils)!")
    except ImportError:
        print("⚠ modeling_layers not found. Installing from source...")
        !pip uninstall -y transformers -q
        !pip install -q git+https://github.com/huggingface/transformers.git --no-deps
        !pip install -q "transformers[torch]" --upgrade

# Step 5: Reinstall peft
print("\n5. Reinstalling peft...")
try:
    !pip install -q "peft>=0.18.0" --upgrade --force-reinstall
    import peft
    print(f"✓ peft {peft.__version__} installed!")
except Exception as e:
    print(f"⚠ peft 0.18.0+ failed: {e}")
    print("Installing older compatible version...")
    !pip install -q "peft==0.13.0" --upgrade --force-reinstall

# Final verification
print("\n" + "="*60)
print("VERIFICATION")
print("="*60)
try:
    import torch
    import transformers
    import peft
    print(f"✓ torch: {torch.__version__}")
    print(f"✓ transformers: {transformers.__version__}")
    print(f"✓ peft: {peft.__version__}")
    
    # Check modeling_layers
    try:
        from transformers import modeling_layers
        print("✓ transformers.modeling_layers: Available")
    except:
        try:
            from transformers.modeling_utils import modeling_layers
            print("✓ transformers.modeling_layers: Available (via modeling_utils)")
        except:
            print("⚠ transformers.modeling_layers: Not found")
            print("  This may cause issues. Try restarting runtime and running Cell 2.")
    
    print("\n✓ All fixed! Restart the runtime (Runtime → Restart runtime) and continue.")
except Exception as e:
    print(f"✗ Error: {e}")
    print("\nPlease restart the runtime and run Cell 2 again.")


### 1.1.1 Troubleshooting Version Compatibility

If you encounter `ModuleNotFoundError: No module named 'transformers.modeling_layers'`, try one of these solutions:


In [None]:
# If you still get the modeling_layers error after running Cell 2, run this cell:

# Solution: Force reinstall transformers from source
print("Force reinstalling transformers from source...")
!pip uninstall -y transformers
!pip install -q git+https://github.com/huggingface/transformers.git --no-deps
!pip install -q "transformers[torch]" --upgrade

# Verify installations
print("\nVerifying installations...")
try:
    import transformers
    from transformers import modeling_layers
    print(f"✓ transformers version: {transformers.__version__}")
    print(f"✓ transformers.modeling_layers imported successfully!")
except ImportError as e:
    print(f"✗ Error: {e}")
    print("Trying alternative: installing transformers 4.44.0+")
    !pip install -q "transformers>=4.44.0" --force-reinstall

try:
    import peft
    print(f"✓ peft version: {peft.__version__}")
except ImportError as e:
    print(f"✗ peft import error: {e}")
    !pip install -q "peft>=0.18.0" --force-reinstall


In [None]:
import os
# Set CUDA_LAUNCH_BLOCKING for better error messages during debugging
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "1")

import sys
import json
import base64
import io
from pathlib import Path
from typing import Dict, List
from tqdm import tqdm

import pandas as pd
import torch
from PIL import Image
from transformers import (
    AutoProcessor,
    AutoTokenizer,
    Qwen2_5_VLForConditionalGeneration,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model, TaskType

IN_COLAB = "google.colab" in sys.modules
ENV_PROJECT_ROOT = os.environ.get("REX_OMNI_ROOT")
DEFAULT_PROJECT_ROOT = Path("/content/Rex-Omni") if IN_COLAB else Path.cwd()

candidate_roots = []
if ENV_PROJECT_ROOT:
    candidate_roots.append(Path(ENV_PROJECT_ROOT))
candidate_roots.append(DEFAULT_PROJECT_ROOT)
candidate_roots.append(Path.cwd())

project_root = None
for candidate in candidate_roots:
    candidate_path = Path(candidate).expanduser().resolve()
    if candidate_path.exists():
        project_root = candidate_path
        break

if project_root is None:
    raise FileNotFoundError(
        "Unable to locate the Rex-Omni project root. Set the REX_OMNI_ROOT env var or run from within the repo."
    )

os.chdir(str(project_root))

# Add project root to path for local imports
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
if str(project_root / "finetuning") not in sys.path:
    sys.path.insert(0, str(project_root / "finetuning"))

print(f"Project root: {project_root}")
print(f"Current directory: {os.getcwd()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("CUDA not detected. Training/inference will fall back to CPU unless a GPU is attached.")


In [None]:
# ### 1.2 Upload VRSBench Data

# Upload the `vrsbench_val_data.parquet` file to Colab. You can either:
# - Use the file upload widget below
# - Upload to Google Drive and mount it
# - Use `gdown` if the file is on Google Drive


In [None]:
# Option 1: Upload file directly using file widget (Colab only)
if IN_COLAB:
    from google.colab import files  # type: ignore
    import shutil

    # Uncomment to upload file:
    # uploaded = files.upload()
    # for filename in uploaded.keys():
    #     shutil.move(filename, '/content/Rex-Omni/vrsbench_val_data.parquet')
    #     print(f'Moved {filename} to /content/Rex-Omni/vrsbench_val_data.parquet')

    # Option 2: If file is in Google Drive
    # from google.colab import drive
    # drive.mount('/content/drive')
    # !cp /content/drive/MyDrive/vrsbench_val_data.parquet /content/Rex-Omni/

    # Option 3: Download from URL (if available)
    # !wget -O /content/Rex-Omni/vrsbench_val_data.parquet <URL_TO_FILE>
else:
    print("Google Colab utilities are not available in this runtime. ")
    print("Ensure `vrsbench_val_data.parquet` exists locally at `project_root / vrsbench_val_data.parquet` before proceeding.")

print("Make sure vrsbench_val_data.parquet is in /content/Rex-Omni/ or your local project root.")


### 1.3 Import Dependencies


## 2. Load and Explore VRSBench Data


In [None]:
# Load the parquet file
parquet_path = project_root / "vrsbench_val_data.parquet"
df = pd.read_parquet(parquet_path)

print(f"Dataset shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print(f"\nFirst row sample:")
print(df.iloc[0])


In [None]:
# Sample 1000 rows
num_samples = 1000
df_sample = df.sample(n=min(num_samples, len(df)), random_state=42).reset_index(drop=True)
print(f"Sampled {len(df_sample)} samples from {len(df)} total samples")


In [None]:
# Explore the data structure
sample_row = df_sample.iloc[0]
print("Sample data structure:")
print(f"Image ID: {sample_row['image_id']}")
print(f"Image Path: {sample_row['image_path']}")
print(f"Caption: {sample_row['caption'][:100]}...")
print(f"\nObjects (first object):")
objects = json.loads(sample_row['objects'])
if len(objects) > 0:
    print(json.dumps(objects[0], indent=2))


## 3. Prepare VRSBench Data for JSON Training

We now train directly from a JSON manifest that lists image paths, bounding boxes, and labels. The legacy TSV pipeline has been removed to avoid format drift—only the JSON export below is required.


In [None]:
# Legacy TSV conversion has been removed in favor of the JSON dataset pipeline.
# This cell is intentionally left empty to avoid accidental use of the old format.


In [None]:
# Convert VRSBench samples into the JSON manifest consumed by GroundingJsonDataset.
# (Legacy TSV generation has been removed to avoid format drift.)
from collections import defaultdict
from datetime import datetime
from PIL import ImageDraw
import random

output_dir = project_root / "vrsbench_data"
os.makedirs(output_dir, exist_ok=True)

json_path = output_dir / "train.json"
synthetic_debug_path = output_dir / "debug_synthetic_samples.json"

print(f"Converting to JSON dataset at {json_path}...")
skipped_reasons = defaultdict(int)
dataset_list: List[Dict] = []

possible_image_roots = [
    project_root,
    project_root / "Images_validation" / "Images_val",
    project_root / "Images_validation",
]


def resolve_image_path(relative_path: str) -> str | None:
    if os.path.isabs(relative_path) and os.path.exists(relative_path):
        return relative_path
    for root in possible_image_roots:
        candidate = root / relative_path
        if candidate.exists():
            return str(candidate)
    fallback = project_root / Path(relative_path).name
    return str(fallback) if fallback.exists() else None


def normalize_boxes(objects_json, width: int, height: int) -> tuple[list[list[float]], list[str]]:
    boxes, labels = [], []
    for obj in objects_json:
        coords = obj.get("obj_coord")
        if not coords or len(coords) != 4:
            continue
        x0, y0, x1, y1 = coords
        boxes.append([x0 * width, y0 * height, x1 * width, y1 * height])
        labels.append(obj.get("obj_cls", "object"))
    return boxes, labels


def build_record(idx: int, row: pd.Series) -> Dict | None:
    image_path = row.get("image_path", "")
    resolved_path = resolve_image_path(image_path)
    if resolved_path is None:
        skipped_reasons["missing_image"] += 1
        return None

    try:
        with Image.open(resolved_path).convert("RGB") as img:
            width, height = img.size
            if min(width, height) < 28:
                skipped_reasons["image_too_small"] += 1
                return None
    except Exception as exc:  # noqa: BLE001
        skipped_reasons["image_load_failed"] += 1
        print(f"Warning: failed to open {resolved_path}: {exc}")
        return None

    try:
        objects = row["objects"]
        if isinstance(objects, str):
            objects = json.loads(objects)
    except Exception as exc:  # noqa: BLE001
        skipped_reasons["objects_parse_failed"] += 1
        print(f"Warning: failed to parse objects for idx={idx}: {exc}")
        return None

    boxes, labels = normalize_boxes(objects, width, height)
    if not boxes:
        skipped_reasons["no_boxes"] += 1
        return None

    return {
        "image_path": str(resolved_path),
        "boxes": boxes,
        "labels": labels,
        "image_id": row.get("image_id", f"row_{idx}"),
        "source": "vrsbench",
    }


def build_synthetic_debug_samples(num_samples: int = 2) -> List[Dict]:
    print(f"Building {num_samples} synthetic debug sample(s)...")
    synthetic_samples = []
    for s_idx in range(num_samples):
        img = Image.new("RGB", (512, 512), color=(30 + s_idx * 40, 30, 60))
        draw = ImageDraw.Draw(img)
        x0, y0 = random.randint(20, 150), random.randint(20, 150)
        w, h = random.randint(150, 320), random.randint(150, 320)
        x1, y1 = min(511, x0 + w), min(511, y0 + h)
        draw.rectangle([x0, y0, x1, y1], outline="yellow", width=4)
        draw.text((x0 + 5, y0 + 5), "synthetic", fill="white")
        img_path = output_dir / f"synthetic_debug_{s_idx}.png"
        img.save(img_path)
        synthetic_samples.append(
            {
                "image_path": str(img_path),
                "boxes": [[float(x0), float(y0), float(x1), float(y1)]],
                "labels": ["synthetic object"],
                "image_id": f"synthetic_{s_idx}",
                "source": "synthetic_debug",
                "created_at": datetime.now(datetime.timezone.utc).isoformat(),
            }
        )
    return synthetic_samples


for idx, row in tqdm(df_sample.iterrows(), total=len(df_sample), desc="Converting to JSON"):
    record = build_record(idx, row)
    if record is not None:
        dataset_list.append(record)

if not dataset_list:
    print("No valid samples found. Consider verifying the parquet file path or image root.")

with open(json_path, "w", encoding="utf-8") as f:
    json.dump(dataset_list, f, indent=2, ensure_ascii=False)

print(f"Saved {len(dataset_list)} samples to {json_path}")
if skipped_reasons:
    print("Skipped samples breakdown:")
    for reason, count in skipped_reasons.items():
        print(f"  - {reason}: {count}")

synthetic_samples = build_synthetic_debug_samples()
with open(synthetic_debug_path, "w", encoding="utf-8") as f:
    json.dump(synthetic_samples, f, indent=2)
print(f"Synthetic debug samples saved to {synthetic_debug_path} (not included in training by default).")


## 4. Set Up LoRA Configuration


In [None]:
# LoRA configuration
lora_config = LoraConfig(
    r=16,  # LoRA rank
    lora_alpha=32,  # LoRA alpha (scaling factor)
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  # Target modules for LoRA
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

print("LoRA Configuration:")
print(lora_config)


## 5. Load Model and Apply LoRA


In [None]:
# Model configuration
model_name = "IDEA-Research/Rex-Omni"
cache_dir = None

print(f"Loading model: {model_name}")

# Load model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    trust_remote_code=True
)

print("Model loaded successfully")

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()
model.config.use_cache = False


In [None]:
# Load processor and tokenizer
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    model_max_length=4096,
    padding_side="right",
    use_fast=False,
    trust_remote_code=True
)

print("Processor and tokenizer loaded")


## 6. Set Up Dataset and Data Collator


In [None]:
# Import dataset classes from finetuning module
from dataset.json_dataset import GroundingJsonDataset
from dataset.collator import DataCollatorForSupervisedDataset
from dataset.task_fns import GroundingTaskFn
from dataset.task_fns.task_prompts.grounding_task import GROUNDING_SINGLE_REGION_STAGE_XYXY
from engine.argument import DataArguments

# Set up data arguments
data_args = DataArguments()
data_args.image_processor = processor.image_processor
data_args.model_type = "qwen2.5vl"

# Image size constraints
min_pixels = 16 * 28 * 28
max_pixels = 2560 * 28 * 28

# Create task function
task_fn = GroundingTaskFn(
    task_prompts=GROUNDING_SINGLE_REGION_STAGE_XYXY,
    image_min_pixels=min_pixels,
    image_max_pixels=max_pixels,
)

print("Task function created")

In [None]:
# Create dataset
train_dataset = GroundingJsonDataset(
    json_file=str(json_path),
    tokenizer=tokenizer,
    data_args=data_args,
    image_min_pixels=min_pixels,
    image_max_pixels=max_pixels,
    task_fn=task_fn,
    system_message="You are a helpful assistant.",
    ori_box_format="xyxy",
    dataset_name="vrsbench_1000",
    max_length=4096,
)

print(f"Dataset created with {len(train_dataset)} samples")


In [None]:
# Quick dataset sanity check to catch grid/pixel mismatches early
import random
from collections import Counter

def dataset_health_check(dataset, sample_size: int = 12):
    if len(dataset) == 0:
        print("Dataset is empty. Please verify the JSON export step above.")
        return {"summary": Counter(), "reasons": Counter()}

    indices = random.sample(range(len(dataset)), min(sample_size, len(dataset)))
    summary = Counter()
    reasons = Counter()

    for idx in indices:
        try:
            sample = dataset[idx]
            pixel_values = sample.get("pixel_values")
            grid_thw = sample.get("image_grid_thw")
            if pixel_values is None or grid_thw is None:
                raise ValueError("missing pixel/grid data")
            if isinstance(grid_thw, list) and len(grid_thw) > 0 and isinstance(grid_thw[0], torch.Tensor):
                grid_tokens = int(grid_thw[0][0].item())
                patch_tokens = int(pixel_values.shape[0])
                if grid_tokens != patch_tokens:
                    raise ValueError(
                        f"grid tokens ({grid_tokens}) != pixel patches ({patch_tokens})"
                    )
            summary["ok"] += 1
        except Exception as exc:  # noqa: BLE001
            summary["failed"] += 1
            reasons[str(exc).split("\n")[0]] += 1

    print("Dataset health check summary:")
    print(f"  OK samples: {summary['ok']}")
    print(f"  Failed samples: {summary['failed']}")
    if reasons:
        print("  Failure reasons:")
        for reason, count in reasons.items():
            print(f"    - {reason}: {count}")

    return {"summary": summary, "reasons": reasons}

train_dataset_health = dataset_health_check(train_dataset)
print("Internal dataset stats:", train_dataset.get_health_report())



In [None]:
# Create data collator
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

print("Data collator created")


## 7. Set Up Training Arguments


In [None]:
# Training arguments
output_dir = project_root / "work_dirs" / "rexomni_lora_vrsbench_1000"
os.makedirs(output_dir, exist_ok=True)

# Adjust batch size based on Colab GPU (T4, V100, A100, etc.)
# T4: batch_size=1, V100: batch_size=2, A100: batch_size=4
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if gpu_memory_gb < 16:
    batch_size = 1
    grad_accum = 8
elif gpu_memory_gb < 32:
    batch_size = 2
    grad_accum = 4
else:
    batch_size = 4
    grad_accum = 2

print(f"GPU Memory: {gpu_memory_gb:.2f} GB")
print(f"Using batch_size={batch_size}, gradient_accumulation_steps={grad_accum}")

training_args = TrainingArguments(
    output_dir=str(output_dir),
    num_train_epochs=3,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=grad_accum,
    learning_rate=2e-4,  # Higher LR for LoRA
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0.01,
    max_grad_norm=1.0,
    bf16=torch.cuda.is_bf16_supported(),  # Use bfloat16 if supported
    fp16=not torch.cuda.is_bf16_supported(),  # Fallback to fp16
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,
    eval_strategy="no",
    save_strategy="steps",
    gradient_checkpointing=True,
    dataloader_num_workers=2,  # Reduce for Colab
    remove_unused_columns=False,
    report_to="none",  # Set to "wandb" if you want to use Weights & Biases
    run_name="rexomni_lora_vrsbench_1000",
)

print("Training arguments configured")


## 8. Create Trainer


In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,
    processing_class=tokenizer,
)

print("Trainer created")


## 9. Start Training


In [None]:
# Start training
print("Starting training...")
print(f"Output directory: {output_dir}")
print(f"Training samples: {len(train_dataset)}")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print("Dataset stats before training:", train_dataset.get_health_report())

try:
    trainer.train()
except RuntimeError as e:  # noqa: BLE001
    if "indexSelectLargeIndex" in str(e) or "device-side assert" in str(e):
        print("\n⚠ Training hit a CUDA indexing/device assert error.")
        print("Re-run the dataset health check cell above and inspect the reported failure reasons.")
        print("You can also lower `image_max_pixels` or filter problematic entries in the JSON export.")
    raise


## 10. Save the Fine-tuned Model


In [None]:
# Save the final model
final_model_dir = output_dir / "final_model"
os.makedirs(final_model_dir, exist_ok=True)

# Save LoRA adapters
model.save_pretrained(str(final_model_dir))

# Save tokenizer and processor
tokenizer.save_pretrained(str(final_model_dir))
processor.save_pretrained(str(final_model_dir))

print(f"Model saved to: {final_model_dir}")

# Optionally save to Google Drive for persistence
# Uncomment the following lines to save to Google Drive:
# from google.colab import drive
# drive.mount('/content/drive')
# drive_model_dir = Path("/content/drive/MyDrive/rexomni_lora_vrsbench_1000")
# import shutil
# shutil.copytree(final_model_dir, drive_model_dir, dirs_exist_ok=True)
# print(f"Model also saved to Google Drive: {drive_model_dir}")


## 11. Load and Test the Fine-tuned Model


In [None]:
# Load the base model
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    trust_remote_code=True
)

# Load LoRA adapters
from peft import PeftModel
fine_tuned_model = PeftModel.from_pretrained(base_model, str(final_model_dir))

# Merge LoRA weights (optional, for faster inference)
merged_model = fine_tuned_model.merge_and_unload()

print("Fine-tuned model loaded successfully")


In [None]:
# Save merged model for use with RexOmniWrapper
merged_model_dir = output_dir / "merged_model"
os.makedirs(merged_model_dir, exist_ok=True)
merged_model.save_pretrained(str(merged_model_dir))
tokenizer.save_pretrained(str(merged_model_dir))
processor.save_pretrained(str(merged_model_dir))

print(f"Merged model saved to: {merged_model_dir}")

# Download the model (optional - for Colab)
print("\nTo download the model, run:")
print("from google.colab import files")
print("import shutil")
print(f"shutil.make_archive('rexomni_lora_model', 'zip', '{merged_model_dir}')")
print("files.download('rexomni_lora_model.zip')")

print("\nYou can now use this model with RexOmniWrapper:")
print(f"rex = RexOmniWrapper(model_path='{merged_model_dir}', backend='transformers')")


## 12. Single Test Inference

Run a single test inference to verify the fine-tuned model works correctly.


In [None]:
# Single Test Inference using RexOmniWrapper
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np
import json
from rex_omni import RexOmniWrapper
from transformers import AutoProcessor

print("="*60)
print("SINGLE TEST INFERENCE")
print("="*60)

# Setup paths
model_path = str(final_model_dir) if 'final_model_dir' in locals() else str(output_dir / "final_model")
base_model_path = "IDEA-Research/Rex-Omni"

# Check if merged model exists, otherwise use final_model_dir
merged_model_path = str(output_dir / "merged_model")
if os.path.exists(merged_model_path):
    model_path = merged_model_path
    print(f"Using merged model: {model_path}")
else:
    print(f"Using LoRA model: {model_path}")

# Check CUDA availability and health
print(f"\nCUDA available: {torch.cuda.is_available()}")
cuda_healthy = False
if torch.cuda.is_available():
    try:
        # Test CUDA with a simple operation
        test_tensor = torch.zeros(1).cuda()
        del test_tensor
        torch.cuda.synchronize()
        cuda_healthy = True
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        print("✓ CUDA is healthy")
    except RuntimeError as e:
        if "device-side assert" in str(e) or "CUDA error" in str(e):
            print("⚠ CUDA is available but in a bad state (likely from previous error)")
            print("  Will use CPU instead. To fix CUDA, restart the Python kernel.")
            cuda_healthy = False
        else:
            raise
else:
    print("CUDA not available, will use CPU")

# Load model using RexOmniWrapper
print("\nLoading fine-tuned model with RexOmniWrapper...")
try:
    rex = RexOmniWrapper(
        model_path=model_path,
        backend="transformers",
        max_tokens=2048,
        temperature=0.0,
        top_p=0.05,
        top_k=1,
        repetition_penalty=1.05,
        attn_implementation="sdpa",  # Use sdpa for compatibility (works on both CPU and CUDA)
    )
    print("✓ Model loaded successfully")
    
    # CRITICAL: Ensure model is on CPU before resizing embeddings
    # (RexOmniWrapper might have already moved it to CUDA)
    print("\nEnsuring model is on CPU for safe embedding resize...")
    try:
        # Check current device
        if hasattr(rex.model, 'model'):
            first_param = next(rex.model.model.parameters())
            current_device = first_param.device
            print(f"  Current model device: {current_device}")
            
            if current_device.type == 'cuda':
                print("  Moving model to CPU for safe embedding resize...")
                rex.model = rex.model.cpu()
                # Clear any CUDA errors
                if torch.cuda.is_available():
                    try:
                        torch.cuda.synchronize()
                        torch.cuda.empty_cache()
                    except:
                        pass  # Ignore CUDA errors during cleanup
                print("  ✓ Model moved to CPU")
    except Exception as e:
        print(f"  ⚠ Warning during device check: {e}")
        # Try to force CPU anyway
        try:
            rex.model = rex.model.cpu()
        except:
            pass
    
    # FIX VOCABULARY MISMATCH FIRST (before moving to device to avoid CUDA errors)
    print("\nFixing vocabulary mismatch (critical - must be done on CPU)...")
    model_vocab_size = rex.model.config.text_config.vocab_size
    tokenizer_vocab_size = len(rex.processor.tokenizer)
    
    print(f"Model vocab size: {model_vocab_size}")
    print(f"Tokenizer vocab size: {tokenizer_vocab_size}")
    
    if tokenizer_vocab_size > model_vocab_size:
        print(f"⚠ Tokenizer vocab ({tokenizer_vocab_size}) > Model vocab ({model_vocab_size})")
        print("Resizing embedding layers to match tokenizer...")
        
        # Resize INPUT embeddings
        old_input_embedding = rex.model.get_input_embeddings()
        old_input_weight = old_input_embedding.weight.data.clone()
        
        from torch.nn import Embedding, Linear
        new_input_embedding = Embedding(
            tokenizer_vocab_size,
            old_input_embedding.embedding_dim,
            padding_idx=old_input_embedding.padding_idx if hasattr(old_input_embedding, 'padding_idx') else None
        )
        
        # Copy old weights
        new_input_embedding.weight.data[:model_vocab_size] = old_input_weight
        
        # Initialize new tokens with mean of existing embeddings
        mean_embedding = old_input_weight.mean(dim=0)
        new_input_embedding.weight.data[model_vocab_size:] = mean_embedding.unsqueeze(0).expand(
            tokenizer_vocab_size - model_vocab_size, -1
        )
        
        # CRITICAL: Preserve image token embeddings if they exist
        # Image tokens might be at specific IDs that need special handling
        if hasattr(rex.processor.tokenizer, 'image_start_id'):
            img_start_id = rex.processor.tokenizer.image_start_id
            if img_start_id < model_vocab_size:
                # Image token already exists, preserve its embedding
                new_input_embedding.weight.data[img_start_id] = old_input_weight[img_start_id]
            elif img_start_id < tokenizer_vocab_size:
                # Image token is in the new range, initialize it properly
                new_input_embedding.weight.data[img_start_id] = old_input_weight.mean(dim=0)
        
        if hasattr(rex.processor.tokenizer, 'image_end_id'):
            img_end_id = rex.processor.tokenizer.image_end_id
            if img_end_id < model_vocab_size:
                new_input_embedding.weight.data[img_end_id] = old_input_weight[img_end_id]
            elif img_end_id < tokenizer_vocab_size:
                new_input_embedding.weight.data[img_end_id] = old_input_weight.mean(dim=0)
        
        rex.model.set_input_embeddings(new_input_embedding)
        print(f"  ✓ Input embedding resized from {model_vocab_size} to {tokenizer_vocab_size}")
        
        # Resize OUTPUT embeddings (lm_head) - CRITICAL!
        if hasattr(rex.model, 'lm_head'):
            old_lm_head = rex.model.lm_head
            old_lm_head_weight = old_lm_head.weight.data.clone()
            
            new_lm_head = Linear(
                old_lm_head.in_features,
                tokenizer_vocab_size,
                bias=old_lm_head.bias is not None
            )
            
            # Copy old weights
            new_lm_head.weight.data[:model_vocab_size] = old_lm_head_weight
            
            # Initialize new tokens with mean of existing weights
            mean_weight = old_lm_head_weight.mean(dim=0)
            new_lm_head.weight.data[model_vocab_size:] = mean_weight.unsqueeze(0).expand(
                tokenizer_vocab_size - model_vocab_size, -1
            )
            
            # Copy bias if exists
            if old_lm_head.bias is not None:
                old_bias = old_lm_head.bias.data.clone()
                new_lm_head.bias.data[:model_vocab_size] = old_bias
                mean_bias = old_bias.mean()
                new_lm_head.bias.data[model_vocab_size:] = mean_bias
            
            rex.model.lm_head = new_lm_head
            print(f"  ✓ Output embedding (lm_head) resized from {model_vocab_size} to {tokenizer_vocab_size}")
        
        # Update config
        rex.model.config.text_config.vocab_size = tokenizer_vocab_size
        rex.model.config.vocab_size = tokenizer_vocab_size
        
        print(f"✓ All embedding layers resized from {model_vocab_size} to {tokenizer_vocab_size}")
    elif tokenizer_vocab_size < model_vocab_size:
        print(f"✓ Tokenizer vocab ({tokenizer_vocab_size}) <= Model vocab ({model_vocab_size}) - OK")
    else:
        print("✓ Vocab sizes match!")
    
    # IMPORTANT: Ensure processor uses the model's tokenizer (not a mismatched one)
    # The processor must match the model's tokenizer for image tokens to work correctly
    print("\nVerifying processor-tokenizer compatibility...")
    try:
        # Check if processor tokenizer matches model's expected tokenizer
        # The processor should use the same tokenizer that the model was trained with
        processor_vocab_size = len(rex.processor.tokenizer)
        model_vocab_size_after_resize = rex.model.config.text_config.vocab_size
        
        if processor_vocab_size != model_vocab_size_after_resize:
            print(f"⚠ Processor vocab ({processor_vocab_size}) != Model vocab ({model_vocab_size_after_resize})")
            print("  This is expected after resizing - processor should still work")
        
        # Verify image token IDs are valid
        if hasattr(rex.processor.tokenizer, 'image_start_id') and hasattr(rex.processor.tokenizer, 'image_end_id'):
            img_start_id = rex.processor.tokenizer.image_start_id
            img_end_id = rex.processor.tokenizer.image_end_id
            if img_start_id >= model_vocab_size_after_resize or img_end_id >= model_vocab_size_after_resize:
                print(f"⚠ Image token IDs ({img_start_id}, {img_end_id}) exceed vocab size!")
                print("  This might cause image token insertion issues")
            else:
                print(f"✓ Image token IDs valid: start={img_start_id}, end={img_end_id}")
        
        print("✓ Processor-tokenizer compatibility verified")
    except Exception as e:
        print(f"⚠ Could not verify processor compatibility: {e}")
    
    # Patch processor and embedding to clamp token IDs as safety measure
    # BUT: Don't clamp image placeholder tokens - they need to be preserved
    _original_call = rex.processor.__call__
    def patched_call(*args, **kwargs):
        result = _original_call(*args, **kwargs)
        if isinstance(result, dict) and 'input_ids' in result:
            if isinstance(result['input_ids'], torch.Tensor):
                actual_vocab_size = rex.model.config.text_config.vocab_size
                # Get image token IDs if they exist
                img_token_ids = set()
                if hasattr(rex.processor.tokenizer, 'image_start_id'):
                    img_token_ids.add(rex.processor.tokenizer.image_start_id)
                if hasattr(rex.processor.tokenizer, 'image_end_id'):
                    img_token_ids.add(rex.processor.tokenizer.image_end_id)
                if hasattr(rex.processor.tokenizer, 'image_pad_id'):
                    img_token_ids.add(rex.processor.tokenizer.image_pad_id)
                
                # Clamp non-image tokens, but preserve image tokens
                input_ids = result['input_ids']
                mask = torch.ones_like(input_ids, dtype=torch.bool)
                for img_id in img_token_ids:
                    if img_id < actual_vocab_size:
                        mask = mask & (input_ids != img_id)
                
                # Only clamp non-image tokens
                clamped_ids = torch.clamp(input_ids, 0, actual_vocab_size - 1)
                result['input_ids'] = torch.where(mask, clamped_ids, input_ids)
        return result
    rex.processor.__call__ = patched_call
    
    # Patch embedding layer forward
    original_embedding_forward = rex.model.get_input_embeddings().forward
    def safe_embedding_forward(input_ids):
        actual_vocab_size = rex.model.config.text_config.vocab_size
        # Preserve image tokens if they exist
        img_token_ids = set()
        if hasattr(rex.processor.tokenizer, 'image_start_id'):
            img_token_ids.add(rex.processor.tokenizer.image_start_id)
        if hasattr(rex.processor.tokenizer, 'image_end_id'):
            img_token_ids.add(rex.processor.tokenizer.image_end_id)
        if hasattr(rex.processor.tokenizer, 'image_pad_id'):
            img_token_ids.add(rex.processor.tokenizer.image_pad_id)
        
        # Clamp non-image tokens
        mask = torch.ones_like(input_ids, dtype=torch.bool)
        for img_id in img_token_ids:
            if img_id < actual_vocab_size:
                mask = mask & (input_ids != img_id)
        
        clamped_ids = torch.clamp(input_ids, 0, actual_vocab_size - 1)
        safe_input_ids = torch.where(mask, clamped_ids, input_ids)
        return original_embedding_forward(safe_input_ids)
    rex.model.get_input_embeddings().forward = safe_embedding_forward
    print("✓ Processor and embedding layer patched for safety (preserving image tokens)")
    
    # Now move model to device (vocabulary is fixed)
    # Use CUDA only if it's healthy, otherwise use CPU
    device = torch.device("cuda" if (torch.cuda.is_available() and cuda_healthy) else "cpu")
    print(f"\nMoving model to {device}...")
    
    # Clear CUDA cache if available and healthy
    if device.type == "cuda" and cuda_healthy:
        try:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            print("  CUDA cache cleared")
        except RuntimeError as e:
            if "device-side assert" in str(e) or "CUDA error" in str(e):
                print("  ⚠ CUDA error during cache clear, switching to CPU")
                device = torch.device("cpu")
                cuda_healthy = False
            else:
                raise
    
    # Change attention implementation if needed
    if device.type == "cpu" or not cuda_healthy:
        print("⚠ Changing attention implementation to sdpa (CPU or no CUDA)...")
        rex.model.config._attn_implementation = "sdpa"
        if hasattr(rex.model.config, 'text_config'):
            rex.model.config.text_config._attn_implementation = "sdpa"
        if hasattr(rex.model.config, 'vision_config'):
            rex.model.config.vision_config._attn_implementation = "sdpa"
        
        if hasattr(rex.model, 'visual') and hasattr(rex.model.visual, 'blocks'):
            for block in rex.model.visual.blocks:
                if hasattr(block, 'attn') and hasattr(block.attn, 'config'):
                    block.attn.config._attn_implementation = "sdpa"
        
        if hasattr(rex.model, 'model') and hasattr(rex.model.model, 'layers'):
            for layer in rex.model.model.layers:
                if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'config'):
                    layer.self_attn.config._attn_implementation = "sdpa"
        print("✓ Attention implementation changed to sdpa")
    
    # Move model to device with error handling
    try:
        # Set model to eval mode before moving (safer)
        rex.model.eval()
        
        # Move model to device
        rex.model = rex.model.to(device)
        print(f"✓ Model moved to {device}")
        
        # Verify device placement
        if hasattr(rex.model, 'model'):
            first_param = next(rex.model.model.parameters())
            print(f"✓ Model device verified: {first_param.device}")
            
    except RuntimeError as e:
        if "CUDA" in str(e) or "device-side assert" in str(e):
            print(f"⚠ CUDA error during model move: {e}")
            print("  This might be due to vocabulary mismatch or CUDA state issues.")
            print("  Trying to clear CUDA cache and retry...")
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            
            # Try moving again
            try:
                rex.model = rex.model.to(device)
                print(f"✓ Model moved to {device} on retry")
            except Exception as e2:
                print(f"✗ Failed to move model to {device}: {e2}")
                print("  Falling back to CPU...")
                device = torch.device("cpu")
                rex.model = rex.model.to(device)
                print(f"✓ Model moved to CPU as fallback")
        else:
            raise
        
except Exception as e:
    print(f"✗ Error loading model: {e}")
    raise

# Get a test sample from the dataset
print("\nPreparing test sample...")
test_idx = 0
test_row = df_sample.iloc[test_idx]
test_image_path = test_row['image_path']

# Handle image path
if not os.path.isabs(test_image_path):
    possible_paths = [
        os.path.join(project_root, test_image_path),
        os.path.join(str(project_root), test_image_path),
        test_image_path
    ]
    for path in possible_paths:
        if os.path.exists(path):
            test_image_path = path
            break
    else:
        print(f"⚠ Warning: Image not found at {test_image_path}")
        print("Trying to use first available image from dataset...")
        # Try to find any image
        for idx, row in df_sample.iterrows():
            img_path = row['image_path']
            for root in [str(project_root), ""]:
                full_path = os.path.join(root, img_path) if root else img_path
                if os.path.exists(full_path):
                    test_image_path = full_path
                    test_row = row
                    break
            if os.path.exists(test_image_path):
                break

# Load test image
try:
    test_image = Image.open(test_image_path).convert("RGB")
    print(f"✓ Test image loaded: {test_image_path}")
    print(f"  Image size: {test_image.size}")
except Exception as e:
    print(f"✗ Error loading image: {e}")
    raise

# Extract categories from the test sample
try:
    objs = json.loads(test_row['objects']) if isinstance(test_row['objects'], str) else test_row['objects']
    categories = list(set([o.get('obj_cls', 'object') for o in objs if isinstance(o, dict)]))
    if not categories:
        categories = ["object"]
    print(f"  Categories to detect: {categories}")
except Exception as e:
    print(f"⚠ Could not extract categories: {e}")
    categories = ["object"]

# Test processor image token insertion before inference
print("\nTesting processor image token insertion...")
try:
    # Create a test message with image (same format RexOmniWrapper uses)
    test_messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {
            "role": "user",
            "content": [
                {"type": "image", "image": test_image},
                {"type": "text", "text": "test"}
            ]
        }
    ]
    
    # Apply chat template
    text = rex.processor.apply_chat_template(
        test_messages, tokenize=False, add_generation_prompt=True
    )
    
    # Process with processor
    test_inputs = rex.processor(
        text=[text],
        images=[test_image],
        padding=True,
        return_tensors="pt"
    )

    if "image_grid_thw" in test_inputs:
        print(f"  Processor grid_thw: {test_inputs['image_grid_thw'][0].tolist()}")
    
    # Check for image tokens
    input_ids = test_inputs['input_ids'][0]
    
    # Check for image placeholder tokens
    img_token_count = 0
    if hasattr(rex.processor.tokenizer, 'image_start_id'):
        img_start_id = rex.processor.tokenizer.image_start_id
        img_end_id = getattr(rex.processor.tokenizer, 'image_end_id', None)
        img_pad_id = getattr(rex.processor.tokenizer, 'image_pad_id', None)
        
        img_token_count = (input_ids == img_start_id).sum().item()
        if img_end_id:
            img_token_count += (input_ids == img_end_id).sum().item()
        if img_pad_id:
            img_token_count += (input_ids == img_pad_id).sum().item()
        
        print(f"  Image token IDs: start={img_start_id}, end={img_end_id}, pad={img_pad_id}")
        print(f"  Image tokens found in input_ids: {img_token_count}")
        
        if img_token_count == 0:
            print("  ⚠ WARNING: No image tokens found! This will cause the error.")
            print("  The processor might not be inserting image tokens correctly.")
            print("  This could be due to processor-tokenizer mismatch.")
            raise RuntimeError(
                "Processor did not insert any <|vision|> tokens; aborting inference before model.generate to avoid CUDA asserts."
            )
        else:
            print(f"  ✓ Image tokens are being inserted correctly ({img_token_count} found)")
    else:
        print("  ⚠ Could not find image token IDs in tokenizer")
        print("  This might indicate a processor configuration issue")
        
except Exception as e:
    print(f"  ⚠ Error testing processor: {e}")

# Run inference
print("\n" + "="*60)
print("Running inference...")
print("="*60)

try:
    results = rex.inference(
        images=test_image,
        task="detection",
        categories=categories
    )
    
    result = results[0] if isinstance(results, list) else results
    
    print("\n" + "="*60)
    print("INFERENCE RESULT")
    print("="*60)
    print(f"Success: {result.get('success', False)}")
    print(f"Inference time: {result.get('inference_time', 0):.2f}s")
    print(f"Output tokens: {result.get('num_output_tokens', 0)}")
    print(f"\nRaw output:\n{result.get('raw_output', 'N/A')}")
    
    # Extract predictions
    preds = result.get('extracted_predictions', {})
    if preds:
        print(f"\nExtracted predictions:")
        total_boxes = 0
        for cat, boxes in preds.items():
            print(f"  {cat}: {len(boxes)} box(es)")
            total_boxes += len(boxes)
        print(f"  Total: {total_boxes} boxes")
        
        # Visualize the result
        fig, ax = plt.subplots(1, 1, figsize=(12, 8))
        ax.imshow(test_image)
        ax.axis('off')
        
        # Draw bounding boxes
        colors = plt.cm.tab10(np.linspace(0, 1, len(preds)))
        color_idx = 0
        for cat, boxes in preds.items():
            for box in boxes:
                if isinstance(box, dict) and 'bbox' in box:
                    bbox = box['bbox']
                    if len(bbox) == 4:
                        x0, y0, x1, y1 = bbox
                        rect = patches.Rectangle(
                            (x0, y0), x1 - x0, y1 - y0,
                            linewidth=2, 
                            edgecolor=colors[color_idx % len(colors)], 
                            facecolor='none',
                            label=cat
                        )
                        ax.add_patch(rect)
                        ax.text(x0, y0 - 5, cat, color=colors[color_idx % len(colors)], 
                               fontsize=10, weight='bold', 
                               bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))
            color_idx += 1
        
        ax.set_title(f"Test Inference Result\nCategories: {', '.join(categories)}", fontsize=12)
        if preds:
            ax.legend(loc='upper right', fontsize=8)
        plt.tight_layout()
        plt.show()
    else:
        print("\n⚠ No predictions extracted from output")
        print("This might be normal if the model output format differs from expected")
        
        # Still show the image
        fig, ax = plt.subplots(1, 1, figsize=(12, 8))
        ax.imshow(test_image)
        ax.axis('off')
        ax.set_title(f"Test Image\n(No boxes detected)\nCategories: {', '.join(categories)}", fontsize=12)
        plt.tight_layout()
        plt.show()
    
    print("\n" + "="*60)
    print("✓ Test inference completed successfully!")
    print("="*60)
    
except Exception as e:
    error_msg = str(e)
    print(f"\n✗ Inference failed: {error_msg}")
    
    if "token" in error_msg.lower() or "vocab" in error_msg.lower():
        print("\nTroubleshooting tips:")
        print("1. The tokenizer vocabulary might not match the model")
        print("2. Try using the base model's tokenizer (already attempted)")
        print("3. Check if the model was saved correctly with its tokenizer")
        print("4. You may need to re-save the model with the correct tokenizer")
    
    raise


## Summary

This notebook has:
1. ✅ Set up the environment for Google Colab
2. ✅ Loaded 1000 samples from VRSBench dataset
3. ✅ Converted data to TSV format required for training
4. ✅ Set up LoRA configuration for parameter-efficient fine-tuning
5. ✅ Fine-tuned the Rex-Omni model with LoRA
6. ✅ Saved the fine-tuned model

### Important Notes for Colab:
- **Session Timeout**: Colab sessions timeout after inactivity. Save your work frequently!
- **GPU Runtime**: Make sure you're using a GPU runtime (Runtime → Change runtime type → GPU)
- **Data Persistence**: Files in `/content` are deleted when the session ends. Save important files to Google Drive or download them.
- **Memory**: If you run out of memory, reduce `batch_size` or `gradient_accumulation_steps`

### Next Steps:
- Evaluate the model on validation data
- Adjust hyperparameters (learning rate, LoRA rank, etc.)
- Train for more epochs if needed
- Download or save the model to Google Drive for persistence
- Use the fine-tuned model for inference
