# VRS Bench Inference with Rex-Omni API (Async & Multi-Object)

This notebook runs inference on the `vrsbench_val_data.parquet` dataset using the **Rex-Omni** HTTP API.

**Key Features:**
- **Multi-Object Splitting**: Each referring sentence in an image is treated as a separate sample.
- **Async Inference**: Uses `aiohttp` and `asyncio` for concurrent requests.
- **Concurrency Control**: Limits concurrent requests to 5.
- **Sample Limit**: Runs on the first 1000 samples.

In [None]:
!pip install aiohttp nest_asyncio

In [None]:
import aiohttp
import asyncio
import base64
import json
import pandas as pd
import os
import nest_asyncio
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm.asyncio import tqdm

# Apply nest_asyncio to allow nested event loops in Jupyter
nest_asyncio.apply()

## 1. Configuration

In [None]:
# Rex-Omni API Endpoint
API_URL = "https://animeshraj958--rex-vision-service-api-grounding.modal.run"

# Dataset
PARQUET_FILE = "vrsbench_val_data.parquet"
OUTPUT_FILE = "vrs_inference_results_async.json"

# Limits
MAX_SAMPLES = 1000
CONCURRENCY = 5

## 2. Data Preparation

In [None]:
import json
import os
import pandas as pd

def load_and_process_data(parquet_file, limit=None, base_dir='.'):
    """
    Robust loader with logging for missing images / malformed rows.

    - base_dir: attempt to resolve relative image paths under this directory.
    """
    print(f"Loading {parquet_file} ...")
    if not os.path.exists(parquet_file):
        raise FileNotFoundError(f"Parquet file not found: {parquet_file}")

    df = pd.read_parquet(parquet_file)
    print(f"Parquet rows: {len(df)}")

    samples = []
    skipped = {
        "missing_image": 0,
        "bad_objects": 0,
        "bad_object_entry": 0,
    }

    for idx, row in df.iterrows():
        image_path = row.get('image_path') or row.get('image') or None
        image_id = row.get('image_id') or row.get('id') or f"row_{idx}"

        if not image_path:
            # no image path -> skip
            skipped["missing_image"] += 1
            continue

        # try absolute first, then join with base_dir
        if not os.path.exists(image_path):
            candidate = os.path.join(base_dir, image_path)
            if os.path.exists(candidate):
                image_path = candidate
            else:
                # try basename in base_dir
                bn = os.path.basename(image_path)
                candidate2 = os.path.join(base_dir, bn)
                if os.path.exists(candidate2):
                    image_path = candidate2
                else:
                    # missing
                    skipped["missing_image"] += 1
                    continue

        # get objects field
        objects = row.get('objects', None)
        if objects is None:
            # try other probable column names
            for alt in ['instances', 'annotations', 'objs']:
                if alt in row:
                    objects = row.get(alt)
                    break

        # if string, try parse JSON
        if isinstance(objects, str):
            try:
                objects = json.loads(objects)
            except Exception as e:
                skipped["bad_objects"] += 1
                continue

        if not isinstance(objects, list):
            skipped["bad_objects"] += 1
            continue

        for obj in objects:
            if not isinstance(obj, dict):
                skipped["bad_object_entry"] += 1
                continue

            # try common keys for referring sentence and bbox
            caption = obj.get('referring_sentence') or obj.get('caption') or obj.get('sentence') or ""
            bbox = obj.get('obj_coord') or obj.get('bbox') or obj.get('box') or obj.get('coords') or []

            samples.append({
                "image_id": image_id,
                "image_path": image_path,
                "caption": caption,
                "ground_truth_bbox": bbox,
                "obj_id": obj.get('obj_id') or obj.get('id')
            })

            if limit and len(samples) >= limit:
                print("Reached limit:", limit)
                print("Skipped summary:", skipped)
                return samples

    print("Finished. samples:", len(samples))
    print("Skipped summary:", skipped)
    return samples

## 3. Async Inference Logic

In [None]:
async def encode_image(image_path):
    """Reads image and converts to base64 (blocking I/O run in executor)."""
    loop = asyncio.get_event_loop()
    def _read():
        try:
            with open(image_path, "rb") as f:
                data = f.read()
                encoded = base64.b64encode(data).decode('utf-8')
                # Determine mime type
                mime = "image/png" if image_path.lower().endswith(".png") else "image/jpeg"
                return f"data:{mime};base64,{encoded}"
        except Exception as e:
            return None
    return await loop.run_in_executor(None, _read)

async def process_sample(session, sample, semaphore):
    """Process a single sample with concurrency limit."""
    async with semaphore:
        image_path = sample['image_path']
        caption = sample['caption']
        
        # Encode image
        image_data = await encode_image(image_path)
        if not image_data:
            return {**sample, "error": "Image read failed"}
            
        payload = {
            "image": image_data,
            "caption": caption
        }
        
        try:
            async with session.post(API_URL, json=payload) as response:
                if response.status == 200:
                    result = await response.json()
                    return {**sample, "prediction": result}
                else:
                    text = await response.text()
                    return {**sample, "error": f"API {response.status}: {text}"}
        except Exception as e:
            return {**sample, "error": str(e)}

async def run_async_inference(samples, concurrency=5):
    semaphore = asyncio.Semaphore(concurrency)
    
    async with aiohttp.ClientSession() as session:
        tasks = [
            process_sample(session, sample, semaphore)
            for sample in samples
        ]
        
        results = []
        # Use tqdm for progress bar
        for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Async Inference"):
            result = await f
            results.append(result)
            
    return results

## 4. Run Inference

In [None]:
# Run the async loop
results = asyncio.run(run_async_inference(all_samples, concurrency=CONCURRENCY))

# Save results
with open(OUTPUT_FILE, "w") as f:
    json.dump(results, f, indent=2)
    
print(f"Saved {len(results)} results to {OUTPUT_FILE}")

## 5. Visualize a Sample

In [None]:
def visualize_result(image_input, result, caption):
    """Helper to draw bounding boxes on the image"""
    try:
        img = Image.open(image_input).convert("RGB")
        plt.figure(figsize=(10, 10))
        plt.imshow(img)
        ax = plt.gca()
        plt.title(f"Caption: {caption}")
        
        # Draw boxes from Rex-Omni annotations
        if result and "annotations" in result:
            for ann in result.get("annotations", []):
                phrase = ann.get("phrase", "object")
                boxes = ann.get("boxes", [])
                
                for bbox in boxes:
                    # bbox is [x1, y1, x2, y2]
                    width = bbox[2] - bbox[0]
                    height = bbox[3] - bbox[1]
                    
                    rect = patches.Rectangle(
                        (bbox[0], bbox[1]), width, height, 
                        linewidth=2, edgecolor='#00FF00', facecolor='none'
                    )
                    ax.add_patch(rect)
                    
                    plt.text(
                        bbox[0], bbox[1] - 5, phrase, 
                        color='black', fontsize=10, weight='bold',
                        bbox=dict(facecolor='#00FF00', alpha=0.7, edgecolor='none', pad=2)
                    )
            
        plt.axis('off')
        plt.show()
    except Exception as e:
        print(f"Visualization error: {e}")

# Visualize the first successful result
for res in results:
    if "prediction" in res and res["prediction"].get("success", False):
        visualize_result(res['image_path'], res['prediction'], res['caption'])
        break