# Crafter VLM GEPA Demo

This demo runs GEPA prompt optimization for a Crafter vision-language agent that uses image-only observations.

**What this demo does:**
1. Creates a local task app for the Crafter VLM agent
2. Runs GEPA prompt optimization to find the best system prompt
3. Extracts the optimized prompt from results
4. Runs eval jobs comparing baseline vs optimized prompts
5. Displays comparison results

In [1]:
# Parameters (can be overridden by papermill)
BACKEND_URL = "https://api.usesynth.ai"  # Default to production
API_KEY = None  # Will be set based on environment
POLICY_MODEL = "gpt-4.1-nano"  # VLM model for the agent
VERIFIER_MODEL = "gpt-5-nano"  # Model for verification (must be in allowed list)
ROLLOUT_BUDGET = 30  # Total rollout budget
NUM_GENERATIONS = 2  # Number of GEPA generations
USE_TUNNEL = True  # Whether to use cloudflared tunnels (required for prod)

In [2]:
# Parameters
BACKEND_URL = "https://api.usesynth.ai"
POLICY_MODEL = "gpt-4.1-nano"
VERIFIER_MODEL = "gpt-5-nano"
ROLLOUT_BUDGET = 6
NUM_GENERATIONS = 1
USE_TUNNEL = True


In [3]:
# Step 1: Imports and Setup
from __future__ import annotations

import asyncio
import json
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional

import httpx
from dotenv import load_dotenv
from openai import AsyncOpenAI

load_dotenv()

# Add parent directory to path for imports
sys.path.insert(0, str(Path('.').resolve().parent.parent))

from synth_ai.sdk.api.eval import EvalJob, EvalJobConfig
from synth_ai.sdk.api.train.prompt_learning import PromptLearningJob
from synth_ai.sdk.learning.prompt_learning_client import PromptLearningClient
from synth_ai.sdk.learning.rl import mint_environment_api_key, setup_environment_api_key
from synth_ai.sdk.localapi import LocalAPIConfig, create_local_api
from synth_ai.sdk.task import TaskInfo, run_server_background
from synth_ai.sdk.task.contracts import RolloutMetrics, RolloutRequest, RolloutResponse
from synth_ai.sdk.tunnels import wait_for_health_check
from synth_ai.sdk.tunnels.tunneled_api import TunneledLocalAPI, TunnelBackend

from crafter_logic import (
    ACTION_STRING_TO_INT,
    CRAFTER_ALLOWED_ACTIONS,
    CrafterEnvironmentWrapper,
    CrafterScorer,
    CrafterVLMReActPolicy,
    normalize_action_name,
)

print('Imports loaded successfully')

  class StructuredOutputConfig(BaseModel):


[celery_app] EXPERIMENT_QUEUE_DB_PATH not set, will use default path


[celery_app] Using default database path: /Users/joshpurtell/.synth_ai/experiment_queue.db


[celery_app] Initializing with database: /Users/joshpurtell/.synth_ai/experiment_queue.db (broker: redis://localhost:6379/0)


Imports loaded successfully


In [4]:
# Step 2: Configuration

SYNTH_API_BASE = 'https://api.usesynth.ai'
LOCAL_API_PORT = 8001
OPTIMIZED_LOCAL_API_PORT = 8002

print(f'Backend: {SYNTH_API_BASE}')
print(f'Local API Ports: {LOCAL_API_PORT}, {OPTIMIZED_LOCAL_API_PORT}')

Backend: https://api.usesynth.ai
Local API Ports: 8001, 8002


In [5]:
# Step 3: Get API Key and Check Backend Health

if API_KEY:
    SYNTH_API_KEY = API_KEY
else:
    SYNTH_API_KEY = os.environ.get('SYNTH_API_KEY', '').strip()

if not SYNTH_API_KEY:
    raise RuntimeError('SYNTH_API_KEY not set. Please set it in environment or pass as parameter.')

print(f'Using API Key: {SYNTH_API_KEY[:20]}...')

# Check backend health
r = httpx.get(f'{SYNTH_API_BASE}/health', timeout=30)
if r.status_code == 200:
    print(f'Backend health: {r.json()}')
else:
    raise RuntimeError(f'Backend not healthy: status {r.status_code}')

Using API Key: sk_live_ace8b968-a52...


Backend health: {'status': 'ok', 'database': 'connected', 'details': {}}


In [6]:
# Step 5: Local API Factory

APP_ID = "crafter_vlm"
APP_NAME = "Crafter VLM ReAct Agent"
TOOL_NAME = "crafter_interact"

def create_crafter_vlm_local_api(system_prompt: str, env_api_key: str):
    """Factory to create a Crafter VLM task app with a specific system prompt."""
    # Import inside factory to ensure availability in closure
    from crafter_logic import (
        ACTION_STRING_TO_INT,
        CRAFTER_ALLOWED_ACTIONS,
        CrafterEnvironmentWrapper,
        CrafterScorer,
        CrafterVLMReActPolicy,
        normalize_action_name,
    )
    
    os.environ['ENVIRONMENT_API_KEY'] = env_api_key

    async def run_rollout(request: RolloutRequest, fastapi_request) -> RolloutResponse:
        policy_config = request.policy.config or {}
        seed = request.env.seed or 0
        env_config = request.env.config or {}
        max_steps = int(env_config.get('max_steps_per_episode', 200))
        max_turns = int(env_config.get('max_turns', 50))

        env = CrafterEnvironmentWrapper(seed=seed, max_steps=max_steps)
        observation = await env.reset()

        policy = CrafterVLMReActPolicy(
            system_prompt=system_prompt,
            use_vision=True,
            image_only_mode=True,
        )

        # Route OpenAI calls through Synth's inference proxy for trace reconstruction
        inference_url = policy_config.get('inference_url', '')
        if inference_url:
            os.environ['OPENAI_BASE_URL'] = inference_url
        
        # Use policy_config api_key (from Synth proxy) or fall back to OPENAI_API_KEY env var
        api_key = policy_config.get('api_key') or os.environ.get('OPENAI_API_KEY')
        if not api_key:
            raise ValueError("No API key available: policy_config['api_key'] and OPENAI_API_KEY env var are both empty")
        client = AsyncOpenAI(api_key=api_key)

        history: List[Dict[str, Any]] = []
        episode_rewards: List[float] = []

        for turn in range(max_turns):
            messages = policy.build_messages(observation, history)
            
            response = await client.chat.completions.create(
                model=policy_config.get('model', POLICY_MODEL),
                messages=messages,
                tools=policy.tools,
                tool_choice='required',
                max_completion_tokens=policy_config.get('max_completion_tokens', 512),
            )
            
            message = response.choices[0].message
            response_text = message.content or ''
            tool_calls = [
                {'id': tc.id, 'type': 'function', 'function': {'name': tc.function.name, 'arguments': tc.function.arguments}}
                for tc in (message.tool_calls or [])
            ]

            next_observation = observation
            tool_responses: List[Dict[str, Any]] = []
            
            if tool_calls:
                for tc in tool_calls:
                    tool_call_id = tc['id']
                    tool_name = tc['function']['name']
                    actions_list: List[str] = []
                    
                    if tool_name == TOOL_NAME:
                        try:
                            args = json.loads(tc['function']['arguments'])
                            raw_actions = args.get('actions_list', [])
                            actions_list = [str(a) for a in raw_actions if str(a).strip()][:5]
                        except Exception:
                            pass
                    
                    if not actions_list:
                        actions_list = ['noop']

                    normalized_actions = []
                    action_results = []

                    for action_str in actions_list:
                        normalized = normalize_action_name(action_str) or 'noop'
                        normalized_actions.append(normalized)
                        action = ACTION_STRING_TO_INT.get(normalized, 0)
                        next_observation = await env.step(action)
                        reward = next_observation.get('reward', 0.0)
                        episode_rewards.append(float(reward))
                        action_results.append({
                            'action': normalized,
                            'reward': reward,
                            'terminated': next_observation.get('terminated'),
                            'truncated': next_observation.get('truncated'),
                        })
                        if next_observation.get('terminated') or next_observation.get('truncated'):
                            break

                    tool_responses.append({'tool_call_id': tool_call_id, 'actions': normalized_actions, 'results': action_results})
                    if next_observation.get('terminated') or next_observation.get('truncated'):
                        break
            else:
                next_observation = await env.step(0)
                episode_rewards.append(float(next_observation.get('reward', 0.0)))

            history.append({'role': 'assistant', 'content': response_text, 'tool_calls': tool_calls})
            for resp in tool_responses:
                history.append({
                    'role': 'tool',
                    'tool_call_id': resp['tool_call_id'],
                    'content': json.dumps({'actions': resp['actions'], 'results': resp['results']}),
                })

            observation = next_observation
            if observation.get('terminated') or observation.get('truncated'):
                break

        score, details = CrafterScorer.score_episode(observation, len(episode_rewards), max_steps)

        return RolloutResponse(
            run_id=request.run_id,
            metrics=RolloutMetrics(outcome_reward=score, details=details),
            trace=None,  # Synth reconstructs from inference proxy
            trace_correlation_id=policy_config.get('trace_correlation_id'),
        )

    def provide_taskset_description():
        return {'splits': ['train', 'test']}

    def provide_task_instances(seeds):
        for seed in seeds:
            yield TaskInfo(
                task={'id': APP_ID, 'name': APP_NAME},
                dataset={'id': APP_ID, 'split': 'train', 'index': seed},
                inference={'tool': TOOL_NAME},
                limits={'max_turns': 50},
                task_metadata={'seed': seed},
            )

    return create_local_api(LocalAPIConfig(
        app_id=APP_ID,
        name=APP_NAME,
        description=f'{APP_NAME} local API for VLM agent with image-only observations.',
        provide_taskset_description=provide_taskset_description,
        provide_task_instances=provide_task_instances,
        rollout=run_rollout,
        cors_origins=['*'],
    ))


print('Local API factory defined')

Local API factory defined


In [7]:
# Step 6: GEPA Job Runner

def run_gepa_job(
    *,
    api_key: str,
    local_api_url: str,
    local_api_key: str,
    baseline_system_prompt: str,
):
    """Run a GEPA prompt optimization job."""
    config_body = {
        'prompt_learning': {
            'algorithm': 'gepa',
            'task_app_url': local_api_url,
            'task_app_api_key': local_api_key,
            'env_name': 'crafter',
            'initial_prompt': {
                'messages': [{'role': 'system', 'order': 0, 'pattern': baseline_system_prompt}],
                'wildcards': {},
            },
            'policy': {
                'inference_mode': 'synth_hosted',
                'model': POLICY_MODEL,
                'provider': 'openai',
                'temperature': 0.0,
                'max_completion_tokens': 512,
            },
            'gepa': {
                'env_name': 'crafter',
                'evaluation': {'seeds': list(range(30)), 'validation_seeds': list(range(50, 56))},
                'rollout': {'budget': ROLLOUT_BUDGET, 'max_concurrent': 3, 'minibatch_size': 3},
                'mutation': {'rate': 0.3},
                'population': {'initial_size': 3, 'num_generations': NUM_GENERATIONS, 'children_per_generation': 2},
                'archive': {'size': 5, 'pareto_set_size': 10},
                'token': {'max_limit': 4000, 'counting_model': 'gpt-4', 'max_spend_usd': 50.0},
            },
            'env': {
                'max_turns': 20,  # Cap VLM calls per rollout
                'max_steps_per_episode': 200,
            },
            'verifier': {
                'enabled': False,
                'reward_source': 'task_app',
            },
        },
    }

    job = PromptLearningJob.from_dict(
        config_dict=config_body,
        backend_url=SYNTH_API_BASE,
        api_key=api_key,
        task_app_api_key=local_api_key,
        skip_health_check=True,
    )
    job_id = job.submit()
    print(f'GEPA job created: {job_id}')
    
    result = job.poll_until_complete(timeout=3600.0, interval=3.0, progress=True)
    print(f'GEPA job finished: {result.status.value}')
    return result


print('GEPA job runner defined')

GEPA job runner defined


In [8]:
# Step 7: Eval Job Runner

EVAL_MODEL = "gpt-4o-mini"  # Use real OpenAI model for eval (eval doesn't support synth_hosted)
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')  # Get OpenAI key for eval jobs

def run_eval_job(*, local_api_url: str, local_api_key: str, seeds: list[int], mode: str):
    """Run an eval job and wait for completion."""
    config = EvalJobConfig(
        task_app_url=local_api_url,
        backend_url=SYNTH_API_BASE,
        api_key=SYNTH_API_KEY,
        task_app_api_key=local_api_key,
        env_name='crafter',
        seeds=seeds,
        policy_config={
            'model': EVAL_MODEL,
            'provider': 'openai',
            'api_key': OPENAI_API_KEY,  # Pass OpenAI key to task app
        },
        env_config={
            'max_steps_per_episode': 200,
            'max_turns': 20,
        },
        concurrency=5,
    )
    job = EvalJob(config)
    job_id = job.submit()
    print(f'  {mode} eval job: {job_id}')
    return job.poll_until_complete(timeout=600.0, interval=2.0, progress=True)


print('Eval job runner defined')

Eval job runner defined


In [9]:
# Step 8: Setup Environment API Key

environment_api_key = mint_environment_api_key()
os.environ['ENVIRONMENT_API_KEY'] = environment_api_key
print(f'Minted: {environment_api_key[:12]}...{environment_api_key[-4:]}')

try:
    setup_environment_api_key(SYNTH_API_BASE, SYNTH_API_KEY, token=environment_api_key)
    print('Environment API key uploaded')
except Exception as exc:
    print(f'Warning: failed to upload ENVIRONMENT_API_KEY: {exc}')

Minted: b8ab02fef61f...9565


[env-keys] public_key: b64_len=44 sha256=e173cb5664e240ba1b0d97a36000b0c4c15f42506a227a82cd2b2c00be119c57 head=LYqK7q2klATZWxRq tail=bgj7pMsckDLn1j8=
[env-keys] plaintext: len=64 preview=b8ab02……9565 has_ws=False
[env-keys] ciphertext: b64_len=152 sha256=d671581b1db8ee7e4b1b2dc4dbafed30364e5aa9e08f91f5d788ae76520c8223 head=X+wZu5nAVKraPPJw tail=8XNPObZedsKSsg==


Environment API key uploaded


In [10]:
# Step 9: Define Baseline Prompt

allowed_actions = ', '.join(CRAFTER_ALLOWED_ACTIONS)
baseline_prompt = (
    'You are an agent playing Crafter, a survival crafting game. '
    'Your goal is to survive and unlock achievements by exploring, crafting, and building. '
    'You can see the game state through images. Analyze each image carefully to understand '
    'your surroundings, inventory, health, and available resources. '
    'Use the crafter_interact tool to execute actions. '
    "Key mechanics: use 'do' only when adjacent to a resource (tree, stone, cow, plant); "
    'it does nothing on grass or water. '
    'Craft progression: wood -> table -> wood_pickaxe -> stone -> stone_pickaxe -> iron tools. '
    'Sleep when energy is low to restore and unlock wake_up. '
    f'Available actions: {allowed_actions}. '
    'Only use these action names and return 2-5 actions per decision. '
    'Strategy: move toward trees to collect wood; place a table once you have wood; '
    'craft a wood pickaxe, then collect stone and craft a stone pickaxe; '
    'progress toward iron tools and combat when safe.'
)

print('Baseline prompt:')
print(baseline_prompt[:200] + '...')

Baseline prompt:
You are an agent playing Crafter, a survival crafting game. Your goal is to survive and unlock achievements by exploring, crafting, and building. You can see the game state through images. Analyze eac...


In [None]:
# Step 10: Start Baseline Local API

baseline_app = create_crafter_vlm_local_api(baseline_prompt, environment_api_key)
run_server_background(baseline_app, port=LOCAL_API_PORT)
await wait_for_health_check('127.0.0.1', LOCAL_API_PORT, environment_api_key, timeout=60.0)

if USE_TUNNEL:
    # Create tunnel to expose local API to the internet
    print(f'Creating tunnel for port {LOCAL_API_PORT}...')
    baseline_tunnel = await TunneledLocalAPI.create(
        local_port=LOCAL_API_PORT,
        backend=TunnelBackend.CloudflareManagedTunnel,
        api_key=SYNTH_API_KEY,
        env_api_key=environment_api_key,
        backend_url=SYNTH_API_BASE,
        progress=True,
    )
    BASELINE_LOCAL_API_URL = baseline_tunnel.url
else:
    BASELINE_LOCAL_API_URL = f'http://localhost:{LOCAL_API_PORT}'

print(f'Baseline local API URL: {BASELINE_LOCAL_API_URL}')

In [None]:
# Step 11: Run GEPA Optimization

print('Starting GEPA optimization...')
print(f'  Rollout budget: {ROLLOUT_BUDGET}')
print(f'  Generations: {NUM_GENERATIONS}')

job_result = run_gepa_job(
    api_key=SYNTH_API_KEY,
    local_api_url=BASELINE_LOCAL_API_URL,
    local_api_key=environment_api_key,
    baseline_system_prompt=baseline_prompt,
)

print(f'\nGEPA Status: {job_result.status.value}')
if job_result.succeeded:
    print('GEPA optimization succeeded!')
else:
    print(f'GEPA failed: {job_result.error}')

In [None]:
# Step 12: Extract Optimized Prompt

def extract_system_prompt(best_prompt: Optional[Dict[str, Any]]) -> Optional[str]:
    """Extract system prompt from prompt learning results."""
    if not best_prompt:
        return None
    for msg in best_prompt.get('messages', []):
        if msg.get('role') == 'system':
            return msg.get('pattern') or msg.get('content')
    for sec in best_prompt.get('sections', []):
        if sec.get('role') == 'system':
            return sec.get('content')
    return None

optimized_prompt = None

if job_result.succeeded:
    pl_client = PromptLearningClient(SYNTH_API_BASE, SYNTH_API_KEY)
    prompt_results = await pl_client.get_prompts(job_result.job_id)
    optimized_prompt = extract_system_prompt(prompt_results.best_prompt)
    
    if optimized_prompt:
        print('=' * 60)
        print('OPTIMIZED PROMPT')
        print('=' * 60)
        print(optimized_prompt[:800] + '...' if len(optimized_prompt) > 800 else optimized_prompt)
        print('=' * 60)
        
        # Save to results directory
        results_dir = Path('results')
        results_dir.mkdir(exist_ok=True)
        with open(results_dir / 'optimized_prompt.txt', 'w') as f:
            f.write(optimized_prompt)
        print(f'\nSaved optimized prompt to: {results_dir / "optimized_prompt.txt"}')
    else:
        print('Failed to extract optimized prompt from results')
else:
    print('Skipping prompt extraction (GEPA did not succeed)')

In [None]:
# Step 13: Run Evaluation (Baseline vs Optimized)

EVAL_SEEDS = list(range(100, 120))  # 20 held-out test samples

if optimized_prompt:
    # Start optimized local API
    optimized_app = create_crafter_vlm_local_api(optimized_prompt, environment_api_key)
    run_server_background(optimized_app, port=OPTIMIZED_LOCAL_API_PORT)
    await wait_for_health_check('127.0.0.1', OPTIMIZED_LOCAL_API_PORT, environment_api_key, timeout=60.0)
    
    if USE_TUNNEL:
        # Create tunnel for optimized API
        print(f'Creating tunnel for port {OPTIMIZED_LOCAL_API_PORT}...')
        optimized_tunnel = await TunneledLocalAPI.create(
            local_port=OPTIMIZED_LOCAL_API_PORT,
            backend=TunnelBackend.CloudflareManagedTunnel,
            api_key=SYNTH_API_KEY,
            env_api_key=environment_api_key,
            backend_url=SYNTH_API_BASE,
            progress=True,
        )
        OPTIMIZED_LOCAL_API_URL = optimized_tunnel.url
    else:
        OPTIMIZED_LOCAL_API_URL = f'http://localhost:{OPTIMIZED_LOCAL_API_PORT}'
    
    print(f'Optimized local API URL: {OPTIMIZED_LOCAL_API_URL}')
    print(f'\nRunning evaluation on {len(EVAL_SEEDS)} seeds...')

    # Run baseline eval
    print('\nRunning BASELINE eval...')
    baseline_eval = run_eval_job(
        local_api_url=BASELINE_LOCAL_API_URL,
        local_api_key=environment_api_key,
        seeds=EVAL_SEEDS,
        mode='baseline',
    )

    # Run optimized eval
    print('\nRunning OPTIMIZED eval...')
    optimized_eval = run_eval_job(
        local_api_url=OPTIMIZED_LOCAL_API_URL,
        local_api_key=environment_api_key,
        seeds=EVAL_SEEDS,
        mode='optimized',
    )

    # Display results
    print('\n' + '=' * 60)
    print('EVALUATION RESULTS')
    print('=' * 60)
    print(f'Baseline: {baseline_eval.raw}')
    print(f'Optimized: {optimized_eval.raw}')
    
    # Save results
    results_dir = Path('results')
    results_dir.mkdir(exist_ok=True)
    with open(results_dir / 'eval_results.json', 'w') as f:
        json.dump({'baseline': baseline_eval.raw, 'optimized': optimized_eval.raw}, f, indent=2)
    print(f'\nSaved eval results to: {results_dir / "eval_results.json"}')
else:
    print('Skipping evaluation (no optimized prompt)')

In [None]:
# Step 14: Done
print('Demo complete!')