# Setup

## Import dependencies

In [1]:
%pip install -U dspy datasets tabulate duckdb pandas numpy ipywidgets "sqlglot[rs]" wandb --quiet

Note: you may need to restart the kernel to use updated packages.


In [2]:
import dspy
from datasets import load_dataset
import tabulate
import pandas as pd
import os
from dotenv import load_dotenv

In [3]:
load_dotenv(".env.local")
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
    raise ValueError("OPENAI_API_KEY not found in environment variables")

wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
    raise ValueError("WANDB_API_KEY not found in environment variables")

lm = dspy.LM("openai/gpt-5-mini", api_key=openai_api_key, temperature=1, max_tokens=16000)
dspy.configure(lm=lm)

## Load data

In [4]:
ds = load_dataset("gretelai/synthetic_text_to_sql")

# Set up DSPy

## Set up Signature and Modules

In [5]:
class ProblemDef(dspy.Signature):
    """You are a database expert. You are provided with context for how some table(s) were constructed, and a natural language prompt for what the user wants. Your job is to write a SQL query to provide them with the required data."""
    
    sql_context: str = dspy.InputField(description="SQL queries for creating the table(s) and loading some data")
    sql_prompt: str = dspy.InputField(description="User's natural language prompt")
    sql: str = dspy.OutputField(description="SQL query that delivers on the user's request. Format as code that can be directly run without any changes – do not use new lines or anything else of that sort.")

program = dspy.ChainOfThought(ProblemDef)

In [6]:
# !pip install duckdb pandas numpy sqlglot --quiet
import duckdb, pandas as pd, numpy as np, re
import sqlglot
from sqlglot import parse_one

_ORDER_BY = re.compile(r"\border\s+by\b", re.IGNORECASE)

def _split_sql_statements(script: str):
    out, buf, q = [], [], None
    i, n = 0, len(script)
    while i < n:
        ch = script[i]
        if q:
            buf.append(ch)
            if ch == q:
                if i + 1 < n and script[i+1] == q:
                    buf.append(script[i+1]); i += 1
                else:
                    q = None
        else:
            if ch in ("'", '"', "`"):
                q = ch; buf.append(ch)
            elif ch == ';':
                s = "".join(buf).strip()
                if s: out.append(s)
                buf = []
            else:
                buf.append(ch)
        i += 1
    tail = "".join(buf).strip()
    if tail: out.append(tail)
    return out

import re
from sqlglot import parse_one

_SQLITE_DATE_RE = re.compile(
    r"""\bdate\s*\(\s*'now'\s*(?:,\s*'([+-])\s*(\d+)\s*(year|month|day)s?'\s*)?\)""",
    re.IGNORECASE,
)
_SQLITE_DATETIME_RE = re.compile(
    r"""\bdatetime\s*\(\s*'now'\s*(?:,\s*'([+-])\s*(\d+)\s*(year|month|day|hour|minute|second)s?'\s*)?\)""",
    re.IGNORECASE,
)

def _normalize_sqlite_dates(sql: str) -> str:
    # date('now') or date('now','-1 year') -> CURRENT_DATE +/- INTERVAL 'N unit'
    def _date_subst(m):
        sign, num, unit = m.group(1), m.group(2), m.group(3)
        if not sign:  # just date('now')
            return "CURRENT_DATE"
        op = "-" if sign == "-" else "+"
        return f"CURRENT_DATE {op} INTERVAL '{num} {unit.lower()}'"
    sql = _SQLITE_DATE_RE.sub(_date_subst, sql)

    # datetime('now') / datetime('now','+/-N unit') -> CURRENT_TIMESTAMP +/- INTERVAL 'N unit'
    def _dt_subst(m):
        sign, num, unit = m.group(1), m.group(2), m.group(3)
        if not sign:
            return "CURRENT_TIMESTAMP"
        op = "-" if sign == "-" else "+"
        return f"CURRENT_TIMESTAMP {op} INTERVAL '{num} {unit.lower()}'"
    sql = _SQLITE_DATETIME_RE.sub(_dt_subst, sql)

    return sql

def _mysql_to_duckdb(stmt: str) -> str:
    s = _normalize_sqlite_dates(stmt)  # <-- NEW: normalize SQLite first
    try:
        return parse_one(s, read="mysql").sql(dialect="duckdb")
    except Exception:
        # minimal fallbacks for MySQLisms if parse fails
        s = re.sub(r"`([^`]+)`", r'"\1"', s)
        s = re.sub(
            r"DATE_SUB\s*\(\s*(CURRENT_DATE|NOW\(\))\s*,\s*INTERVAL\s+(\d+)\s+(YEAR|MONTH|DAY)\s*\)",
            lambda m: f"{'CURRENT_DATE' if m.group(1).startswith('CURRENT') else 'CURRENT_DATE'} - INTERVAL '{m.group(2)} {m.group(3).lower()}'",
            s, flags=re.IGNORECASE,
        )
        s = re.sub(
            r"DATE_ADD\s*\(\s*(CURRENT_DATE|NOW\(\))\s*,\s*INTERVAL\s+(\d+)\s+(YEAR|MONTH|DAY)\s*\)",
            lambda m: f"{'CURRENT_DATE' if m.group(1).startswith('CURRENT') else 'CURRENT_DATE'} + INTERVAL '{m.group(2)} {m.group(3).lower()}'",
            s, flags=re.IGNORECASE,
        )
        s = re.sub(r"\bIFNULL\s*\(", "COALESCE(", s, flags=re.IGNORECASE)
        s = re.sub(r"\bLOCATE\s*\(\s*([^,]+)\s*,\s*([^)]+)\)", r"STRPOS(\2, \1)", s, flags=re.IGNORECASE)
        return s

def _normalize_df(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == "O":
            try:
                df[c] = pd.to_numeric(df[c])
            except Exception:
                pass
    return df.replace({np.nan: None})

def _exec_script_capture_last_select(con, script: str):
    last_df, last_sel_sql = None, None
    for raw in _split_sql_statements(script):
        stmt = _mysql_to_duckdb(raw)
        # detect SELECT after minimal comment strip
        s = re.sub(r"^\s*(--[^\n]*\n|/\*.*?\*/\s*)*", "", stmt, flags=re.DOTALL)
        if re.match(r"(?is)^\s*(with\b.*?select|select)\b", s):
            last_df = con.execute(stmt).fetchdf()
            last_sel_sql = stmt
        else:
            con.execute(stmt)
    if last_df is not None:
        last_df = _normalize_df(last_df)
    return last_df, last_sel_sql

def evaluate_sql(sql_context: str, golden_sql: str, predicted_sql: str):
    con = duckdb.connect(":memory:")

    # context
    try:
        for raw in _split_sql_statements(sql_context):
            con.execute(_mysql_to_duckdb(raw))
    except Exception as e:
        return 0, {"reason": "context_error", "detail": str(e)}

    # golden
    try:
        gold_df, gold_last_select = _exec_script_capture_last_select(con, golden_sql)
    except Exception as e:
        return 0, {"reason": "gold_error", "detail": str(e)}
    if gold_df is None:
        return 0, {"reason": "gold_no_select", "detail": "No SELECT in golden_sql."}

    # predicted
    try:
        pred_df, pred_last_select = _exec_script_capture_last_select(con, predicted_sql)
    except Exception as e:
        return 0, {"reason": "pred_error", "detail": str(e)}
    if pred_df is None:
        return 0, {"reason": "pred_no_select", "detail": "No SELECT in predicted_sql."}

    # column alignment (allow pred supersets; else try set/positional)
    gold_cols, pred_cols = list(gold_df.columns), list(pred_df.columns)
    if gold_cols == pred_cols:
        pass
    elif set(gold_cols).issubset(pred_cols):
        pred_df = pred_df[gold_cols]
    elif set(gold_cols) == set(pred_cols):
        pred_df = pred_df[gold_cols]
    elif gold_df.shape[1] == pred_df.shape[1]:
        new_names = [f"c{i}" for i in range(gold_df.shape[1])]
        gold_df = gold_df.copy(); pred_df = pred_df.copy()
        gold_df.columns = new_names; pred_df.columns = new_names
    else:
        return 0, {"reason": "column_mismatch",
                   "detail": f"Different number of columns: expected {gold_df.shape[1]}, got {pred_df.shape[1]}"}

    # ordering rule from gold's last SELECT
    gold_has_order = bool(_ORDER_BY.search(gold_last_select or ""))
    if not gold_has_order:
        try:
            g = gold_df.sort_values(by=list(gold_df.columns), kind="mergesort").reset_index(drop=True)
            p = pred_df.sort_values(by=list(gold_df.columns), kind="mergesort").reset_index(drop=True)
        except Exception:
            g = gold_df.reset_index(drop=True); p = pred_df.reset_index(drop=True)
    else:
        g = gold_df.reset_index(drop=True); p = pred_df.reset_index(drop=True)

    # value compare
    if g.shape != p.shape:
        return 0, {"reason": "shape_mismatch", "detail": f"gold {g.shape} vs pred {p.shape}"}

    for c in g.columns:
        if pd.api.types.is_numeric_dtype(g[c]) and pd.api.types.is_numeric_dtype(p[c]):
            if not np.allclose(g[c].values, p[c].values, rtol=1e-6, atol=1e-8, equal_nan=True):
                return 0, {"reason": "value_mismatch", "detail": f"Numeric mismatch in '{c}'",
                           "gold_head": g.head(10).to_dict("records"),
                           "pred_head": p.head(10).to_dict("records")}
        else:
            eq = [(x == y) or (x is None and y is None) for x, y in zip(g[c].values, p[c].values)]
            if not all(eq):
                return 0, {"reason": "value_mismatch", "detail": f"Mismatch in '{c}'",
                           "gold_head": g.head(10).to_dict("records"),
                           "pred_head": p.head(10).to_dict("records")}
    return 1, None


## Test

In [7]:
demo_index = 4
context = ds['train'][demo_index]['sql_context']
prompt = ds['train'][demo_index]['sql_prompt']
golden_sql = ds['train'][demo_index]['sql']

print(f"Context: {context}")
print(f"Prompt: {prompt}")
print(f"Golden sql: {golden_sql}")
result = program(sql_context=context, sql_prompt=prompt)
print(result)

Context: CREATE TABLE upgrades (id INT, cost FLOAT, type TEXT); INSERT INTO upgrades (id, cost, type) VALUES (1, 500, 'Insulation'), (2, 1000, 'HVAC'), (3, 1500, 'Lighting');
Prompt: Find the energy efficiency upgrades with the highest cost and their types.
Golden sql: SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1;
Prediction(
    reasoning='We need the upgrade(s) that have the maximum cost. Use a subquery to get MAX(cost) and return rows matching that value (including id, type, and cost).',
    sql='SELECT id, type, cost FROM upgrades WHERE cost = (SELECT MAX(cost) FROM upgrades);'
)


In [8]:
score, info = evaluate_sql(context, golden_sql, result.sql)
print(score, info)


1 None


## Environment didn't work, let's use LLM as Judge

In [9]:
class Judge(dspy.Signature):
    """You are required to judge two SQL queries for functional similarity. You will be given a context of how the table(s) and data were created, and the natural language prompt from the user"""

    sql_context: str = dspy.InputField(description="SQL statement(s) creating the table(s) and the input data")
    sql_prompt: str = dspy.InputField(description="Natural language prompt from the user")
    golden_sql: str = dspy.InputField(description="The golden SQL query from our dataset")
    candidate_sql: str = dspy.InputField(description="A SQL query generated by a model for the same prompt")
    similar: bool = dspy.OutputField(description="True if the candidate SQL query is functionally similar to the golden SQL query")

judge = dspy.ChainOfThought(Judge)
    

In [10]:
judge_response = judge(sql_context=context, sql_prompt=prompt, golden_sql=golden_sql, candidate_sql=result.sql)
print(f"Context: {context}")
print(f"Prompt: {prompt}")
print(f"Golden SQL: {golden_sql}")
print(f"Candidate SQL: {result.sql}")
print(f"Judge Response: {judge_response}")


Context: CREATE TABLE upgrades (id INT, cost FLOAT, type TEXT); INSERT INTO upgrades (id, cost, type) VALUES (1, 500, 'Insulation'), (2, 1000, 'HVAC'), (3, 1500, 'Lighting');
Prompt: Find the energy efficiency upgrades with the highest cost and their types.
Golden SQL: SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1;
Candidate SQL: SELECT id, type, cost FROM upgrades WHERE cost = (SELECT MAX(cost) FROM upgrades);
Judge Response: Prediction(
    reasoning='Both queries return the upgrade(s) that have the maximum cost and include the type and cost information. Differences:\n- The candidate also returns the id column (extra column not present in the golden query).\n- The golden query uses ROW_NUMBER() and will return a single row (even if there are ties), whereas the candidate uses cost = MAX(cost) and will return all rows that tie for the maximum cost.\n\nDespite these differences in returned columns and tie-handling

# Get ready to GEPA

In [11]:
# pip install datasets dspy-ai
import math, random
from typing import Callable, List, Tuple, Optional
from datasets import Dataset, DatasetDict
from dspy import GEPA

def split_for_gepa(
    ds: Dataset,
    to_example: Callable[[dict], "dspy.Example"],
    val_size: float = 0.15,
    seed: int = 42,
    group_col: Optional[str] = None,
    stratify_col: Optional[str] = None,
) -> Tuple[List["dspy.Example"], List["dspy.Example"]]:
    """
    Return (train_set, val_set) as lists of dspy.Example.
    - If group_col is set: group-aware split (no group leakage).
    - Else if stratify_col is set: use HF stratified split.
    - Else: random split.
    """
    assert 0.0 < val_size < 1.0, "val_size must be in (0,1)"
    rng = random.Random(seed)

    # --- Group-aware split (preferred for text2sql) ---
    if group_col:
        groups = ds[group_col]
        # Build group -> indices
        g2idx = {}
        for i, g in enumerate(groups):
            g2idx.setdefault(g, []).append(i)
        uniq_groups = list(g2idx.keys())
        rng.shuffle(uniq_groups)
        n_val_groups = max(1, math.floor(val_size * len(uniq_groups)))
        val_groups = set(uniq_groups[:n_val_groups])

        val_idx = [i for g in val_groups for i in g2idx[g]]
        train_idx = [i for g in uniq_groups[n_val_groups:] for i in g2idx[g]]

        # Edge case: if a group is gigantic, ensure both splits non-empty
        if not train_idx or not val_idx:
            # fallback: plain random split
            perm = list(range(len(ds)))
            rng.shuffle(perm)
            cut = max(1, math.floor(val_size * len(ds)))
            val_idx, train_idx = perm[:cut], perm[cut:]

        ds_train = ds.select(train_idx)
        ds_val = ds.select(val_idx)

    # --- Stratified split (when you have a label/cluster column) ---
    elif stratify_col:
        # HF does stratify on categorical-like columns
        parts: DatasetDict = ds.train_test_split(
            test_size=val_size,
            seed=seed,
            stratify_by_column=stratify_col,
        )
        ds_train, ds_val = parts["train"], parts["test"]

    # --- Simple random split ---
    else:
        parts: DatasetDict = ds.train_test_split(test_size=val_size, seed=seed)
        ds_train, ds_val = parts["train"], parts["test"]

    # Map to dspy.Example lists
    train_set = [to_example(r) for r in ds_train]
    val_set = [to_example(r) for r in ds_val]
    return train_set, val_set

def to_dspy_example(row):
    # mark inputs; leave gold 'sql' as label
    return dspy.Example(
        sql_prompt=row["sql_prompt"],
        sql_context=row["sql_context"],
        sql=row["sql"],          # gold label
    ).with_inputs("sql_prompt", "sql_context")


# call function that splits ds['train'] into train_set and val_set as needed
# ds is your loaded HF dataset dict; we split ds["train"]
train_set, val_set = split_for_gepa(
    ds["train"],
    to_dspy_example,          # your to_dspy_example(row)
    val_size=0.5,
    seed=42,
    group_col=None,      # e.g., "db_id" if available
    stratify_col=None,   # or a column like "op_class" if you want stratification
)

In [12]:
max_variants_to_try = 20 # number of variants to test
mini_batch_size = 20 # mini-batch size
val_set_size = 200 # val-set size
train_set_size = 200 # train-set size

def budget_for_variants(N, V, k, slack=2):
    # slack handles occasional extra probes/promotions
    return V + N * (k + slack)

def metric_with_feedback(example, pred, trace=None, pred_name=None, pred_trace=None):
    judge_response = judge(sql_context=example.sql_context, sql_prompt=example.sql_prompt, golden_sql=example.sql, candidate_sql=pred.sql)
    score = 0
    if (judge_response.similar):
        score = 1
    return dspy.Prediction(score=score, feedback=judge_response.reasoning)

val_for_tracking = val_set[:val_set_size]   # 128–512 is a good range
train_set_for_optimization = train_set[:train_set_size]
optimizer = GEPA(
    metric=metric_with_feedback,
    num_threads=32,
    track_stats=True,
    reflection_minibatch_size=mini_batch_size,
    reflection_lm=lm,
    use_wandb=True,
    wandb_api_key=wandb_api_key,
    log_dir="logs",
    auto="light"   
)

# Run GEPA

In [13]:
optimized_program = optimizer.compile(
    program,
    trainset=train_set_for_optimization,
    valset=val_for_tracking,
)

2025/10/14 10:04:55 INFO dspy.teleprompt.gepa.gepa: Running GEPA for approx 1180 metric calls of the program. This amounts to 2.95 full evals on the train+val set.
2025/10/14 10:04:55 INFO dspy.teleprompt.gepa.gepa: Using 200 examples for tracking Pareto scores. You can consider using a smaller sample of the valset to allow GEPA to explore more diverse solutions within the same budget.


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/raveesh/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mraveeshbhalla90[0m ([33mraveeshbhalla90-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [dspy, litellm, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/
GEPA Optimization:   0%|          | 0/1180 [00:00<?, ?rollouts/s]2025/10/14 10:04:58 INFO dspy.evaluate.evaluate: Average Metric: 111.0 / 200 (55.5%)
2025/10/14 10:04:58 INFO dspy.teleprompt.gepa.gepa: Iteration 0: Base program full valset score: 0.555
GEPA Optimization:  17%|█▋        | 200/1180 [00:01<00:06, 140.08rollouts/s]2025/10/14 10:04:58 INFO dspy.teleprompt.gepa.gepa: Iteration 1: Selected program 0 score: 0.555


Average Metric: 14.00 / 20 (70.0%): 100%|██████████| 20/20 [00:38<00:00,  1.90s/it] 

2025/10/14 10:05:36 INFO dspy.evaluate.evaluate: Average Metric: 14.0 / 20 (70.0%)





2025/10/14 10:06:37 INFO dspy.teleprompt.gepa.gepa: Iteration 1: Proposed new text for predict: You are a SQL expert assistant whose job is: given (1) a short natural-language request (sql_prompt) and (2) a small SQL schema and seed data snippet (sql_context), produce a single correct SQL statement that answers the prompt plus a brief reasoning explanation of how you formed the query.

Output format (always):
- A short "reasoning" paragraph describing the approach, assumptions, and any edge-cases handled.
- The SQL statement (runnable against the provided schema). Do not produce extraneous SQL or multiple alternative queries unless the prompt explicitly asks for options.

General rules and intent:
1. Preserve the semantics of the user's request exactly. Do not add, remove, or weaken required filters/constraints unless:
   - The schema does not contain needed columns/tables, in which case explicitly state the assumption you must make.
   - The prompt is ambiguous and you state the chose

Average Metric: 14.00 / 20 (70.0%): 100%|██████████| 20/20 [00:54<00:00,  2.75s/it] 

2025/10/14 10:11:37 INFO dspy.evaluate.evaluate: Average Metric: 14.0 / 20 (70.0%)





2025/10/14 10:13:16 INFO dspy.teleprompt.gepa.gepa: Iteration 2: Proposed new text for predict: You are a SQL expert assistant. You receive two inputs:
- sql_prompt: a short natural-language request describing exactly what SQL result or change the user wants.
- sql_context: a small SQL schema and (optional) seed data snippet, using CREATE TABLE / INSERT statements that define available tables, columns, and example values.

Your job: produce exactly two labeled sections in this order:
1) reasoning — a concise 1–4 sentence paragraph describing your approach, any assumptions, and any non-obvious decisions or edge-cases handled.
2) sql — a single SQL statement (runnable against the provided schema) that implements the request.

Formatting and content rules (follow strictly):

Overall output format
- Always output exactly two sections prefixed by the words "reasoning" and "sql" (lowercase), each followed by a newline and then the content. Do not include additional sections, explanations, or

Average Metric: 13.00 / 20 (65.0%): 100%|██████████| 20/20 [00:52<00:00,  2.63s/it]

2025/10/14 10:15:24 INFO dspy.evaluate.evaluate: Average Metric: 13.0 / 20 (65.0%)





2025/10/14 10:16:19 INFO dspy.teleprompt.gepa.gepa: Iteration 3: Proposed new text for predict: You are a SQL expert assistant. For every request you receive you must follow the precise input/output contract and rules below.

Input format (what the user will provide):
- sql_prompt: a short natural-language request describing the desired result.
- sql_context: a small SQL schema and optional seed INSERTs that define available tables, columns, types and sample data.

Required output format (always):
1. A "reasoning" paragraph (1–4 short sentences) that:
   - briefly explains the approach and any assumptions or interpretation choices you made,
   - explicitly names any edge-cases handled (e.g., empty-result handling, ties, case-insensitivity),
   - states any schema gaps and the assumption you make if you must invent a column/table (or clearly ask for clarification instead).
2. The SQL statement, labeled "sql", containing exactly one runnable SQL statement (no extra statements, no transac

Average Metric: 12.00 / 20 (60.0%): 100%|██████████| 20/20 [00:56<00:00,  2.81s/it]

2025/10/14 10:19:24 INFO dspy.evaluate.evaluate: Average Metric: 12.0 / 20 (60.0%)





2025/10/14 10:20:50 INFO dspy.teleprompt.gepa.gepa: Iteration 4: Proposed new text for predict: You are a SQL expert assistant. For each task you are given, you will receive:
- sql_prompt: a short natural-language request describing exactly what SQL the user wants.
- sql_context: a small SQL schema and optional seed data snippet that defines available tables, columns, types and example values.

Your job: produce exactly two things in the exact order and format below (no extra text, no extra SQL statements):
1) A brief "reasoning" paragraph (1–4 sentences) describing the approach, any assumptions or ambiguity resolutions, and any important edge-cases handled.
2) A single SQL statement that implements the request and is runnable against the provided schema.

Output format (must match exactly):
reasoning
<1–4 concise sentences>

sql
<single SQL statement>

Hard rules and behavior (must follow these precisely):
- Preserve the prompt's semantics exactly. Do not add, remove, or weaken filter

Average Metric: 12.00 / 20 (60.0%): 100%|██████████| 20/20 [01:19<00:00,  4.00s/it]

2025/10/14 10:27:53 INFO dspy.evaluate.evaluate: Average Metric: 12.0 / 20 (60.0%)





2025/10/14 10:29:01 INFO dspy.teleprompt.gepa.gepa: Iteration 5: Proposed new text for predict: You are a SQL expert assistant. For each task you are given you will receive two inputs:
- sql_prompt: a short natural-language request describing exactly what SQL the user wants.
- sql_context: a small SQL schema and optional seed data snippet that defines available tables, columns, types and example values.

Your job: produce exactly two things in the exact order and exact textual format below (no extra text, no extra SQL statements, no comments, no prose outside these labels):
1) A brief "reasoning" paragraph of 1–4 concise sentences describing approach, any assumptions or ambiguity resolutions, and any important edge-cases handled.
2) A single SQL statement that implements the request and is runnable against the provided schema.

Output format (must match exactly):
reasoning
<1–4 concise sentences>

sql
<single SQL statement>

Hard rules you must follow exactly:
- Do not output anything 

Average Metric: 12.00 / 20 (60.0%): 100%|██████████| 20/20 [00:59<00:00,  2.98s/it]

2025/10/14 10:36:13 INFO dspy.evaluate.evaluate: Average Metric: 12.0 / 20 (60.0%)





2025/10/14 10:37:02 INFO dspy.teleprompt.gepa.gepa: Iteration 6: Proposed new text for predict: You are a SQL expert assistant. For every task you are given you will receive two inputs:
- sql_prompt: a short natural-language request describing exactly what SQL the user wants.
- sql_context: a small SQL schema and optional seed data snippet that defines available tables, columns, types and example values.

Your job: produce exactly two things in the exact order and format below (no extra text, no extra SQL statements):
1) A concise "reasoning" paragraph of 1–4 sentences describing approach, any assumptions or ambiguity resolutions, and important edge-cases handled.
2) A single SQL statement that implements the request and is runnable against the provided schema.

Output format (must match exactly):
reasoning
<1–4 concise sentences>

sql
<single SQL statement>

Strict rules you must follow (read carefully — these are mandatory):

General behavior
- Preserve the prompt's semantics exactly

Average Metric: 13.00 / 20 (65.0%): 100%|██████████| 20/20 [01:06<00:00,  3.30s/it] 

2025/10/14 10:39:38 INFO dspy.evaluate.evaluate: Average Metric: 13.0 / 20 (65.0%)





2025/10/14 10:40:38 INFO dspy.teleprompt.gepa.gepa: Iteration 7: Proposed new text for predict: You are a SQL expert assistant. For every request you receive, you will be given:
- sql_prompt: a short natural-language request describing the SQL the user wants, and
- sql_context: a small SQL schema + seed-data snippet (CREATE TABLE / INSERT statements) that defines available tables/columns and sample values.

Your job: produce exactly two things in this exact output format (and nothing else):
1) A short "reasoning" paragraph (1–4 sentences) that explains your approach, any assumptions you made, how you handled edge-cases, and any non-obvious decisions (e.g., use of COALESCE, case-insensitive matching, dialect-specific features). Keep this concise.
2) A single SQL statement (runnable against the provided schema). Precede it with the token "sql" on its own line. The SQL must be one statement only (unless the user's prompt explicitly asks for multiple statements). Do NOT output additional S

# Review original and optimized prompts

In [14]:
print(program.predict.signature.instructions)

You are a database expert. You are provided with context for how some table(s) were constructed, and a natural language prompt for what the user wants. Your job is to write a SQL query to provide them with the required data.


In [15]:
print(optimized_program.predict.signature.instructions)

You are a SQL expert assistant. For every request you receive, you will be given:
- sql_prompt: a short natural-language request describing the SQL the user wants, and
- sql_context: a small SQL schema + seed-data snippet (CREATE TABLE / INSERT statements) that defines available tables/columns and sample values.

Your job: produce exactly two things in this exact output format (and nothing else):
1) A short "reasoning" paragraph (1–4 sentences) that explains your approach, any assumptions you made, how you handled edge-cases, and any non-obvious decisions (e.g., use of COALESCE, case-insensitive matching, dialect-specific features). Keep this concise.
2) A single SQL statement (runnable against the provided schema). Precede it with the token "sql" on its own line. The SQL must be one statement only (unless the user's prompt explicitly asks for multiple statements). Do NOT output additional SQL statements, transaction commands (COMMIT), or alternative queries.

Output layout exactly:
re

# Evals

In [14]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import Dataset
from time import perf_counter
from typing import Dict, Any, Optional

def evaluate_program(
    program,
    ds_test: Dataset,
    limit: int = 100,
    max_workers: int = 8,
    field_map: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
    """
    Evaluate a DSPy program on the first `limit` rows of a HF Dataset split.

    Args:
        program: a DSPy Module with signature program(sql_prompt=..., sql_context=...)
        ds_test: Hugging Face Dataset (e.g., ds["test"])
        limit: number of rows to evaluate (default 100)
        max_workers: parallel threads for I/O-bound LM + judge
        field_map: optional mapping if your column names differ:
                   {"sql_prompt": "...", "sql_context": "...", "sql": "..."}

    Returns:
        {
          "accuracy": float,
          "correct": int,
          "total": int,
          "avg_latency_s": float,
          "failures": [ {idx, reason, pred_sql, feedback} ... up to 20 ],
        }
    """
    if field_map is None:
        field_map = {"sql_prompt": "sql_prompt", "sql_context": "sql_context", "sql": "sql"}

    ds_test = ds_test.shuffle()
    n = min(limit, len(ds_test))
    subset = ds_test.select(range(n))
    start = perf_counter()

    def _eval_one(i_row):
        i, row = i_row
        try:
            pred = program(
                sql_prompt=row[field_map["sql_prompt"]],
                sql_context=row[field_map["sql_context"]],
            )
            pred_sql = getattr(pred, "sql", None) or (pred.get("sql") if isinstance(pred, dict) else None) or ""
            jr = judge(
                sql_context=row[field_map["sql_context"]],
                sql_prompt=row[field_map["sql_prompt"]],
                golden_sql=row[field_map["sql"]],
                candidate_sql=pred_sql,
            )
            ok = bool(getattr(jr, "similar", False))
            feedback = getattr(jr, "reasoning", "") or ""
            return (i, ok, pred_sql, feedback, None)
        except Exception as e:
            return (i, False, "", "", f"{type(e).__name__}: {e}")

    results = []
    # Threaded evaluation (I/O bound: LM + judge). Tune max_workers to your provider limits.
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(_eval_one, (i, subset[i])) for i in range(n)]
        for f in as_completed(futures):
            results.append(f.result())

    # Sort back to input order
    results.sort(key=lambda x: x[0])

    correct = sum(1 for _, ok, *_ in results if ok)
    total = n
    acc = correct / total if total else 0.0
    elapsed = perf_counter() - start
    avg_lat = elapsed / total if total else 0.0

    failures = []
    for i, ok, pred_sql, feedback, err in results:
        if not ok and len(failures) < 20:
            failures.append({
                "idx": i,
                "reason": ("error: " + err) if err else "mismatch",
                "pred_sql": pred_sql,
                "feedback": feedback,
            })

    return {
        "accuracy": acc,
        "correct": correct,
        "total": total,
        "avg_latency_s": avg_lat,
        "failures": failures,
    }
    
test_split = ds["test"]
test_split = test_split.shuffle()

In [None]:
# Evaluate original and optimized on ds["test"][:100]

orig_metrics = evaluate_program(program, test_split, limit=500, max_workers=32)

print("Original:", orig_metrics["accuracy"], f"({orig_metrics['correct']}/{orig_metrics['total']})")

Original: 0.634 (317/500)


In [18]:
opt_metrics  = evaluate_program(optimized_program, test_split, limit=500, max_workers=32)
print("Optimized:", opt_metrics["accuracy"], f"({opt_metrics['correct']}/{opt_metrics['total']})")

Optimized: 0.646 (323/500)


# OpenAI optimized

In [19]:
class OpenAIOptimized(dspy.Signature):
    """
    Developer: # Role and Objective
    You are a SQL expert assistant. For every request, you will receive two inputs:
    - `sql_prompt`: a concise natural-language description of the SQL the user wants.
    - `sql_context`: a small SQL schema and seed data (CREATE TABLE/INSERT statements) that define available tables, columns, and example values.

    # Instructions
    Begin with a concise checklist (3–7 bullets) outlining your planned approach: key sub-tasks or decision points for producing the SQL statement. Keep these conceptual, not implementation-level. For every incoming request, produce your response in the following exact format (and nothing else):
    1. **reasoning**
    - A short (1–4 sentences) paragraph explaining your approach, key assumptions, handling of edge cases, and any noteworthy decisions (e.g., use of COALESCE, case-insensitive matching, dialect-specific features). Keep this concise.

    2. **sql**
    - The SQL statement that fulfills the prompt using only the context provided. Start this section with the word `sql` on its own line, followed by the query. Only output a single SQL statement unless the prompt explicitly requests multiple statements.

    Example:

    reasoning
    <Concise, 1–4 sentence paragraph with approach and key decisions.>

    sql
    <Single SQL statement>

    ---

    # Core Rules, Constraints, and Style

    A. **Preserve user semantics strictly**
    - Do not add, remove, or modify required filters, joins, grouping, or result structures.
    - Deviate only if necessary due to missing schema elements or ambiguity—explicitly state any such assumptions in reasoning.
    - Preserve output shape as implied by the prompt (e.g., single scalar vs. grouped rows).

    B. **Use standard and portable SQL where possible**
    - Prefer ANSI SQL features. If using dialect-specific syntax (e.g., INTERVAL, DATE_SUB), mention it in reasoning.

    C. **DML Statements**
    - For INSERT: always use explicit column lists unless the schema provides explicit column order and prompt implies omission.
    - For UPDATE/DELETE: exactly mirror the WHERE clauses the user specifies.
    - If the prompt requests multiple dependent statements (e.g., insert parent and then child row), either:
    - Produce a single statement if possible,
    - State assumptions (e.g., chosen IDs), or
    - Ask for user clarification in reasoning if no safe assumption can be made.

    D. **Aggregation, grouping, and expected result shape**
    - Use GROUP BY and HAVING only as the prompt implies.
    - Use COALESCE(..., 0) for aggregates only if the user is likely to expect zero instead of NULL, and note this in reasoning.

    E. **Joins**
    - Use LEFT JOIN when prompted for parent rows that may have no children; otherwise, use INNER JOIN. Justify your choice in reasoning if ambiguous.

    F. **Anti-joins**
    - Prefer NOT EXISTS for anti-join semantics; use LEFT JOIN ... WHERE child.key IS NULL where appropriate. Note potential NULL handling issues in reasoning.

    G. **Case-insensitive matches and NULL handling**
    - Implement case-insensitive matches using LOWER(column) LIKE '%term%', stating this in reasoning.
    - Quote string/date literals with single quotes.
    - Mention if NULLs are handled/converted by your SQL.

    H. **Dates and times**
    - Use the schema’s date literal format (prefer 'YYYY-MM-DD').
    - For relative dates or ambiguous ranges, state your interpretation clearly.

    I. **Top-N and ordering**
    - Use ORDER BY ... DESC LIMIT n (or FETCH FIRST n ROWS ONLY for ANSI SQL). Mention your choice in reasoning if non-standard.

    J. **Missing schema requirements**
    - Never invent tables or columns. If critical elements are missing, either:
    - Explicitly state your assumption and proceed, or
    - Ask for user clarification in reasoning.

    K. **Output restrictions**
    - Only emit a single SQL statement (unless multiple are explicitly requested), and never output transaction commands, alternative queries, or extraneous commentary.
    - Keep reasoning concise, informative, and within 1–4 sentences.

    # Common SQL Patterns to Use (as appropriate)
    - Aggregation: `SELECT SUM(col) AS total FROM table WHERE ...;`
    - Conditional aggregation: `SUM(CASE WHEN condition THEN value ELSE 0 END)`
    - Including unmatched parents: `FROM parent LEFT JOIN child ... GROUP BY parent.id`
    - Anti-join: `WHERE NOT EXISTS (SELECT 1 FROM child ... )`
    - Percentages: `ROUND(100.0 * COUNT(*) / SUM(COUNT(*)) OVER (), 2)`
    - Case-insensitive search: `LOWER(name) LIKE '%term%'`
    - Null-to-zero: `COALESCE(SUM(...), 0)`
    - Top-N: `ORDER BY ... DESC LIMIT N`

    # Explicit Pitfalls to Avoid
    - Never omit essential WHERE, GROUP BY, or HAVING clauses.
    - Do not change result shape (e.g., returning a scalar when a multi-row result is expected).
    - Never return multiple SQL statements, unless explicitly requested.
    - Never invent missing schema elements—justify any assumed values in reasoning.

    # Behavior on Ambiguity
    - If the user request or schema is ambiguous, clearly state your interpretation in reasoning and proceed accordingly.
    - Limit assumptions to what is evident; document all assumptions.

    # Output Format
    - Adhere strictly to: "reasoning" (1–4 sentences explaining approach and choices), then a blank line, then "sql" followed by the single SQL statement.

    # Verbosity
    - Default mode: concise, clear, and explicit parsing of requirements in reasoning.
    - For code: maintain clarity—use readable structure, explicit references, and comments when appropriate.

    # Stop Conditions
    - Hand back output immediately after producing the reasoning and SQL sections in exact format. Do not add further commentary or suggestions.

    After forming the SQL statement, briefly validate that the chosen SQL matches all explicit user requirements and schema details. If issues are found, correct and update your response accordingly. If requirements are ambiguous or cannot be met, explicitly state this in reasoning and request clarification.
    """

    sql_context: str = dspy.InputField(description="SQL queries for creating the table(s) and loading some data")
    sql_prompt: str = dspy.InputField(description="User's natural language prompt")
    sql: str = dspy.OutputField(description="SQL query that delivers on the user's request. Format as code that can be directly run without any changes – do not use new lines or anything else of that sort.")

openai_optimized = dspy.ChainOfThought(OpenAIOptimized)

In [None]:
# Evaluate original and optimized on ds["test"][:100]
orig_metrics = evaluate_program(openai_optimized, test_split, limit=500, max_workers=32)

print("Optimized:", orig_metrics["accuracy"], f"({orig_metrics['correct']}/{orig_metrics['total']})")



Optimized: 0.542 (271/500)


# Store programs and run evals again

In [35]:
program.save("./dspy_program/program.json", save_program=False)
optimized_program.save("./optimized_program/program.json", save_program=False)
openai_optimized.save("./openai_optimized_program/program.json", save_program=False)

## Reload programs and run evals

In [15]:
og_program = dspy.ChainOfThought(ProblemDef)
og_program.load("./dspy_program/program.json")
og_metrics = evaluate_program(og_program, test_split, limit=500, max_workers=32)
print(f"Original Program: {og_metrics}")

Original Program: {'accuracy': 0.584, 'correct': 292, 'total': 500, 'avg_latency_s': 0.535394384418003, 'failures': [{'idx': 0, 'reason': 'mismatch', 'pred_sql': "SELECT (SELECT COUNT(*) FROM disability_services.students WHERE accommodation IS NOT NULL AND TRIM(accommodation) <> '' AND LOWER(TRIM(accommodation)) <> 'accessibility_parking') + (SELECT COUNT(*) FROM disability_services.staff WHERE accommodation IS NOT NULL AND TRIM(accommodation) <> '' AND LOWER(TRIM(accommodation)) <> 'accessibility_parking') AS total_with_accommodations;", 'feedback': "The two queries are not functionally equivalent.\n\n- Output shape: The golden SQL uses UNION of two COUNT(*) queries, returning two rows (one count for staff, one for students). The candidate returns a single scalar (the sum of the two counts) as total_with_accommodations — matching the prompt's wording for a single total, but different from the golden's two-row result.\n- Filtering differences: The golden WHERE clause only excludes rows

In [16]:
opt_program = dspy.ChainOfThought(ProblemDef)
opt_program.load("./optimized_program/program.json")
opt_metrics = evaluate_program(opt_program, test_split, limit=500, max_workers=32)
print(f"Original Program: {opt_metrics}")

Original Program: {'accuracy': 0.642, 'correct': 321, 'total': 500, 'avg_latency_s': 0.9665514680820052, 'failures': [{'idx': 0, 'reason': 'mismatch', 'pred_sql': 'sql\nSELECT country, COUNT(*) AS num_factories FROM FairTradeFactories GROUP BY country HAVING COUNT(*) = (SELECT MAX(cnt) FROM (SELECT COUNT(*) AS cnt FROM FairTradeFactories GROUP BY country) AS sub);', 'feedback': 'The golden query returns the top 5 countries by factory count (ordered descending, limited to 5). The candidate query returns only the country or countries whose count equals the overall maximum (i.e., the countries tied for the single highest count). These behaviors differ (candidate yields only the top tied maxima, not the top 5), so they are not functionally equivalent.'}, {'idx': 5, 'reason': 'mismatch', 'pred_sql': 'sql\nSELECT o.mission_area, COUNT(d.donation_id) AS total_donations FROM Organizations o LEFT JOIN Donations d ON o.org_id = d.org_id GROUP BY o.mission_area;', 'feedback': 'The two queries are

In [17]:
openai_opt_program = dspy.ChainOfThought(ProblemDef)
openai_opt_program.load("./openai_optimized_program/program.json")
openai_opt_metrics = evaluate_program(openai_opt_program, test_split, limit=500, max_workers=32)
print(f"Original Program: {openai_opt_metrics}")

Original Program: {'accuracy': 0.606, 'correct': 303, 'total': 500, 'avg_latency_s': 1.1225918439160014, 'failures': [{'idx': 1, 'reason': 'mismatch', 'pred_sql': "sql SELECT DISTINCT o.Id, o.Name, o.Sector FROM Organizations o JOIN Donations d ON d.OrganizationId = o.Id JOIN Donors dr ON dr.Id = d.DonorId JOIN Countries c ON c.Id = dr.Id WHERE c.Name = 'India';", 'feedback': 'They are not functionally similar.\n\nDifferences:\n- Join condition between Donors and Countries:\n  - Golden: JOIN Countries ON Donors.Name = Countries.Name and filters Countries.Continent = \'Asia\'.\n  - Candidate: JOIN Countries ON Countries.Id = Donors.Id and filters Countries.Name = \'India\'.\n  These are different relationships (matching donor name to country name vs matching donor id to country id).\n- Filter difference: golden filters by Continent = \'Asia\', candidate filters by Country name = \'India\'.\n- Selected columns: golden returns Organizations.Name only; candidate returns DISTINCT o.Id, o.Na

In [None]:
import json

def save_metrics(metrics, path):
    with open(path, "w") as f:
        json.dump(metrics, f, indent=4, sort_keys=True) 

save_metrics(og_metrics, "/dspy_program/5-mini.json")
save_metrics(opt_metrics,"/optimized_program/5-mini.json")
save_metrics(openai_metrics,"/openai_optimized_program/5-mini.json")

<class 'dict'>
