# Synth GEPA Demo - Banking77

Prompt optimization using Synth's GEPA algorithm on the Banking77 intent classification task.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/synth-laboratories/synth-ai/blob/main/demos/gepa_banking77/gepa_banking77_prompt_optimization.ipynb)

**Structure:**
1. **Setup** - Install dependencies and configure
2. **Task Definition** - Banking77 classification task
3. **Local API** - Expose the task for optimization
4. **Optimize** - Run GEPA to discover better prompts
5. **Evaluate** - Formal eval on held-out data

In [None]:
# Step 0: Install dependencies (Colab only)
import sys

if "google.colab" in sys.modules:
    import os

    _INSTALLED_MARKER = "/content/.synth_deps_v2"

    if os.path.exists(_INSTALLED_MARKER):
        print("Dependencies ready.")
    else:
        print("Installing dependencies...")
        
        # Simple install - pip handles versioning
        !pip install -q "synth-ai>=0.6.3" httpx fastapi uvicorn datasets nest_asyncio

        # Install cloudflared
        !wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O /usr/local/bin/cloudflared 2>/dev/null
        !chmod +x /usr/local/bin/cloudflared

        with open(_INSTALLED_MARKER, 'w') as f:
            f.write("ok")

        print("Done! Restarting runtime...")
        os.kill(os.getpid(), 9)
else:
    print("Not in Colab - assuming dependencies installed")

## Step 1: Setup

In [None]:
# Step 1: Setup - All imports, config, and API keys
import os, sys, json, asyncio
import httpx
import nest_asyncio

nest_asyncio.apply()

from datasets import load_dataset
from openai import AsyncOpenAI
from synth_ai.core.utils.env import mint_demo_api_key
from synth_ai.core.utils.urls import BACKEND_URL_BASE

# Production backend
SYNTH_API_BASE = BACKEND_URL_BASE
# Ports are optional - will auto-find available ports if not specified
LOCAL_API_PORT = 8001  # Optional: specify a port, or None to auto-select
OPTIMIZED_LOCAL_API_PORT = 8002  # Optional: specify a port, or None to auto-select

# Always mint a demo key for this notebook
print("\nMinting demo SYNTH_API_KEY for this demo...")
API_KEY = mint_demo_api_key()
print(f"Demo API Key: {API_KEY[:25]}...")

# Set API key in environment for SDK to use
os.environ["SYNTH_API_KEY"] = API_KEY

# Synth inference URL - all LLM calls go through Synth's hosted inference
# Uses the inference proxy endpoint (OpenAI-compatible)
SYNTH_INFERENCE_URL = f"{SYNTH_API_BASE}/api/inference/v1"
print(f"\nUsing Synth hosted inference: {SYNTH_INFERENCE_URL}")

print("\n" + "=" * 50)
print("SETUP COMPLETE")
print("=" * 50)

## Step 2: Task Definition

Banking77 is an intent classification task with 77 possible intents.

In [None]:
BANKING77_LABELS = [
    "activate_my_card",
    "age_limit",
    "apple_pay_or_google_pay",
    "atm_support",
    "automatic_top_up",
    "balance_not_updated_after_bank_transfer",
    "balance_not_updated_after_cheque_or_cash_deposit",
    "beneficiary_not_allowed",
    "cancel_transfer",
    "card_about_to_expire",
    "card_acceptance",
    "card_arrival",
    "card_delivery_estimate",
    "card_linking",
    "card_not_working",
    "card_payment_fee_charged",
    "card_payment_not_recognised",
    "card_payment_wrong_exchange_rate",
    "card_swallowed",
    "cash_withdrawal_charge",
    "cash_withdrawal_not_recognised",
    "change_pin",
    "compromised_card",
    "contactless_not_working",
    "country_support",
    "declined_card_payment",
    "declined_cash_withdrawal",
    "declined_transfer",
    "direct_debit_payment_not_recognised",
    "disposable_card_limits",
    "edit_personal_details",
    "exchange_charge",
    "exchange_rate",
    "exchange_via_app",
    "extra_charge_on_statement",
    "failed_transfer",
    "fiat_currency_support",
    "get_disposable_virtual_card",
    "get_physical_card",
    "getting_spare_card",
    "getting_virtual_card",
    "lost_or_stolen_card",
    "lost_or_stolen_phone",
    "order_physical_card",
    "passcode_forgotten",
    "pending_card_payment",
    "pending_cash_withdrawal",
    "pending_top_up",
    "pending_transfer",
    "pin_blocked",
    "receiving_money",
    "Refund_not_showing_up",
    "request_refund",
    "reverted_card_payment?",
    "supported_cards_and_currencies",
    "terminate_account",
    "top_up_by_bank_transfer_charge",
    "top_up_by_card_charge",
    "top_up_by_cash_or_cheque",
    "top_up_failed",
    "top_up_limits",
    "top_up_reverted",
    "topping_up_by_card",
    "transaction_charged_twice",
    "transfer_fee_charged",
    "transfer_into_account",
    "transfer_not_received_by_recipient",
    "transfer_timing",
    "unable_to_verify_identity",
    "verify_my_identity",
    "verify_source_of_funds",
    "verify_top_up",
    "virtual_card_not_working",
    "visa_or_mastercard",
    "why_verify_identity",
    "wrong_amount_of_cash_received",
    "wrong_exchange_rate_for_cash_withdrawal",
]

TOOL_NAME = "banking77_classify"
TOOL_SCHEMA = {
    "type": "function",
    "function": {
        "name": TOOL_NAME,
        "description": "Return the predicted banking77 intent label.",
        "parameters": {
            "type": "object",
            "properties": {"intent": {"type": "string"}},
            "required": ["intent"],
        },
    },
}


def format_available_intents(label_names: list) -> str:
    """Format the list of available intents for the prompt."""
    return "\n".join(f"{i + 1}. {l}" for i, l in enumerate(label_names))


async def classify_banking77_query(
    query: str,
    system_prompt: str,
    available_intents: str | None = None,
    model: str = "gpt-4o-mini",
    api_key: str | None = None,
    inference_url: str | None = None,
) -> str:
    """Classify a banking query into an intent.

    Args:
        query: The customer query to classify
        system_prompt: System prompt for the model
        available_intents: Formatted list of available intents
        model: Model to use (e.g., "gpt-4o-mini")
        api_key: API key for authentication
        inference_url: Inference URL (interceptor URL during optimization, regular URL otherwise)

    Returns:
        The predicted intent label
    """
    if available_intents is None:
        available_intents = format_available_intents(BANKING77_LABELS)

    user_msg = (
        f"Customer Query: {query}\n\n"
        f"Available Intents:\n{available_intents}\n\n"
        f"Classify this query into one of the above banking intents using the tool call."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_msg},
    ]

    if inference_url:
        # Use interceptor URL - pass Synth API key via X-API-Key header
        default_headers = {"X-API-Key": api_key} if api_key else {}
        client = AsyncOpenAI(
            base_url=inference_url,
            api_key="synth-interceptor",  # Dummy - interceptor uses its own key
            default_headers=default_headers,
        )
    else:
        # Fallback to Synth's hosted inference for standalone usage
        client = AsyncOpenAI(
            base_url=SYNTH_INFERENCE_URL,
            api_key=API_KEY,
        )

    response = await client.chat.completions.create(
        model=model,
        messages=messages,
        tools=[TOOL_SCHEMA],
        tool_choice={"type": "function", "function": {"name": TOOL_NAME}},
    )

    tool_call = response.choices[0].message.tool_calls[0]
    args = json.loads(tool_call.function.arguments)
    return args["intent"]


# Load Banking77 from source CSV (HuggingFace dataset scripts no longer supported)
dataset = load_dataset(
    "csv",
    data_files="https://raw.githubusercontent.com/PolyAI-LDN/task-specific-datasets/master/banking_data/test.csv",
    split="train"
)
# CSV has 'category' column (string labels) instead of 'label' (int)
label_names = sorted(set(dataset["category"]))
print(f"Loaded {len(dataset)} test samples with {len(label_names)} intent labels")

print("\n" + "=" * 50)
print("BUSINESS LOGIC READY")
print("=" * 50)
print("\nclassify_banking77_query(query, system_prompt) -> intent")
print("\nThis is the core app. Now let's see how prompts affect performance...")

## Step 3: Before/After Preview

Compare a **baseline prompt** (78%) vs an **optimized prompt** (92%) on 50 test samples.

In [None]:
# Step 3: Before/After Comparison
#
# Compare baseline vs optimized prompts on 50 test samples.
# The optimized prompt was discovered by GEPA - it achieves ~92% vs ~78% baseline.
# Uses Synth's hosted inference - no OPENAI_API_KEY needed!

BASELINE_SYSTEM_PROMPT = """You are an expert banking assistant that classifies customer queries into banking intents. Given a customer message, respond with exactly one intent label from the provided list using the `banking77_classify` tool."""

# This optimized prompt was discovered by GEPA - it adds classification strategy and key distinctions
OPTIMIZED_SYSTEM_PROMPT = """You are a precise banking intent classifier. Analyze customer queries and classify them into exactly one of the 77 predefined banking intents.

Classification Strategy:
1. IDENTIFY THE PRIMARY ACTION: What does the customer want to DO? (activate, cancel, check, transfer, verify, etc.)
2. IDENTIFY THE SUBJECT: What is it about? (card, transfer, payment, account, etc.)
3. IDENTIFY THE STATE: Is it about something pending, failed, declined, or completed?

Key Intent Distinctions:
- "card_arrival" vs "card_delivery_estimate": Both about card delivery. Use "card_arrival" for "where is my card?" and "card_delivery_estimate" for "how long will it take?"
- "get_physical_card" vs "order_physical_card": Use "order_physical_card" for placing an order, "get_physical_card" for asking HOW to get one
- "pending_*" intents: Transaction is IN PROGRESS, not yet complete
- "failed_*" or "declined_*" intents: Transaction was REJECTED
- "*_not_recognised" intents: Customer doesn't recognize a transaction on their statement
- "verify_*" intents: About verification/authentication processes
- "top_up_*" intents: About adding money TO the account
- "transfer_*" intents: About moving money between accounts

Output the single most appropriate intent using the banking77_classify tool."""

# Test on 50 held-out samples
TEST_INDICES = list(range(100, 150))

async def score_prompt(system_prompt: str, indices: list[int], prompt_name: str) -> float:
    """Score a prompt on a set of test samples."""
    correct = 0
    total = len(indices)

    for i, idx in enumerate(indices):
        sample = dataset[idx]
        query = sample["text"]
        expected = sample["category"]  # CSV has string labels directly

        predicted = await classify_banking77_query(
            query=query,
            system_prompt=system_prompt,
            model="gpt-4o-mini",
        )

        # Normalize for comparison
        pred_norm = predicted.lower().replace("_", " ").strip()
        exp_norm = expected.lower().replace("_", " ").strip()
        is_correct = pred_norm == exp_norm

        if is_correct:
            correct += 1

        if (i + 1) % 10 == 0:
            print(
                f"  {prompt_name}: {i + 1}/{total} done, {correct}/{i + 1} correct ({correct / (i + 1):.0%})"
            )

    accuracy = correct / total
    return accuracy

print(
    f"Testing on {len(TEST_INDICES)} samples (indices {TEST_INDICES[0]}-{TEST_INDICES[-1]})...\n"
)

print("Scoring BASELINE prompt...")
baseline_score = await score_prompt(BASELINE_SYSTEM_PROMPT, TEST_INDICES, "Baseline")

print("\nScoring OPTIMIZED prompt...")
optimized_score = await score_prompt(OPTIMIZED_SYSTEM_PROMPT, TEST_INDICES, "Optimized")

print("\n" + "=" * 60)
print("BEFORE/AFTER COMPARISON")
print("=" * 60)
print(f"\nBASELINE PROMPT:")
print(f'  "{BASELINE_SYSTEM_PROMPT[:80]}..."')
print(
    f"  Accuracy: {baseline_score:.0%} ({int(baseline_score * len(TEST_INDICES))}/{len(TEST_INDICES)})"
)

print(f"\nOPTIMIZED PROMPT (from GEPA):")
print(f'  "{OPTIMIZED_SYSTEM_PROMPT[:80]}..."')
print(
    f"  Accuracy: {optimized_score:.0%} ({int(optimized_score * len(TEST_INDICES))}/{len(TEST_INDICES)})"
)

lift = optimized_score - baseline_score
print(f"\nLIFT: {lift:+.0%}")

if lift > 0:
    print("\n>>> Better prompts = better results!")
    print(">>> Now let's see how Synth finds these optimized prompts...")

## Step 3: Local API

Expose the task via HTTP so Synth can run optimization against it.

In [None]:
from synth_ai.sdk.localapi import LocalAPIConfig, create_local_api
from synth_ai.sdk.localapi._impl.contracts import RolloutMetrics, RolloutRequest, RolloutResponse, TaskInfo
from synth_ai.sdk.localapi._impl.trace_correlation_helpers import extract_trace_correlation_id
from synth_ai.core.tunnels import TunnelBackend, TunneledLocalAPI
from synth_ai.data.enums import SuccessStatus

APP_ID = "banking77"
APP_NAME = "Banking77 Intent Classification"

BASELINE_SYSTEM_PROMPT = """You are an expert banking assistant that classifies customer queries into banking intents. Given a customer message, respond with exactly one intent label from the provided list using the `banking77_classify` tool."""

USER_PROMPT = "Customer Query: {query}\n\nAvailable Intents:\n{available_intents}\n\nClassify this query into one of the above banking intents using the tool call."


class Banking77Dataset:
    """Lazy dataset loader for Banking77."""
    # Load directly from GitHub CSV (HuggingFace dataset scripts no longer supported)
    _DATA_URLS = {
        "train": "https://raw.githubusercontent.com/PolyAI-LDN/task-specific-datasets/master/banking_data/train.csv",
        "test": "https://raw.githubusercontent.com/PolyAI-LDN/task-specific-datasets/master/banking_data/test.csv",
    }

    def __init__(self):
        self._cache = {}
        self._label_names = None

    def _load_split(self, split: str):
        if split not in self._cache:
            url = self._DATA_URLS.get(split)
            if not url:
                raise ValueError(f"Unknown split: {split}. Available: {list(self._DATA_URLS.keys())}")
            ds = load_dataset("csv", data_files=url, split="train")
            self._cache[split] = ds
            # Build label names from unique categories
            if self._label_names is None:
                self._label_names = sorted(set(ds["category"]))
        return self._cache[split]

    def ensure_ready(self, splits):
        """Ensure specified dataset splits are loaded."""
        for split in splits:
            self._load_split(split)

    def size(self, split: str) -> int:
        return len(self._load_split(split))

    def sample(self, *, split: str, index: int) -> dict:
        ds = self._load_split(split)
        idx = index % len(ds)
        row = ds[idx]
        # CSV has 'category' field with string labels
        label_text = row.get("category", "unknown")
        return {"index": idx, "split": split, "text": str(row.get("text", "")), "label": label_text}

    @property
    def label_names(self) -> list:
        if self._label_names is None:
            self._load_split("train")
        return self._label_names or []


def create_banking77_local_api(system_prompt: str):
    """Create a Banking77 local API for optimization."""

    dataset = Banking77Dataset()
    dataset.ensure_ready(["train", "test"])

    async def run_rollout(request: RolloutRequest, fastapi_request) -> RolloutResponse:
        split = request.env.config.get("split", "train")
        seed = request.env.seed
        sample = dataset.sample(split=split, index=seed)

        # Extract inference_url and api_key from policy config
        # During GEPA optimization, inference_url points to the interceptor
        policy_config = request.policy.config or {}
        inference_url = policy_config.get("inference_url")
        api_key = policy_config.get("api_key")

        # Call the classifier with the interceptor URL
        predicted_intent = await classify_banking77_query(
            query=sample["text"],
            system_prompt=system_prompt,
            available_intents=format_available_intents(dataset.label_names),
            model=policy_config.get("model", "gpt-4o-mini"),
            api_key=api_key,
            inference_url=inference_url,
        )

        expected_intent = sample["label"]
        is_correct = (
            predicted_intent.lower().replace("_", " ").strip()
            == expected_intent.lower().replace("_", " ").strip()
        )
        reward = 1.0 if is_correct else 0.0

        # Extract trace correlation ID from policy config/inference_url
        policy_cfg_for_trace = {
            key: value
            for key, value in policy_config.items()
            if key not in {"trace_correlation_id", "trace"}
        }
        trace_correlation_id = extract_trace_correlation_id(
            policy_config=policy_cfg_for_trace,
            inference_url=str(inference_url or ""),
        )

        return RolloutResponse(
            trace_correlation_id=trace_correlation_id,
            reward_info=RolloutMetrics(outcome_reward=reward),
            trace=None,
            inference_url=str(inference_url or ""),
            success_status=SuccessStatus.SUCCESS,
        )

    def provide_taskset_description():
        return {
            "splits": ["train", "test"],
            "sizes": {"train": dataset.size("train"), "test": dataset.size("test")},
        }

    def provide_task_instances(seeds):
        for seed in seeds:
            sample = dataset.sample(split="train", index=seed)
            yield TaskInfo(
                task={"id": APP_ID, "name": APP_NAME},
                dataset={"id": APP_ID, "split": sample["split"], "index": sample["index"]},
                inference={"tool": TOOL_NAME},
                limits={"max_turns": 1},
                task_metadata={"query": sample["text"], "expected_intent": sample["label"]},
            )

    return create_local_api(
        LocalAPIConfig(
            app_id=APP_ID,
            name=APP_NAME,
            description=f"{APP_NAME} local API for classifying customer queries into banking intents.",
            provide_taskset_description=provide_taskset_description,
            provide_task_instances=provide_task_instances,
            rollout=run_rollout,
            cors_origins=["*"],
        )
    )


print("Starting local API...")
baseline_app = create_banking77_local_api(BASELINE_SYSTEM_PROMPT)

# Create tunnel - handles server startup, health check, and tunnel creation automatically
print("\nStarting server and provisioning Cloudflare tunnel...")
baseline_tunnel = await TunneledLocalAPI.create_for_app(
    app=baseline_app,
    local_port=None,
    backend=TunnelBackend.CloudflareQuickTunnel,
    progress=True,
)
BASELINE_LOCAL_API_URL = baseline_tunnel.url

print(f"\n" + "=" * 50)
print("LOCAL API READY")
print("=" * 50)
print(f"URL: {BASELINE_LOCAL_API_URL}")

## Step 4: Run GEPA

GEPA evolves prompts over multiple generations, selecting the best performers.

In [None]:
import random
from synth_ai.sdk.optimization.internal.prompt_learning import PromptLearningJob

# Banking77 train split has 10,003 samples - sample randomly for better coverage
random.seed(42)
DATASET_SIZE = 10003
all_indices = list(range(DATASET_SIZE))
random.shuffle(all_indices)

# 100 train seeds, 200 validation seeds - randomly sampled, non-overlapping
TRAIN_SEEDS = all_indices[:100]
VALIDATION_SEEDS = all_indices[100:300]


async def run_gepa():
    config_body = {
        "prompt_learning": {
            "algorithm": "gepa",
            "task_app_url": BASELINE_LOCAL_API_URL,
            "env_name": "banking77",
            "initial_prompt": {
                "messages": [
                    {"role": "system", "order": 0, "pattern": BASELINE_SYSTEM_PROMPT},
                    {"role": "user", "order": 1, "pattern": USER_PROMPT},
                ],
                "wildcards": {"query": "REQUIRED", "available_intents": "OPTIONAL"},
            },
            "policy": {
                "model": "gpt-4.1-nano",
                "provider": "openai",
                "inference_mode": "synth_hosted",
                "temperature": 0.0,
                "max_completion_tokens": 256,
            },
            "gepa": {
                "env_name": "banking77",
                "evaluation": {
                    "seeds": TRAIN_SEEDS,  # 100 random training seeds
                    "validation_seeds": VALIDATION_SEEDS,  # 200 random validation seeds
                },
                "rollout": {
                    "budget": 300,  # Higher budget for more samples
                    "max_concurrent": 32,  # High parallelism for speed
                    "minibatch_size": 32,
                },
                "proposer_effort": "MEDIUM",
                "proposer_output_tokens": "FAST",
                "mutation": {"rate": 0.3},
                "population": {
                    "initial_size": 4,
                    "num_generations": 3,
                    "children_per_generation": 3,
                },
                "archive": {"size": 5, "pareto_set_size": 10},
            },
        },
    }

    print(f"Creating GEPA job...")
    print(f"  Train seeds: {len(TRAIN_SEEDS)} (randomly sampled)")
    print(f"  Validation seeds: {len(VALIDATION_SEEDS)} (randomly sampled)")
    print(f"  Parallelism: 32 concurrent rollouts")

    pl_job = PromptLearningJob.from_dict(
        config_dict=config_body,
        skip_health_check=True,
    )

    job_id = pl_job.submit()
    print(f"Job ID: {job_id}")
    print()

    # Use streaming for real-time progress updates
    result = await pl_job.stream_until_complete_async(timeout=3600.0)

    print(f"\nFINAL: {result.status.value}")

    if result.succeeded:
        print(f"BEST SCORE: {result.best_score}")
    elif result.failed:
        print(f"ERROR: {result.error}")

    return result


result = await run_gepa()

## Step 5: Evaluate

Compare baseline vs optimized prompts on held-out test samples.

In [None]:
from synth_ai.sdk.eval.job import EvalJob, EvalJobConfig, EvalResult

EVAL_SEEDS = list(range(100, 150))  # Held-out test samples


def run_eval_job(local_api_url: str, seeds: list[int], mode: str) -> EvalResult:
    """Run an eval job and wait for completion."""
    config = EvalJobConfig(
        local_api_url=local_api_url,
        backend_url=SYNTH_API_BASE,
        api_key=API_KEY,
        env_name="banking77",
        seeds=seeds,
        policy_config={"model": "gpt-4.1-nano", "provider": "openai"},
        env_config={"split": "test"},
        concurrency=10,
    )
    job = EvalJob(config)
    job.submit()
    return job.poll_until_complete(timeout=600.0, interval=2.0, progress=True)


if result.succeeded:
    # Get the optimized system prompt - one simple method call
    optimized_prompt = result.get_system_prompt()
    best_score = result.best_score

    if optimized_prompt:
        print("\n" + "=" * 60)
        print("OPTIMIZATION RESULTS")
        print("=" * 60)
        print(f"\nBest Train Reward: {best_score:.1%}")
        print(f"\nOptimized Prompt:")
        print(optimized_prompt[:400] + "..." if len(optimized_prompt) > 400 else optimized_prompt)

        # Create optimized API and run final evaluation
        optimized_app = create_banking77_local_api(optimized_prompt)
        optimized_tunnel = await TunneledLocalAPI.create_for_app(
            app=optimized_app,
            local_port=None,
            backend=TunnelBackend.CloudflareQuickTunnel,
            progress=True,
        )
        OPTIMIZED_LOCAL_API_URL = optimized_tunnel.url

        baseline_result = run_eval_job(BASELINE_LOCAL_API_URL, EVAL_SEEDS, "baseline")
        optimized_result = run_eval_job(OPTIMIZED_LOCAL_API_URL, EVAL_SEEDS, "optimized")

        # Final results
        if baseline_result.succeeded and optimized_result.succeeded:
            eval_lift = optimized_result.mean_reward - baseline_result.mean_reward
            print("\n" + "=" * 60)
            print("FINAL EVALUATION")
            print("=" * 60)
            print(f"Training Score:  {best_score:.1%}")
            print(
                f"Held-Out Test:   {optimized_result.mean_reward:.1%} (baseline: {baseline_result.mean_reward:.1%})"
            )
            print(f"Improvement:     {eval_lift:+.1%}")

            if eval_lift > 0:
                print("\n✓ Optimization generalizes to held-out data!")
            elif eval_lift == 0:
                print("\n= Same performance on held-out data")
            else:
                print("\n⚠ Possible overfitting (baseline better on held-out)")
    else:
        print("\n" + "=" * 60)
        print("OPTIMIZATION COMPLETED")
        print("=" * 60)
        print(f"\nBest Score: {result.best_score:.1%}")
        print("\nNote: Could not retrieve the optimized prompt text.")
else:
    print(f"Optimization failed: {result.error}")

In [None]:
from synth_ai.core.tunnels import cleanup_all

print("Cleaning up cloudflared processes...")
cleanup_all()
print("Demo complete!")