# Gemini Model Evaluation on VSI-Bench

Evaluate Gemini models (2.5 Pro, 3.1 Pro) on VSI-Bench and Cambrian-S benchmarks for video spatial intelligence.

## Models
| Model | Model ID | Status |
|-------|----------|--------|
| Gemini 2.5 Pro | `gemini-2.5-pro-preview` | Available |
| Gemini 3.1 Pro | `gemini-3.1-pro` | Placeholder |

## Benchmarks
- **VSI-Bench** (default): 5,000+ video spatial intelligence questions
  - 8 task types: object direction, distance, counting, room size, etc.
  - Metrics: Accuracy (MCA), Mean Relative Accuracy (numerical)
- **Cambrian-S Suite** (optional): VideoMME, EgoSchema, MVBench, CV-Bench, 3DSR

## Eval Modes
- `tiny`: 50 samples (~5 min) - sanity check
- `small`: 200 samples (~20 min) - quick evaluation
- `full`: 5,000+ samples (~8 hours) - complete benchmark

## Requirements
- Google API key with Gemini access
- Google Colab (any GPU tier works, API-based)

## Sources
- [VSI-Bench Paper](https://arxiv.org/abs/2412.14171)
- [VSI-Bench Dataset](https://huggingface.co/datasets/nyu-visionx/VSI-Bench)
- [Cambrian-S Project](https://cambrian-mllm.github.io/cambrian-s/)

## 1. Setup Environment

In [None]:
# Install dependencies
%pip install -q google-genai>=1.0.0
%pip install -q datasets>=2.0.0
%pip install -q tqdm pandas matplotlib seaborn tabulate
%pip install -q huggingface_hub

print("Dependencies installed!")

In [None]:
# Verify installation
import google.genai as genai
from datasets import load_dataset
import pandas as pd
import matplotlib.pyplot as plt

print(f"google-genai version: {genai.__version__}")
print("All imports successful!")

## 2. Configuration

In [None]:
import os
from dataclasses import dataclass, field, asdict
from typing import Optional, List, Dict, Any, Union
from enum import Enum


class EvalMode(Enum):
    TINY = "tiny"      # 50 samples, ~5 min
    SMALL = "small"    # 200 samples, ~20 min
    FULL = "full"      # All samples


@dataclass
class GeminiConfig:
    """Configuration for Gemini API."""
    model_name: str = "gemini-2.5-pro-preview"
    api_key: str = ""

    # Rate limiting (conservative for free tier)
    requests_per_minute: int = 10
    retry_attempts: int = 3
    retry_delay_seconds: float = 2.0

    # Video settings
    video_fps: int = 1  # Frames per second for video sampling
    max_video_duration_seconds: int = 300  # 5 minutes max

    # Generation settings
    max_output_tokens: int = 256
    temperature: float = 0.0  # Greedy decoding for reproducibility


@dataclass
class EvalConfig:
    """Evaluation configuration."""
    mode: EvalMode = EvalMode.TINY
    seed: int = 42

    # Sample counts per mode
    TINY_SAMPLES: int = 50
    SMALL_SAMPLES: int = 200

    # Benchmark config
    benchmark_config: str = "full"  # "full" or "debiased"

    # Checkpoint settings
    save_frequency: int = 10  # Save checkpoint every N samples

    @property
    def num_samples(self) -> Optional[int]:
        """Get number of samples based on mode."""
        if self.mode == EvalMode.TINY:
            return self.TINY_SAMPLES
        elif self.mode == EvalMode.SMALL:
            return self.SMALL_SAMPLES
        return None  # Full benchmark


# ============================================================
# USER CONFIGURATION - Modify these settings
# ============================================================

# API Key (get from https://aistudio.google.com/apikey)
# Option 1: Set directly (not recommended for sharing)
# GOOGLE_API_KEY = "your-api-key-here"

# Option 2: Use Colab secrets (recommended)
try:
    from google.colab import userdata
    GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
except:
    GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY', '')

# Model selection
MODELS_TO_EVALUATE = [
    "gemini-2.5-pro-preview",
    # "gemini-3.1-pro",  # Placeholder for future model
]

# Evaluation mode
EVAL_MODE = "tiny"    # Options: "tiny", "small", "full"

# Benchmarks to run (default: VSI-Bench only)
BENCHMARKS_TO_RUN = ["vsi_bench"]
# BENCHMARKS_TO_RUN = ["vsi_bench", "videomme", "egoschema"]  # Cambrian-S suite

# Output directory (Google Drive)
OUTPUT_DIR = "/content/drive/MyDrive/gemini_vsi_eval"

# Resume from checkpoint
RESUME_FROM_CHECKPOINT = True

# ============================================================
# Create configuration objects
# ============================================================

gemini_config = GeminiConfig(
    model_name=MODELS_TO_EVALUATE[0],
    api_key=GOOGLE_API_KEY,
)

eval_config = EvalConfig(
    mode=EvalMode(EVAL_MODE),
)

# Display configuration
print("=" * 60)
print("Configuration")
print("=" * 60)
print(f"Models: {MODELS_TO_EVALUATE}")
print(f"Eval mode: {eval_config.mode.value}")
print(f"Samples: {eval_config.num_samples or 'all'}")
print(f"Benchmarks: {BENCHMARKS_TO_RUN}")
print(f"API key configured: {'Yes' if GOOGLE_API_KEY else 'No'}")
print("=" * 60)

if not GOOGLE_API_KEY:
    print("\n WARNING: No API key found!")
    print("Set GOOGLE_API_KEY in Colab secrets or environment.")

## 3. Gemini API Client

In [None]:
import time
import hashlib
from pathlib import Path


class GeminiClient:
    """Wrapper for Google Generative AI SDK with video support."""

    def __init__(self, config: GeminiConfig):
        self.config = config
        self._init_client()
        self._uploaded_files: Dict[str, Any] = {}  # Cache uploaded files
        self._last_request_time = 0.0
        self._consecutive_errors = 0

    def _init_client(self):
        """Initialize the Google GenAI client."""
        from google import genai
        self.client = genai.Client(api_key=self.config.api_key)

    def _rate_limit_wait(self):
        """Implement rate limiting between requests."""
        min_interval = 60.0 / self.config.requests_per_minute
        elapsed = time.time() - self._last_request_time
        if elapsed < min_interval:
            time.sleep(min_interval - elapsed)
        self._last_request_time = time.time()

    def _handle_rate_limit_error(self):
        """Handle 429 error with exponential backoff."""
        self._consecutive_errors += 1
        wait_time = min(2 ** self._consecutive_errors, 60)  # Max 60s
        print(f"Rate limited. Waiting {wait_time}s...")
        time.sleep(wait_time)

    def _reset_errors(self):
        """Reset error counter after successful request."""
        self._consecutive_errors = 0

    def upload_video(self, video_path: str) -> Any:
        """Upload video to Gemini Files API with caching."""
        # Check cache first
        cache_key = hashlib.md5(video_path.encode()).hexdigest()
        if cache_key in self._uploaded_files:
            file = self._uploaded_files[cache_key]
            # Verify file still exists
            try:
                file = self.client.files.get(name=file.name)
                if file.state.name == "ACTIVE":
                    return file
            except:
                pass  # Re-upload if not found

        # Upload via Files API
        file = self.client.files.upload(file=video_path)

        # Wait for processing
        while file.state.name == "PROCESSING":
            time.sleep(2)
            file = self.client.files.get(name=file.name)

        if file.state.name != "ACTIVE":
            raise RuntimeError(f"File upload failed: {file.state.name}")

        self._uploaded_files[cache_key] = file
        return file

    def generate_with_video(
        self,
        video_path: str,
        prompt: str,
    ) -> str:
        """Generate response for video + text prompt."""
        self._rate_limit_wait()

        for attempt in range(self.config.retry_attempts):
            try:
                # Upload video
                video_file = self.upload_video(video_path)

                # Generate response
                response = self.client.models.generate_content(
                    model=self.config.model_name,
                    contents=[video_file, prompt],
                    config={
                        "temperature": self.config.temperature,
                        "max_output_tokens": self.config.max_output_tokens,
                    },
                )

                self._reset_errors()
                return response.text

            except Exception as e:
                error_str = str(e).lower()
                if "429" in error_str or "rate" in error_str:
                    self._handle_rate_limit_error()
                elif attempt < self.config.retry_attempts - 1:
                    time.sleep(self.config.retry_delay_seconds)
                else:
                    raise

        raise RuntimeError("Max retries exceeded")

    def cleanup_uploaded_files(self):
        """Delete uploaded files to free quota."""
        for cache_key, file in list(self._uploaded_files.items()):
            try:
                self.client.files.delete(name=file.name)
            except Exception:
                pass
        self._uploaded_files.clear()
        print(f"Cleaned up uploaded files")


# Test API connection
if GOOGLE_API_KEY:
    try:
        client = GeminiClient(gemini_config)
        # Quick test with text only
        response = client.client.models.generate_content(
            model=gemini_config.model_name,
            contents="Say 'API connection successful' in exactly those words.",
        )
        print(f"API test: {response.text}")
    except Exception as e:
        print(f"API test failed: {e}")

## 4. VSI-Bench Data Loader

In [None]:
import random
from collections import defaultdict
from tqdm.notebook import tqdm
from huggingface_hub import hf_hub_download
import zipfile


@dataclass
class VSIBenchSample:
    """A single VSI-Bench sample."""
    sample_id: str
    video_path: str
    question: str
    ground_truth: str
    options: Optional[List[str]]  # None for numerical questions
    question_type: str  # e.g., "object_counting", "room_size"
    task_category: str  # "configurational", "measurement", "spatiotemporal"
    is_numerical: bool  # True for NA tasks, False for MCA tasks
    metadata: Dict = field(default_factory=dict)


class VSIBenchLoader:
    """Load VSI-Bench dataset from HuggingFace.

    Videos are stored separately in ZIP files and must be downloaded/extracted.
    The dataset uses `dataset` (arkitscenes/scannet/scannetpp) and `scene_name`
    to identify which video corresponds to each sample.
    """

    REPO_ID = "nyu-visionx/VSI-Bench"

    # Video sources and their ZIP files
    VIDEO_SOURCES = ["arkitscenes", "scannet", "scannetpp"]

    # Task category mapping
    TASK_CATEGORIES = {
        # Configurational tasks (MCA)
        "object_rel_direction": "configurational",
        "object_rel_direction_easy": "configurational",
        "object_rel_direction_medium": "configurational",
        "object_rel_direction_hard": "configurational",
        "object_rel_distance": "configurational",
        "route_plan": "configurational",
        # Measurement estimation tasks (NA - numerical)
        "object_counting": "measurement",
        "abs_dist": "measurement",
        "object_abs_distance": "measurement",
        "room_size": "measurement",
        "room_size_estimation": "measurement",
        "obj_size_estimation": "measurement",
        "object_size_estimation": "measurement",
        # Spatiotemporal tasks (MCA)
        "appearance_order": "spatiotemporal",
    }

    NUMERICAL_TASKS = {
        "object_counting", "abs_dist", "object_abs_distance",
        "room_size", "room_size_estimation",
        "obj_size_estimation", "object_size_estimation"
    }

    def __init__(
        self,
        config: str = "full",  # "full" or "debiased"
        cache_dir: Optional[str] = None,
        num_samples: Optional[int] = None,  # Limit for testing
        seed: int = 42,
        stratified: bool = True,  # Ensure all task types represented
    ):
        self.config = config
        self.cache_dir = Path(cache_dir or "/content/vsi_bench_cache")
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.video_dir = self.cache_dir / "videos"
        self.video_dir.mkdir(parents=True, exist_ok=True)
        self.num_samples = num_samples
        self.seed = seed
        self.stratified = stratified
        self._dataset = None
        self._video_paths: Dict[str, str] = {}  # (dataset, scene_name) -> video_path

    def _download_and_extract_videos(self, required_sources: set):
        """Download and extract video ZIP files from HuggingFace."""
        for source in required_sources:
            zip_name = f"{source}.zip"
            extract_dir = self.video_dir / source

            # Skip if already extracted
            if extract_dir.exists() and any(extract_dir.iterdir()):
                print(f"  {source}: already extracted")
                continue

            print(f"  {source}: downloading...")
            try:
                zip_path = hf_hub_download(
                    repo_id=self.REPO_ID,
                    filename=zip_name,
                    repo_type="dataset",
                    cache_dir=str(self.cache_dir / "hf_cache"),
                )

                print(f"  {source}: extracting...")
                extract_dir.mkdir(parents=True, exist_ok=True)
                with zipfile.ZipFile(zip_path, 'r') as zf:
                    zf.extractall(extract_dir)
                print(f"  {source}: done")

            except Exception as e:
                print(f"  {source}: failed - {e}")

    def _build_video_index(self):
        """Build index mapping (dataset, scene_name) to video file paths."""
        self._video_paths = {}

        for source in self.VIDEO_SOURCES:
            source_dir = self.video_dir / source

            if not source_dir.exists():
                continue

            # Find all video files (mp4, avi, mov, etc.)
            for video_file in source_dir.rglob("*"):
                if video_file.suffix.lower() in {".mp4", ".avi", ".mov", ".mkv"}:
                    # Scene name is typically the file stem or parent folder name
                    scene_name = video_file.stem
                    key = (source, scene_name)
                    self._video_paths[key] = str(video_file)

                    # Also try parent folder as scene name (some datasets structure it this way)
                    if video_file.parent.name != source:
                        alt_key = (source, video_file.parent.name)
                        if alt_key not in self._video_paths:
                            self._video_paths[alt_key] = str(video_file)

        print(f"Indexed {len(self._video_paths)} videos")

    def load(self) -> List[VSIBenchSample]:
        """Load dataset from HuggingFace."""
        print(f"Loading VSI-Bench ({self.config})...")

        # Load annotations
        self._dataset = load_dataset(
            self.REPO_ID,
            split="test",
        )

        print(f"Loaded {len(self._dataset)} annotation samples")

        # Apply sampling first (before downloading videos)
        if self.num_samples and len(self._dataset) > self.num_samples:
            if self.stratified:
                indices = self._stratified_sample()
            else:
                random.seed(self.seed)
                indices = random.sample(range(len(self._dataset)), self.num_samples)
            self._dataset = self._dataset.select(indices)
            print(f"Sampled {len(self._dataset)} samples (stratified={self.stratified})")

        # Determine which video sources we need
        required_sources = set(item["dataset"] for item in self._dataset)
        print(f"Required video sources: {required_sources}")

        # Download and extract videos
        print("Downloading videos...")
        self._download_and_extract_videos(required_sources)

        # Build video index
        self._build_video_index()

        # Convert to samples
        samples = []
        missing_videos = 0
        for idx, item in enumerate(tqdm(self._dataset, desc="Processing samples")):
            sample = self._process_item(idx, item)
            if sample:
                samples.append(sample)
            else:
                missing_videos += 1

        print(f"Processed {len(samples)} valid samples")
        if missing_videos > 0:
            print(f"Warning: {missing_videos} samples skipped (missing videos)")

        return samples

    def _stratified_sample(self) -> List[int]:
        """Sample ensuring all task types are represented."""
        random.seed(self.seed)

        # Group indices by task type
        task_indices = defaultdict(list)
        for idx, item in enumerate(self._dataset):
            task_type = item.get("question_type", "unknown")
            task_indices[task_type].append(idx)

        # Calculate samples per task
        num_tasks = len(task_indices)
        samples_per_task = max(1, self.num_samples // num_tasks)
        remaining = self.num_samples - (samples_per_task * num_tasks)

        # Sample from each task
        selected = []
        for task, indices in task_indices.items():
            n = min(samples_per_task, len(indices))
            selected.extend(random.sample(indices, n))

        # Fill remaining with random samples
        all_indices = set(range(len(self._dataset))) - set(selected)
        if remaining > 0 and all_indices:
            selected.extend(random.sample(list(all_indices), min(remaining, len(all_indices))))

        return selected[:self.num_samples]

    def _get_video_path(self, item: Dict) -> Optional[str]:
        """Get video path for a dataset item."""
        dataset = item.get("dataset", "")
        scene_name = item.get("scene_name", "")

        # Try exact match
        key = (dataset, scene_name)
        if key in self._video_paths:
            return self._video_paths[key]

        # Try without leading zeros or with different formatting
        for (ds, sn), path in self._video_paths.items():
            if ds == dataset and (sn == scene_name or sn.lstrip("0") == scene_name.lstrip("0")):
                return path

        return None

    def _process_item(self, idx: int, item: Dict) -> Optional[VSIBenchSample]:
        """Process a single dataset item."""
        # Skip pruned samples if using debiased config
        if self.config == "debiased" and item.get("pruned", False):
            return None

        # Get video path
        video_path = self._get_video_path(item)
        if not video_path:
            return None

        question_type = item.get("question_type", "unknown")

        # Parse options if available
        options = item.get("options")
        if options and isinstance(options, str):
            options = [opt.strip() for opt in options.split("\n") if opt.strip()]

        return VSIBenchSample(
            sample_id=str(item.get("id", idx)),
            video_path=video_path,
            question=item["question"],
            ground_truth=str(item["ground_truth"]),
            options=options,
            question_type=question_type,
            task_category=self.TASK_CATEGORIES.get(question_type, "unknown"),
            is_numerical=question_type in self.NUMERICAL_TASKS,
            metadata={
                "dataset": item.get("dataset", "unknown"),
                "scene_name": item.get("scene_name", ""),
            },
        )

    def get_task_breakdown(self) -> Dict[str, int]:
        """Get count of samples per task type."""
        if self._dataset is None:
            return {}
        from collections import Counter
        return dict(Counter(item["question_type"] for item in self._dataset))


# Test loader
print("Testing VSI-Bench loader...")
test_loader = VSIBenchLoader(
    config=eval_config.benchmark_config,
    num_samples=5,  # Just test with 5
    seed=eval_config.seed,
)

try:
    test_samples = test_loader.load()
    if test_samples:
        print(f"\nSample question: {test_samples[0].question[:100]}...")
        print(f"Task type: {test_samples[0].question_type}")
        print(f"Video path: {test_samples[0].video_path}")
    else:
        print("\nNo samples loaded. Videos may need to be downloaded first.")
except Exception as e:
    print(f"Loader test failed: {e}")
    import traceback
    traceback.print_exc()

## 5. Evaluation Metrics

In [None]:
import re
import numpy as np


class VSIBenchMetrics:
    """Evaluation metrics for VSI-Bench."""

    @staticmethod
    def extract_answer(response: str, is_numerical: bool) -> Union[str, float, None]:
        """Extract answer from model response."""
        response = response.strip()

        if is_numerical:
            # Extract number from response
            numbers = re.findall(r'[-+]?\d*\.?\d+', response)
            if numbers:
                try:
                    return float(numbers[0])
                except ValueError:
                    return None
            return None
        else:
            # Extract option letter (A, B, C, D)
            response_upper = response.upper()

            # Try patterns in order of specificity
            patterns = [
                r'\b([ABCD])\b',           # Word boundary
                r'answer[:\s]+([ABCD])',   # "answer: A" pattern
                r'^([ABCD])[\s\.,)]',      # Starts with letter
                r'([ABCD])',               # Any occurrence
            ]

            for pattern in patterns:
                match = re.search(pattern, response_upper)
                if match:
                    return match.group(1)

            # Fallback: first character if single letter
            if len(response) == 1 and response.upper() in 'ABCD':
                return response.upper()

            return None

    @staticmethod
    def accuracy(predictions: List[str], targets: List[str]) -> float:
        """Calculate exact match accuracy for MCA tasks."""
        if not predictions:
            return 0.0
        correct = sum(1 for p, t in zip(predictions, targets) if p == t)
        return correct / len(predictions)

    @staticmethod
    def mean_relative_accuracy(
        predictions: List[float],
        targets: List[float],
        start: float = 0.5,
        end: float = 0.95,
        interval: float = 0.05,
    ) -> float:
        """Calculate Mean Relative Accuracy for numerical tasks.

        Based on VSI-Bench implementation:
        https://github.com/vision-x-nyu/thinking-in-space
        """
        if not predictions or not targets:
            return 0.0

        predictions = np.array(predictions)
        targets = np.array(targets)

        # Filter out invalid entries
        valid = (targets != 0) & ~np.isnan(predictions) & ~np.isnan(targets)
        if not valid.any():
            return 0.0

        predictions = predictions[valid]
        targets = targets[valid]

        # Normalized absolute distance
        abs_dist_norm = np.abs(predictions - targets) / np.abs(targets)

        # Calculate accuracy across confidence intervals
        num_pts = int((end - start) / interval) + 2
        conf_intervals = np.linspace(start, end, num_pts)

        accuracies = []
        for conf in conf_intervals:
            acc = (abs_dist_norm <= 1 - conf).mean()
            accuracies.append(acc)

        return float(np.mean(accuracies))

    @staticmethod
    def aggregate_by_task(
        results: List[Dict],
    ) -> Dict[str, Dict[str, float]]:
        """Aggregate results by task type."""
        task_results = defaultdict(lambda: {
            "mca_correct": 0,
            "mca_total": 0,
            "na_predictions": [],
            "na_targets": [],
        })

        for r in results:
            if r.get("status") == "error":
                continue

            task = r["question_type"]
            if r["is_numerical"]:
                if r["prediction"] is not None and r["target"] is not None:
                    task_results[task]["na_predictions"].append(r["prediction"])
                    task_results[task]["na_targets"].append(r["target"])
            else:
                task_results[task]["mca_total"] += 1
                if r["prediction"] == r["target"]:
                    task_results[task]["mca_correct"] += 1

        # Calculate metrics per task
        metrics = {}
        for task, data in task_results.items():
            if data["mca_total"] > 0:
                metrics[task] = {
                    "accuracy": data["mca_correct"] / data["mca_total"],
                    "count": data["mca_total"],
                    "type": "MCA",
                }
            elif data["na_predictions"]:
                mra = VSIBenchMetrics.mean_relative_accuracy(
                    data["na_predictions"],
                    data["na_targets"],
                )
                metrics[task] = {
                    "mra": mra,
                    "count": len(data["na_predictions"]),
                    "type": "NA",
                }

        return metrics


# Test metrics
print("Testing metrics...")
print(f"Extract 'A' from 'The answer is A': {VSIBenchMetrics.extract_answer('The answer is A', False)}")
print(f"Extract number from '42 meters': {VSIBenchMetrics.extract_answer('42 meters', True)}")
print(f"Accuracy [A,B,A] vs [A,A,A]: {VSIBenchMetrics.accuracy(['A','B','A'], ['A','A','A']):.2%}")
print(f"MRA [10,20] vs [12,18]: {VSIBenchMetrics.mean_relative_accuracy([10,20], [12,18]):.3f}")

## 6. Evaluation Runner

In [None]:
import json
from datetime import datetime


@dataclass
class EvaluationCheckpoint:
    """Track evaluation progress for resumption."""

    model_name: str
    benchmark: str
    config: str  # "full" or "debiased"

    completed_ids: List[str] = field(default_factory=list)
    results: List[Dict] = field(default_factory=list)

    started_at: str = ""
    last_updated: str = ""

    # Aggregated metrics
    overall_accuracy: float = 0.0
    overall_mra: float = 0.0
    task_metrics: Dict[str, Dict] = field(default_factory=dict)

    @classmethod
    def load_or_create(
        cls,
        checkpoint_dir: str,
        model_name: str,
        benchmark: str,
        config: str,
    ) -> "EvaluationCheckpoint":
        """Load existing checkpoint or create new one."""
        # Sanitize model name for filename
        safe_model_name = model_name.replace("/", "_").replace("-", "_")
        checkpoint_path = Path(checkpoint_dir) / f"{safe_model_name}_{benchmark}_{config}.json"

        if checkpoint_path.exists():
            with open(checkpoint_path) as f:
                data = json.load(f)
            return cls(**data)

        return cls(
            model_name=model_name,
            benchmark=benchmark,
            config=config,
            started_at=datetime.now().isoformat(),
        )

    def save(self, checkpoint_dir: str):
        """Save checkpoint to file."""
        self.last_updated = datetime.now().isoformat()
        safe_model_name = self.model_name.replace("/", "_").replace("-", "_")
        checkpoint_path = Path(checkpoint_dir) / f"{safe_model_name}_{self.benchmark}_{self.config}.json"
        checkpoint_path.parent.mkdir(parents=True, exist_ok=True)

        with open(checkpoint_path, "w") as f:
            json.dump(asdict(self), f, indent=2, default=str)

    def is_completed(self, sample_id: str) -> bool:
        """Check if a sample has been evaluated."""
        return sample_id in self.completed_ids

    def add_result(self, sample_id: str, result: Dict):
        """Add a result and mark sample as complete."""
        self.completed_ids.append(sample_id)
        self.results.append(result)

    def get_progress(self, total: int) -> float:
        """Get completion percentage."""
        return len(self.completed_ids) / total if total > 0 else 0.0


class GeminiEvaluator:
    """Evaluate Gemini models on VSI-Bench."""

    PROMPT_TEMPLATE_MCA = """Watch this video carefully and answer the following question.

Question: {question}

Options:
{options}

Respond with ONLY the letter of the correct answer (A, B, C, or D). Do not include any explanation."""

    PROMPT_TEMPLATE_NUMERICAL = """Watch this video carefully and answer the following question.

Question: {question}

Respond with ONLY a single number as your answer. Do not include units or explanations."""

    def __init__(
        self,
        client: GeminiClient,
        checkpoint_dir: str,
        save_frequency: int = 10,
    ):
        self.client = client
        self.checkpoint_dir = checkpoint_dir
        self.save_frequency = save_frequency

    def format_prompt(self, sample: VSIBenchSample) -> str:
        """Format prompt for a sample."""
        if sample.is_numerical:
            return self.PROMPT_TEMPLATE_NUMERICAL.format(
                question=sample.question
            )
        else:
            options_text = "\n".join(sample.options) if sample.options else ""
            return self.PROMPT_TEMPLATE_MCA.format(
                question=sample.question,
                options=options_text,
            )

    def evaluate_sample(self, sample: VSIBenchSample) -> Dict:
        """Evaluate a single sample."""
        prompt = self.format_prompt(sample)

        try:
            response = self.client.generate_with_video(
                sample.video_path,
                prompt,
            )

            # Extract answer
            prediction = VSIBenchMetrics.extract_answer(
                response,
                sample.is_numerical
            )

            # Parse ground truth
            if sample.is_numerical:
                try:
                    target = float(sample.ground_truth)
                except ValueError:
                    target = None
            else:
                target = sample.ground_truth.strip().upper()
                if len(target) > 1:
                    target = target[0]  # Take first letter

            return {
                "sample_id": sample.sample_id,
                "question_type": sample.question_type,
                "task_category": sample.task_category,
                "is_numerical": sample.is_numerical,
                "prediction": prediction,
                "target": target,
                "raw_response": response,
                "correct": prediction == target if not sample.is_numerical else None,
                "status": "success",
            }

        except Exception as e:
            return {
                "sample_id": sample.sample_id,
                "question_type": sample.question_type,
                "task_category": sample.task_category,
                "is_numerical": sample.is_numerical,
                "prediction": None,
                "target": None,
                "raw_response": str(e),
                "correct": False,
                "status": "error",
                "error": str(e),
            }

    def run(
        self,
        samples: List[VSIBenchSample],
        model_name: str,
        benchmark: str = "vsi_bench",
        config: str = "full",
        resume: bool = True,
    ) -> EvaluationCheckpoint:
        """Run evaluation on all samples."""

        # Load or create checkpoint
        if resume:
            checkpoint = EvaluationCheckpoint.load_or_create(
                self.checkpoint_dir,
                model_name,
                benchmark,
                config,
            )
        else:
            checkpoint = EvaluationCheckpoint(
                model_name=model_name,
                benchmark=benchmark,
                config=config,
                started_at=datetime.now().isoformat(),
            )

        # Filter to remaining samples
        remaining = [s for s in samples if not checkpoint.is_completed(s.sample_id)]

        print(f"\nEvaluation: {model_name}")
        print(f"Total samples: {len(samples)}")
        print(f"Already completed: {len(checkpoint.completed_ids)}")
        print(f"Remaining: {len(remaining)}")

        if not remaining:
            print("All samples already evaluated!")
            self._update_metrics(checkpoint)
            return checkpoint

        # Progress bar
        pbar = tqdm(remaining, desc=f"Evaluating")

        for i, sample in enumerate(pbar):
            result = self.evaluate_sample(sample)
            checkpoint.add_result(sample.sample_id, result)

            # Update progress description
            progress = checkpoint.get_progress(len(samples))
            pbar.set_description(f"Evaluating ({progress:.1%})")

            # Periodic save
            if (i + 1) % self.save_frequency == 0:
                self._update_metrics(checkpoint)
                checkpoint.save(self.checkpoint_dir)

        # Final save
        self._update_metrics(checkpoint)
        checkpoint.save(self.checkpoint_dir)

        # Cleanup uploaded files
        self.client.cleanup_uploaded_files()

        return checkpoint

    def _update_metrics(self, checkpoint: EvaluationCheckpoint):
        """Update aggregated metrics in checkpoint."""
        checkpoint.task_metrics = VSIBenchMetrics.aggregate_by_task(
            checkpoint.results
        )

        # Calculate overall metrics
        mca_results = [r for r in checkpoint.results
                       if not r["is_numerical"] and r["status"] == "success"]
        na_results = [r for r in checkpoint.results
                      if r["is_numerical"] and r["status"] == "success"]

        if mca_results:
            checkpoint.overall_accuracy = VSIBenchMetrics.accuracy(
                [r["prediction"] for r in mca_results],
                [r["target"] for r in mca_results],
            )

        if na_results:
            predictions = [r["prediction"] for r in na_results if r["prediction"] is not None]
            targets = [r["target"] for r in na_results if r["target"] is not None]
            if predictions and targets:
                checkpoint.overall_mra = VSIBenchMetrics.mean_relative_accuracy(
                    predictions, targets
                )


print("Evaluation runner ready.")

## 7. Mount Google Drive

In [None]:
# Mount Google Drive for persistent storage
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print(f"Google Drive mounted. Output dir: {OUTPUT_DIR}")
except:
    print("Not running in Colab, using local directory")
    OUTPUT_DIR = "./gemini_vsi_eval"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory ready: {OUTPUT_DIR}")

## 8. Execute Evaluation

In [None]:
# Load VSI-Bench samples
print("Loading VSI-Bench dataset...")
print(f"Eval mode: {eval_config.mode.value} ({eval_config.num_samples or 'all'} samples)")

loader = VSIBenchLoader(
    config=eval_config.benchmark_config,
    num_samples=eval_config.num_samples,
    seed=eval_config.seed,
    stratified=True,
)

samples = loader.load()

# Show task breakdown
print("\nTask breakdown:")
task_counts = defaultdict(int)
for s in samples:
    task_counts[s.question_type] += 1
for task, count in sorted(task_counts.items()):
    print(f"  {task}: {count}")

In [None]:
# Run evaluation for each model
all_results = {}

for model_name in MODELS_TO_EVALUATE:
    print("\n" + "=" * 60)
    print(f"Evaluating: {model_name}")
    print("=" * 60)

    # Create client for this model
    model_config = GeminiConfig(
        model_name=model_name,
        api_key=GOOGLE_API_KEY,
    )
    client = GeminiClient(model_config)

    # Create evaluator
    evaluator = GeminiEvaluator(
        client=client,
        checkpoint_dir=OUTPUT_DIR,
        save_frequency=eval_config.save_frequency,
    )

    # Run evaluation
    checkpoint = evaluator.run(
        samples=samples,
        model_name=model_name,
        benchmark="vsi_bench",
        config=eval_config.benchmark_config,
        resume=RESUME_FROM_CHECKPOINT,
    )

    all_results[model_name] = checkpoint

    # Print summary
    print(f"\n{model_name} Results:")
    print(f"  Overall Accuracy (MCA): {checkpoint.overall_accuracy:.2%}")
    print(f"  Overall MRA (Numerical): {checkpoint.overall_mra:.3f}")
    print(f"  Samples evaluated: {len(checkpoint.results)}")

print("\n" + "=" * 60)
print("Evaluation complete!")
print("=" * 60)

## 9. Results & Visualization

In [None]:
import seaborn as sns
from tabulate import tabulate

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")


def create_results_dataframe(checkpoints: Dict[str, EvaluationCheckpoint]) -> pd.DataFrame:
    """Create DataFrame from evaluation results."""
    rows = []

    for model_name, checkpoint in checkpoints.items():
        for task, metrics in checkpoint.task_metrics.items():
            score = metrics.get("accuracy", metrics.get("mra", 0))
            rows.append({
                "model": model_name.split("/")[-1],
                "task": task,
                "score": score,
                "type": metrics.get("type", "unknown"),
                "count": metrics.get("count", 0),
            })

    return pd.DataFrame(rows)


def plot_task_performance(df: pd.DataFrame, save_path: str = None):
    """Create bar chart of performance by task type."""
    if df.empty:
        print("No data to plot.")
        return

    fig, ax = plt.subplots(figsize=(14, 6))

    # Pivot for grouped bars
    models = df["model"].unique()
    tasks = df["task"].unique()
    x = np.arange(len(tasks))
    width = 0.8 / len(models)

    colors = plt.cm.tab10(np.linspace(0, 1, len(models)))

    for i, model in enumerate(models):
        model_data = df[df["model"] == model]
        scores = []
        for task in tasks:
            val = model_data[model_data["task"] == task]["score"].values
            scores.append(val[0] if len(val) > 0 else 0)

        offset = (i - len(models) / 2 + 0.5) * width
        bars = ax.bar([xi + offset for xi in x], scores, width, label=model, color=colors[i])

        # Add value labels
        for bar, score in zip(bars, scores):
            if score > 0:
                ax.annotate(f"{score:.1%}" if score <= 1 else f"{score:.2f}",
                           xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                           ha="center", va="bottom", fontsize=7, rotation=45)

    ax.set_xlabel("Task")
    ax.set_ylabel("Score")
    ax.set_title("VSI-Bench Performance by Task")
    ax.set_xticks(x)
    ax.set_xticklabels(tasks, rotation=45, ha="right")
    ax.legend(loc="upper right")
    ax.set_ylim(0, 1.1)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.show()


def plot_category_radar(checkpoints: Dict[str, EvaluationCheckpoint], save_path: str = None):
    """Create radar chart of performance by category."""
    categories = ["configurational", "measurement", "spatiotemporal"]

    fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))

    angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
    angles += angles[:1]  # Close the polygon

    colors = plt.cm.tab10(np.linspace(0, 1, len(checkpoints)))

    for idx, (model_name, checkpoint) in enumerate(checkpoints.items()):
        # Calculate average score per category
        scores = []
        for cat in categories:
            cat_tasks = [t for t, m in checkpoint.task_metrics.items()
                        if VSIBenchLoader.TASK_CATEGORIES.get(t) == cat]
            if cat_tasks:
                cat_scores = [checkpoint.task_metrics[t].get("accuracy",
                             checkpoint.task_metrics[t].get("mra", 0))
                             for t in cat_tasks]
                scores.append(np.mean(cat_scores))
            else:
                scores.append(0)

        scores += scores[:1]  # Close the polygon

        short_name = model_name.split("/")[-1]
        ax.plot(angles, scores, "o-", linewidth=2, label=short_name, color=colors[idx])
        ax.fill(angles, scores, alpha=0.1, color=colors[idx])

    ax.set_xticks(angles[:-1])
    ax.set_xticklabels([c.title() for c in categories], size=12)
    ax.set_ylim(0, 1)
    ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
    plt.title("VSI-Bench Category Performance", size=14, y=1.08)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.show()


# Create visualizations
if all_results:
    print("\n" + "=" * 60)
    print("Visualizations")
    print("=" * 60)

    df = create_results_dataframe(all_results)

    # Task performance chart
    plot_task_performance(df, save_path=os.path.join(OUTPUT_DIR, "task_performance.png"))

    # Radar chart
    plot_category_radar(all_results, save_path=os.path.join(OUTPUT_DIR, "category_radar.png"))
else:
    print("No results to visualize yet.")

In [None]:
# Create comparison table
def create_leaderboard(checkpoints: Dict[str, EvaluationCheckpoint]) -> pd.DataFrame:
    """Create leaderboard comparing models."""
    rows = []

    for model_name, checkpoint in checkpoints.items():
        row = {
            "Model": model_name.split("/")[-1],
            "MCA Accuracy": f"{checkpoint.overall_accuracy:.1%}",
            "NA MRA": f"{checkpoint.overall_mra:.3f}",
            "Samples": len(checkpoint.results),
        }

        # Add per-category scores
        for cat in ["configurational", "measurement", "spatiotemporal"]:
            cat_tasks = [t for t, m in checkpoint.task_metrics.items()
                        if VSIBenchLoader.TASK_CATEGORIES.get(t) == cat]
            if cat_tasks:
                cat_scores = [checkpoint.task_metrics[t].get("accuracy",
                             checkpoint.task_metrics[t].get("mra", 0))
                             for t in cat_tasks]
                row[cat.title()] = f"{np.mean(cat_scores):.1%}"
            else:
                row[cat.title()] = "N/A"

        rows.append(row)

    return pd.DataFrame(rows)


if all_results:
    leaderboard = create_leaderboard(all_results)

    print("\n" + "=" * 60)
    print("LEADERBOARD")
    print("=" * 60)
    print(tabulate(leaderboard, headers="keys", tablefmt="fancy_grid", showindex=False))

    # Save to CSV
    leaderboard.to_csv(os.path.join(OUTPUT_DIR, "leaderboard.csv"), index=False)
    print(f"\nLeaderboard saved to: {OUTPUT_DIR}/leaderboard.csv")

## 10. Export Results

In [None]:
import zipfile


def create_results_archive(output_dir: str) -> str:
    """Create a zip archive of all results."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    archive_name = f"gemini_vsi_eval_{timestamp}"
    archive_path = f"/content/{archive_name}.zip"

    with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(output_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, output_dir)
                zipf.write(file_path, arcname)

    return archive_path


# Create and download archive
if all_results:
    archive_path = create_results_archive(OUTPUT_DIR)
    print(f"Archive created: {archive_path}")

    # Download (in Colab)
    try:
        from google.colab import files
        files.download(archive_path)
        print("Download started. Check your browser's download folder.")
    except:
        print(f"Archive available at: {archive_path}")

## Summary

### Evaluation Complete!

#### Results Summary
- All results saved to Google Drive
- Comparison tables and visualizations generated
- Checkpoint saved for potential resumption

#### Files Generated:
- `*_vsi_bench_*.json` - Checkpoint files with detailed results
- `leaderboard.csv` - Model comparison table
- `task_performance.png` - Bar chart by task
- `category_radar.png` - Radar chart by category

#### Next Steps:
1. Review task-level performance to identify model strengths/weaknesses
2. Run with `EVAL_MODE = "full"` for complete benchmark results
3. Add more models by updating `MODELS_TO_EVALUATE` list
4. Enable additional Cambrian-S benchmarks via `BENCHMARKS_TO_RUN`