In [None]:
# Install core dependencies with compatible versions

print("Installing dependencies...")

# Install base packages
%uv pip install -q psutil setuptools

# Install compatible transformers and peft versions
%uv pip install -q transformers==4.57.3
%uv pip install -q "peft==0.18.0"
%uv pip install -q "bitsandbytes>=0.44.1"
%uv pip install -q torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

# Install flash attention
%uv pip install flash-attn==2.7.4.post1 --no-build-isolation

# Install GRPO-specific dependencies
%uv pip install -q "ray[default]==2.44.0"  # Ray for distributed training
%uv pip install -q vllm==0.8.4  # vLLM for efficient generation
%uv pip install -q omegaconf==2.3.0 jinja2  # Config and template

# Install other requirements
%uv pip install -q accelerate==1.10.1
%uv pip install -q mmengine==0.10.7 ujson==5.11.0

# Install vision dependencies
%uv pip install -q Pillow pandas matplotlib numpy tqdm fastparquet pyarrow
%uv pip install -q qwen-vl-utils pycocotools

# Install liger-kernel
%uv pip install -q liger-kernel

print("\n✓ Dependencies installed!")

import transformers
import ray
print(f"✓ Transformers: {transformers.__version__}")
print(f"✓ Ray: {ray.__version__}")


## 1. Setup Environment


In [None]:
import os
import sys
from pathlib import Path

# Determine project root
PROJECT_ROOT = Path.cwd()
while PROJECT_ROOT.name != 'Rex-Omni' and PROJECT_ROOT.parent != PROJECT_ROOT:
    PROJECT_ROOT = PROJECT_ROOT.parent
if PROJECT_ROOT.name != 'Rex-Omni':
    PROJECT_ROOT = Path.cwd()
os.chdir(PROJECT_ROOT)

# Add finetuning to path
FINETUNING_PATH = PROJECT_ROOT / 'finetuning'
for p in [str(PROJECT_ROOT), str(FINETUNING_PATH)]:
    if p not in sys.path:
        sys.path.insert(0, p)

print(f"✓ Project root: {PROJECT_ROOT}")
print(f"✓ Finetuning path: {FINETUNING_PATH}")


In [None]:
# Core imports
import json
import base64
import io
from tqdm import tqdm

import torch
import pandas as pd
import numpy as np
from PIL import Image

print(f"✓ PyTorch: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")


## 2. Configuration


In [None]:
# ==================== MODEL ====================
MODEL_NAME = "IDEA-Research/Rex-Omni"

# ==================== DATA ====================
VRSBENCH_PARQUET = PROJECT_ROOT / "vrsbench_val_data.parquet"
VRSBENCH_IMAGES_DIR = PROJECT_ROOT / "Images_validation" / "Images_val"  # Update this!
TSV_OUTPUT_DIR = PROJECT_ROOT / "data" / "vrsbench_grpo"
NUM_SAMPLES = 1000  # Number of samples for GRPO training

# Image processing (same as SFT)
MIN_PIXELS = 16 * 28 * 28   # 12,544
MAX_PIXELS = 2560 * 28 * 28  # 2,007,040

# ==================== GRPO TRAINING ====================
OUTPUT_DIR = PROJECT_ROOT / "work_dirs" / "grpo_vrsbench"
N_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 1

# GRPO hyperparameters
ROLLOUT_N = 8  # Number of responses per prompt
GLOBAL_BATCH_SIZE = min(64, NUM_SAMPLES // 4)
ROLLOUT_BATCH_SIZE = min(64, NUM_SAMPLES // 4)
LEARNING_RATE = 1e-6
KL_COEF = 0.01
TOTAL_EPOCHS = 1

print("GRPO Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Samples: {NUM_SAMPLES}")
print(f"  GPUs: {N_GPUS}")
print(f"  Rollout N: {ROLLOUT_N} (responses per prompt)")
print(f"  Global batch size: {GLOBAL_BATCH_SIZE}")
print(f"  Output: {OUTPUT_DIR}")


## 3. Convert VRSBench to TSV Format (Same as SFT)

The GRPO training uses the same TSV format as SFT training.


In [None]:
def convert_vrsbench_to_tsv(
    parquet_path: Path,
    images_dir: Path,
    output_dir: Path,
    num_samples: int = 1000,
    project_root: Path = None
):
    """
    Convert VRSBench parquet to TSV format for GRPO training.
    Same format as SFT - used by TSVRLHFDataset.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Output file paths
    img_tsv = output_dir / "train.images.tsv"
    ann_tsv = output_dir / "train.annotations.tsv"
    ann_idx = output_dir / "train.annotations.tsv.lineidx"
    
    print(f"Loading {parquet_path}...")
    df = pd.read_parquet(parquet_path)
    print(f"  Total samples: {len(df)}")
    
    df_sample = df.head(num_samples).copy()
    print(f"  Using: {len(df_sample)} samples")
    
    converted = 0
    skipped = 0
    errors = []
    
    img_offset = 0
    ann_offset = 0
    
    with open(img_tsv, 'w', encoding='utf-8') as f_img, \
         open(ann_tsv, 'w', encoding='utf-8') as f_ann, \
         open(ann_idx, 'w', encoding='utf-8') as f_idx:
        
        for idx, row in tqdm(df_sample.iterrows(), total=len(df_sample), desc="Converting"):
            try:
                # Find image
                image_path = row.get('image_path', '')
                if not image_path:
                    skipped += 1
                    continue
                
                img_path = None
                candidates = [
                    Path(image_path),
                    images_dir / Path(image_path).name,
                    project_root / image_path.lstrip('./') if project_root else None,
                ]
                for candidate in candidates:
                    if candidate and candidate.exists():
                        img_path = candidate
                        break
                
                if img_path is None:
                    if len(errors) < 5:
                        errors.append(f"Image not found: {image_path}")
                    skipped += 1
                    continue
                
                # Load and encode image
                pil_img = Image.open(img_path).convert('RGB')
                img_width, img_height = pil_img.size
                
                buffer = io.BytesIO()
                pil_img.save(buffer, format='JPEG', quality=95)
                img_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
                
                # Parse objects
                objects_data = row.get('objects')
                if objects_data is None:
                    skipped += 1
                    continue
                
                if isinstance(objects_data, str):
                    objects_list = json.loads(objects_data)
                else:
                    objects_list = objects_data
                
                if not isinstance(objects_list, list) or len(objects_list) == 0:
                    skipped += 1
                    continue
                
                # Build boxes
                boxes = []
                for obj in objects_list:
                    if not isinstance(obj, dict):
                        continue
                    
                    bbox_norm = obj.get('obj_coord')
                    if not bbox_norm or len(bbox_norm) != 4:
                        continue
                    
                    # Convert normalized to absolute coordinates
                    x0 = bbox_norm[0] * img_width
                    y0 = bbox_norm[1] * img_height
                    x1 = bbox_norm[2] * img_width
                    y1 = bbox_norm[3] * img_height
                    
                    phrase = obj.get('referring_sentence', '') or obj.get('obj_cls', '')
                    if not phrase:
                        continue
                    
                    boxes.append({'bbox': [x0, y0, x1, y1], 'phrase': str(phrase)})
                
                if len(boxes) == 0:
                    skipped += 1
                    continue
                
                # Write to TSV
                f_idx.write(f"{ann_offset}\n")
                
                img_line = f"{img_offset}\t{img_b64}\n"
                f_img.write(img_line)
                img_line_bytes = len(img_line.encode('utf-8'))
                
                ann_json = json.dumps({"boxes": boxes}, ensure_ascii=False)
                ann_line = f"{img_offset}\t{ann_json}\n"
                f_ann.write(ann_line)
                ann_line_bytes = len(ann_line.encode('utf-8'))
                
                img_offset += img_line_bytes
                ann_offset += ann_line_bytes
                converted += 1
                
            except Exception as e:
                if len(errors) < 5:
                    errors.append(f"Row {idx}: {str(e)}")
                skipped += 1
    
    print(f"\n{'='*50}")
    print(f"Conversion Complete!")
    print(f"  ✓ Converted: {converted}")
    print(f"  ✗ Skipped: {skipped}")
    
    if errors:
        print(f"\nErrors:")
        for err in errors:
            print(f"  - {err}")
    
    if converted == 0:
        raise RuntimeError(f"No samples converted! Check image paths.")
    
    return converted


In [None]:
# Run conversion
if not VRSBENCH_IMAGES_DIR.exists():
    print(f"⚠️  VRSBench images not found: {VRSBENCH_IMAGES_DIR}")
    print("Please update VRSBENCH_IMAGES_DIR in configuration.")
else:
    num_converted = convert_vrsbench_to_tsv(
        parquet_path=VRSBENCH_PARQUET,
        images_dir=VRSBENCH_IMAGES_DIR,
        output_dir=TSV_OUTPUT_DIR,
        num_samples=NUM_SAMPLES,
        project_root=PROJECT_ROOT
    )
    print(f"\n✓ Ready for GRPO training with {num_converted} samples")


## 4. Create GRPO Configuration

GRPO uses a Python config file that specifies dataset and task function.


In [None]:
# Create GRPO config file for VRSBench
grpo_config_content = f'''"""GRPO Config for VRSBench Dataset"""
from dataset.task_fns import GroundingTaskFn
from dataset.task_fns.task_prompts.grounding_task import GROUNDING_SINGLE_REGION_STAGE_XYXY
from verl.utils.dataset import TSVRLHFDataset

min_pixels = {MIN_PIXELS}
max_pixels = {MAX_PIXELS}

grounding_data = dict(
    type=TSVRLHFDataset,
    image_tsv_file="{TSV_OUTPUT_DIR / 'train.images.tsv'}",
    anno_tsv_file="{TSV_OUTPUT_DIR / 'train.annotations.tsv'}",
    anno_idx_file="{TSV_OUTPUT_DIR / 'train.annotations.tsv.lineidx'}",
    min_pixels=min_pixels,
    max_pixels=max_pixels,
    task_fn=dict(
        type=GroundingTaskFn,
        task_prompts=GROUNDING_SINGLE_REGION_STAGE_XYXY,
        image_min_pixels=min_pixels,
        image_max_pixels=max_pixels,
    ),
    dataset_name="vrsbench_grounding",
    reward_name="box_iou",  # Uses IoU-based reward
)

train_dataset = [
    grounding_data,
]
'''

# Write config file
config_dir = FINETUNING_PATH / "configs"
config_dir.mkdir(parents=True, exist_ok=True)
config_path = config_dir / "grpo_vrsbench.py"

with open(config_path, 'w') as f:
    f.write(grpo_config_content)

print(f"✓ GRPO config written to: {config_path}")
print("\nConfig contents:")
print("-" * 50)
print(grpo_config_content)


## 5. Launch GRPO Training

GRPO training uses Ray for distributed training. There are two ways to run it:

**Option A**: Run via shell command (recommended for multi-GPU)
**Option B**: Run directly in notebook (for debugging)


In [None]:
# Generate the training command
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

training_command = f'''cd {FINETUNING_PATH} && \\
python3 -m verl.trainer.main \\
    config=verl/configs/config.yaml \\
    data.config_path="configs/grpo_vrsbench.py" \\
    data.format_prompt=verl/configs/r1v_format.jinja \\
    worker.actor.model.model_path={MODEL_NAME} \\
    trainer.experiment_name="grpo_vrsbench" \\
    trainer.n_gpus_per_node={N_GPUS} \\
    worker.actor.global_batch_size={GLOBAL_BATCH_SIZE} \\
    data.rollout_batch_size={ROLLOUT_BATCH_SIZE} \\
    worker.actor.micro_batch_size_per_device_for_update=2 \\
    worker.actor.micro_batch_size_per_device_for_experience=4 \\
    worker.rollout.n={ROLLOUT_N} \\
    worker.rollout.tensor_parallel_size={min(2, N_GPUS)} \\
    algorithm.kl_coef={KL_COEF} \\
    worker.actor.optim.lr={LEARNING_RATE} \\
    trainer.total_epochs={TOTAL_EPOCHS} \\
    trainer.save_checkpoint_path="{OUTPUT_DIR}" \\
    trainer.save_freq=50 \\
    trainer.logger="[\\"console\\"]"
'''

print("=" * 60)
print("OPTION A: Run via Terminal (Recommended)")
print("=" * 60)
print("\nCopy and run this command in your terminal:\n")
print(training_command)

# Also save as shell script
script_path = PROJECT_ROOT / "run_grpo_vrsbench.sh"
with open(script_path, 'w') as f:
    f.write("#!/bin/bash\n")
    f.write(f"export OUTPUT_PATH=\"{OUTPUT_DIR}\"\n")
    f.write("export DEBUG_MODE=\"true\"\n")
    f.write("export PYTHONUNBUFFERED=1\n")
    f.write("export RAY_DISABLE_IMPORT_WARNING=1\n\n")
    f.write(training_command.replace(" \\\n", " \\\n    "))

print(f"\n✓ Shell script saved to: {script_path}")
print(f"  Run with: bash {script_path}")


### Option B: Run Training from Notebook

⚠️ **Warning**: This requires multiple GPUs and significant memory. Only run if you have the hardware.


In [None]:
# OPTION B: Run training directly (uncomment to run)
# This cell launches GRPO training - requires multi-GPU setup

RUN_TRAINING = False  # Set to True to run training

if RUN_TRAINING:
    import subprocess
    import os
    
    # Set environment variables
    env = os.environ.copy()
    env["OUTPUT_PATH"] = str(OUTPUT_DIR)
    env["DEBUG_MODE"] = "true"
    env["PYTHONUNBUFFERED"] = "1"
    env["RAY_DISABLE_IMPORT_WARNING"] = "1"
    env["RAY_ADDRESS"] = ""
    env["RAY_CLIENT_MODE"] = ""
    
    # Change to finetuning directory
    os.chdir(FINETUNING_PATH)
    
    print("Starting GRPO training...")
    print(f"Output directory: {OUTPUT_DIR}")
    print(f"GPUs: {N_GPUS}")
    print("-" * 50)
    
    # Run training
    cmd = [
        "python3", "-m", "verl.trainer.main",
        "config=verl/configs/config.yaml",
        f"data.config_path=configs/grpo_vrsbench.py",
        "data.format_prompt=verl/configs/r1v_format.jinja",
        f"worker.actor.model.model_path={MODEL_NAME}",
        "trainer.experiment_name=grpo_vrsbench",
        f"trainer.n_gpus_per_node={N_GPUS}",
        f"worker.actor.global_batch_size={GLOBAL_BATCH_SIZE}",
        f"data.rollout_batch_size={ROLLOUT_BATCH_SIZE}",
        "worker.actor.micro_batch_size_per_device_for_update=2",
        "worker.actor.micro_batch_size_per_device_for_experience=4",
        f"worker.rollout.n={ROLLOUT_N}",
        f"worker.rollout.tensor_parallel_size={min(2, N_GPUS)}",
        f"algorithm.kl_coef={KL_COEF}",
        f"worker.actor.optim.lr={LEARNING_RATE}",
        f"trainer.total_epochs={TOTAL_EPOCHS}",
        f"trainer.save_checkpoint_path={OUTPUT_DIR}",
        "trainer.save_freq=50",
        'trainer.logger=["console"]',
    ]
    
    process = subprocess.Popen(
        cmd,
        env=env,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1
    )
    
    # Stream output
    for line in process.stdout:
        print(line, end='')
    
    process.wait()
    print(f"\n✓ Training completed with exit code: {process.returncode}")
else:
    print("Training not started. Set RUN_TRAINING = True to start.")
    print("Or run the shell command from Option A in terminal.")


## 6. Understanding GRPO Reward Function

GRPO uses IoU (Intersection over Union) based rewards for grounding tasks:


In [None]:
# Preview the reward function
reward_func_path = FINETUNING_PATH / "verl" / "configs" / "reward_func.py"
print("GRPO Reward Function (Box IoU):")
print("=" * 50)
print("""
The reward function computes:
1. Parse model output for detected boxes
2. Match predicted boxes to ground truth by class
3. Calculate IoU for each matched pair
4. Compute Precision & Recall based on IoU scores
5. Return F1 score as final reward

Formula:
- IoU = intersection_area / union_area
- Recall = sum(best_IoU per GT box) / num_GT_boxes
- Precision = sum(best_IoU per pred box) / num_pred_boxes
- Reward = F1 = 2 * (Precision * Recall) / (Precision + Recall)

This encourages the model to:
- Detect all objects (high recall)
- Avoid false positives (high precision)
- Localize objects accurately (high IoU)
""")


## 7. Monitor Training & Load Checkpoints


In [None]:
# Check for saved checkpoints
import glob

checkpoint_pattern = str(OUTPUT_DIR / "**" / "*.pt")
checkpoints = glob.glob(checkpoint_pattern, recursive=True)

if checkpoints:
    print(f"Found {len(checkpoints)} checkpoints:")
    for ckpt in sorted(checkpoints)[-5:]:  # Show last 5
        print(f"  - {ckpt}")
else:
    print(f"No checkpoints found in {OUTPUT_DIR}")
    print("Run training first to generate checkpoints.")


In [None]:
# Load and merge GRPO checkpoint to HuggingFace format
# This converts the distributed checkpoint to a standard HF model

MERGE_CHECKPOINT = False  # Set to True after training

if MERGE_CHECKPOINT and checkpoints:
    from verl.utils.checkpoint import merge_checkpoint_to_hf
    
    # Use latest checkpoint
    latest_ckpt = sorted(checkpoints)[-1]
    output_hf_path = OUTPUT_DIR / "hf_model"
    
    print(f"Merging checkpoint: {latest_ckpt}")
    print(f"Output: {output_hf_path}")
    
    # Run merge script
    merge_cmd = f'''
    python {FINETUNING_PATH}/tools/merge_rl_checkpoints_to_hg_version.py \\
        --checkpoint_path {latest_ckpt} \\
        --output_path {output_hf_path} \\
        --model_name {MODEL_NAME}
    '''
    print(f"Run this command to merge:\n{merge_cmd}")
else:
    print("Set MERGE_CHECKPOINT = True after training to merge checkpoints.")


---

## Summary

This notebook sets up GRPO (Group Relative Policy Optimization) training for Rex-Omni on VRSBench data.

**Key Components:**
1. **TSVRLHFDataset** - Loads TSV data with grounding annotations
2. **GroundingTaskFn** - Converts annotations to detection prompts
3. **BoxIoU Reward** - Computes F1 score based on detection IoU
4. **Ray + vLLM** - Distributed rollout generation
5. **FSDP** - Distributed model training

**GRPO vs SFT:**
| Aspect | SFT | GRPO |
|--------|-----|------|
| Supervision | Ground truth labels | Reward signals |
| Training | Single forward pass | Generate → Evaluate → Update |
| Scaling | Single GPU possible | Multi-GPU required |
| Memory | ~24GB | ~200GB+ total |

**Files Created:**
- `data/vrsbench_grpo/` - TSV training data
- `configs/grpo_vrsbench.py` - Dataset config
- `run_grpo_vrsbench.sh` - Training launch script
- `work_dirs/grpo_vrsbench/` - Checkpoints
