# Setup

## Import dependencies

In [1]:
%pip install -U dspy datasets tabulate duckdb pandas numpy ipywidgets "sqlglot[rs]" wandb --quiet

Note: you may need to restart the kernel to use updated packages.


In [2]:
import dspy
from datasets import load_dataset
import tabulate
import pandas as pd
import os
from dotenv import load_dotenv

In [3]:
load_dotenv(".env.local")
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
    raise ValueError("OPENAI_API_KEY not found in environment variables")

wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
    raise ValueError("WANDB_API_KEY not found in environment variables")

lm = dspy.LM("openai/gpt-4.1-mini", api_key=openai_api_key, temperature=1, max_tokens=16000)
reflection_lm = dspy.LM("openai/gpt-5-mini", api_key=openai_api_key, temperature=1, max_tokens=16000)
dspy.configure(lm=lm)

## Load data

In [4]:
ds = load_dataset("gretelai/synthetic_text_to_sql")

# Set up DSPy

## Set up Signature and Modules

In [5]:
class ProblemDef(dspy.Signature):
    """You are a database expert. You are provided with context for how some table(s) were constructed, and a natural language prompt for what the user wants. Your job is to write a SQL query to provide them with the required data."""
    
    sql_context: str = dspy.InputField(description="SQL queries for creating the table(s) and loading some data")
    sql_prompt: str = dspy.InputField(description="User's natural language prompt")
    sql: str = dspy.OutputField(description="SQL query that delivers on the user's request. Format as code that can be directly run without any changes – do not use new lines or anything else of that sort.")

program = dspy.ChainOfThought(ProblemDef)

In [6]:
# !pip install duckdb pandas numpy sqlglot --quiet
import duckdb, pandas as pd, numpy as np, re
import sqlglot
from sqlglot import parse_one

_ORDER_BY = re.compile(r"\border\s+by\b", re.IGNORECASE)

def _split_sql_statements(script: str):
    out, buf, q = [], [], None
    i, n = 0, len(script)
    while i < n:
        ch = script[i]
        if q:
            buf.append(ch)
            if ch == q:
                if i + 1 < n and script[i+1] == q:
                    buf.append(script[i+1]); i += 1
                else:
                    q = None
        else:
            if ch in ("'", '"', "`"):
                q = ch; buf.append(ch)
            elif ch == ';':
                s = "".join(buf).strip()
                if s: out.append(s)
                buf = []
            else:
                buf.append(ch)
        i += 1
    tail = "".join(buf).strip()
    if tail: out.append(tail)
    return out

import re
from sqlglot import parse_one

_SQLITE_DATE_RE = re.compile(
    r"""\bdate\s*\(\s*'now'\s*(?:,\s*'([+-])\s*(\d+)\s*(year|month|day)s?'\s*)?\)""",
    re.IGNORECASE,
)
_SQLITE_DATETIME_RE = re.compile(
    r"""\bdatetime\s*\(\s*'now'\s*(?:,\s*'([+-])\s*(\d+)\s*(year|month|day|hour|minute|second)s?'\s*)?\)""",
    re.IGNORECASE,
)

def _normalize_sqlite_dates(sql: str) -> str:
    # date('now') or date('now','-1 year') -> CURRENT_DATE +/- INTERVAL 'N unit'
    def _date_subst(m):
        sign, num, unit = m.group(1), m.group(2), m.group(3)
        if not sign:  # just date('now')
            return "CURRENT_DATE"
        op = "-" if sign == "-" else "+"
        return f"CURRENT_DATE {op} INTERVAL '{num} {unit.lower()}'"
    sql = _SQLITE_DATE_RE.sub(_date_subst, sql)

    # datetime('now') / datetime('now','+/-N unit') -> CURRENT_TIMESTAMP +/- INTERVAL 'N unit'
    def _dt_subst(m):
        sign, num, unit = m.group(1), m.group(2), m.group(3)
        if not sign:
            return "CURRENT_TIMESTAMP"
        op = "-" if sign == "-" else "+"
        return f"CURRENT_TIMESTAMP {op} INTERVAL '{num} {unit.lower()}'"
    sql = _SQLITE_DATETIME_RE.sub(_dt_subst, sql)

    return sql

def _mysql_to_duckdb(stmt: str) -> str:
    s = _normalize_sqlite_dates(stmt)  # <-- NEW: normalize SQLite first
    try:
        return parse_one(s, read="mysql").sql(dialect="duckdb")
    except Exception:
        # minimal fallbacks for MySQLisms if parse fails
        s = re.sub(r"`([^`]+)`", r'"\1"', s)
        s = re.sub(
            r"DATE_SUB\s*\(\s*(CURRENT_DATE|NOW\(\))\s*,\s*INTERVAL\s+(\d+)\s+(YEAR|MONTH|DAY)\s*\)",
            lambda m: f"{'CURRENT_DATE' if m.group(1).startswith('CURRENT') else 'CURRENT_DATE'} - INTERVAL '{m.group(2)} {m.group(3).lower()}'",
            s, flags=re.IGNORECASE,
        )
        s = re.sub(
            r"DATE_ADD\s*\(\s*(CURRENT_DATE|NOW\(\))\s*,\s*INTERVAL\s+(\d+)\s+(YEAR|MONTH|DAY)\s*\)",
            lambda m: f"{'CURRENT_DATE' if m.group(1).startswith('CURRENT') else 'CURRENT_DATE'} + INTERVAL '{m.group(2)} {m.group(3).lower()}'",
            s, flags=re.IGNORECASE,
        )
        s = re.sub(r"\bIFNULL\s*\(", "COALESCE(", s, flags=re.IGNORECASE)
        s = re.sub(r"\bLOCATE\s*\(\s*([^,]+)\s*,\s*([^)]+)\)", r"STRPOS(\2, \1)", s, flags=re.IGNORECASE)
        return s

def _normalize_df(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    for c in df.columns:
        if df[c].dtype == "O":
            try:
                df[c] = pd.to_numeric(df[c])
            except Exception:
                pass
    return df.replace({np.nan: None})

def _exec_script_capture_last_select(con, script: str):
    last_df, last_sel_sql = None, None
    for raw in _split_sql_statements(script):
        stmt = _mysql_to_duckdb(raw)
        # detect SELECT after minimal comment strip
        s = re.sub(r"^\s*(--[^\n]*\n|/\*.*?\*/\s*)*", "", stmt, flags=re.DOTALL)
        if re.match(r"(?is)^\s*(with\b.*?select|select)\b", s):
            last_df = con.execute(stmt).fetchdf()
            last_sel_sql = stmt
        else:
            con.execute(stmt)
    if last_df is not None:
        last_df = _normalize_df(last_df)
    return last_df, last_sel_sql

def evaluate_sql(sql_context: str, golden_sql: str, predicted_sql: str):
    con = duckdb.connect(":memory:")

    # context
    try:
        for raw in _split_sql_statements(sql_context):
            con.execute(_mysql_to_duckdb(raw))
    except Exception as e:
        return 0, {"reason": "context_error", "detail": str(e)}

    # golden
    try:
        gold_df, gold_last_select = _exec_script_capture_last_select(con, golden_sql)
    except Exception as e:
        return 0, {"reason": "gold_error", "detail": str(e)}
    if gold_df is None:
        return 0, {"reason": "gold_no_select", "detail": "No SELECT in golden_sql."}

    # predicted
    try:
        pred_df, pred_last_select = _exec_script_capture_last_select(con, predicted_sql)
    except Exception as e:
        return 0, {"reason": "pred_error", "detail": str(e)}
    if pred_df is None:
        return 0, {"reason": "pred_no_select", "detail": "No SELECT in predicted_sql."}

    # column alignment (allow pred supersets; else try set/positional)
    gold_cols, pred_cols = list(gold_df.columns), list(pred_df.columns)
    if gold_cols == pred_cols:
        pass
    elif set(gold_cols).issubset(pred_cols):
        pred_df = pred_df[gold_cols]
    elif set(gold_cols) == set(pred_cols):
        pred_df = pred_df[gold_cols]
    elif gold_df.shape[1] == pred_df.shape[1]:
        new_names = [f"c{i}" for i in range(gold_df.shape[1])]
        gold_df = gold_df.copy(); pred_df = pred_df.copy()
        gold_df.columns = new_names; pred_df.columns = new_names
    else:
        return 0, {"reason": "column_mismatch",
                   "detail": f"Different number of columns: expected {gold_df.shape[1]}, got {pred_df.shape[1]}"}

    # ordering rule from gold's last SELECT
    gold_has_order = bool(_ORDER_BY.search(gold_last_select or ""))
    if not gold_has_order:
        try:
            g = gold_df.sort_values(by=list(gold_df.columns), kind="mergesort").reset_index(drop=True)
            p = pred_df.sort_values(by=list(gold_df.columns), kind="mergesort").reset_index(drop=True)
        except Exception:
            g = gold_df.reset_index(drop=True); p = pred_df.reset_index(drop=True)
    else:
        g = gold_df.reset_index(drop=True); p = pred_df.reset_index(drop=True)

    # value compare
    if g.shape != p.shape:
        return 0, {"reason": "shape_mismatch", "detail": f"gold {g.shape} vs pred {p.shape}"}

    for c in g.columns:
        if pd.api.types.is_numeric_dtype(g[c]) and pd.api.types.is_numeric_dtype(p[c]):
            if not np.allclose(g[c].values, p[c].values, rtol=1e-6, atol=1e-8, equal_nan=True):
                return 0, {"reason": "value_mismatch", "detail": f"Numeric mismatch in '{c}'",
                           "gold_head": g.head(10).to_dict("records"),
                           "pred_head": p.head(10).to_dict("records")}
        else:
            eq = [(x == y) or (x is None and y is None) for x, y in zip(g[c].values, p[c].values)]
            if not all(eq):
                return 0, {"reason": "value_mismatch", "detail": f"Mismatch in '{c}'",
                           "gold_head": g.head(10).to_dict("records"),
                           "pred_head": p.head(10).to_dict("records")}
    return 1, None


## Test

In [7]:
demo_index = 4
context = ds['train'][demo_index]['sql_context']
prompt = ds['train'][demo_index]['sql_prompt']
golden_sql = ds['train'][demo_index]['sql']

print(f"Context: {context}")
print(f"Prompt: {prompt}")
print(f"Golden sql: {golden_sql}")
result = program(sql_context=context, sql_prompt=prompt)
print(result)

Context: CREATE TABLE upgrades (id INT, cost FLOAT, type TEXT); INSERT INTO upgrades (id, cost, type) VALUES (1, 500, 'Insulation'), (2, 1000, 'HVAC'), (3, 1500, 'Lighting');
Prompt: Find the energy efficiency upgrades with the highest cost and their types.
Golden sql: SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1;
Prediction(
    reasoning="To find the energy efficiency upgrades with the highest cost, we need to identify the maximum value in the 'cost' column and then select all rows that have that cost, including their types.",
    sql='SELECT type, cost FROM upgrades WHERE cost = (SELECT MAX(cost) FROM upgrades);'
)


In [8]:
score, info = evaluate_sql(context, golden_sql, result.sql)
print(score, info)


1 None


## Environment didn't work, let's use LLM as Judge

In [9]:
class Judge(dspy.Signature):
    """You are required to judge two SQL queries for functional similarity. You will be given a context of how the table(s) and data were created, and the natural language prompt from the user"""

    sql_context: str = dspy.InputField(description="SQL statement(s) creating the table(s) and the input data")
    sql_prompt: str = dspy.InputField(description="Natural language prompt from the user")
    golden_sql: str = dspy.InputField(description="The golden SQL query from our dataset")
    candidate_sql: str = dspy.InputField(description="A SQL query generated by a model for the same prompt")
    similar: bool = dspy.OutputField(description="True if the candidate SQL query is functionally similar to the golden SQL query")

judge = dspy.ChainOfThought(Judge)
judge.lm = reflection_lm
    

In [10]:
judge_response = judge(sql_context=context, sql_prompt=prompt, golden_sql=golden_sql, candidate_sql=result.sql)
print(f"Context: {context}")
print(f"Prompt: {prompt}")
print(f"Golden SQL: {golden_sql}")
print(f"Candidate SQL: {result.sql}")
print(f"Judge Response: {judge_response}")


Context: CREATE TABLE upgrades (id INT, cost FLOAT, type TEXT); INSERT INTO upgrades (id, cost, type) VALUES (1, 500, 'Insulation'), (2, 1000, 'HVAC'), (3, 1500, 'Lighting');
Prompt: Find the energy efficiency upgrades with the highest cost and their types.
Golden SQL: SELECT type, cost FROM (SELECT type, cost, ROW_NUMBER() OVER (ORDER BY cost DESC) as rn FROM upgrades) sub WHERE rn = 1;
Candidate SQL: SELECT type, cost FROM upgrades WHERE cost = (SELECT MAX(cost) FROM upgrades);
Judge Response: Prediction(
    reasoning='Both the golden SQL and the candidate SQL retrieve the type and cost of the upgrade(s) with the highest cost from the upgrades table. The golden SQL uses the ROW_NUMBER() window function ordered by cost descending and filters for the first row, effectively picking the single highest cost upgrade. The candidate SQL uses a subquery to find the maximum cost, then selects all upgrades with that max cost. The key behavioral difference is that the golden SQL returns exactly

# Get ready to GEPA

In [11]:
# pip install datasets dspy-ai
import math, random
from typing import Callable, List, Tuple, Optional
from datasets import Dataset, DatasetDict
from dspy import GEPA

def split_for_gepa(
    ds: Dataset,
    to_example: Callable[[dict], "dspy.Example"],
    val_size: float = 0.15,
    seed: int = 42,
    group_col: Optional[str] = None,
    stratify_col: Optional[str] = None,
) -> Tuple[List["dspy.Example"], List["dspy.Example"]]:
    """
    Return (train_set, val_set) as lists of dspy.Example.
    - If group_col is set: group-aware split (no group leakage).
    - Else if stratify_col is set: use HF stratified split.
    - Else: random split.
    """
    assert 0.0 < val_size < 1.0, "val_size must be in (0,1)"
    rng = random.Random(seed)

    # --- Group-aware split (preferred for text2sql) ---
    if group_col:
        groups = ds[group_col]
        # Build group -> indices
        g2idx = {}
        for i, g in enumerate(groups):
            g2idx.setdefault(g, []).append(i)
        uniq_groups = list(g2idx.keys())
        rng.shuffle(uniq_groups)
        n_val_groups = max(1, math.floor(val_size * len(uniq_groups)))
        val_groups = set(uniq_groups[:n_val_groups])

        val_idx = [i for g in val_groups for i in g2idx[g]]
        train_idx = [i for g in uniq_groups[n_val_groups:] for i in g2idx[g]]

        # Edge case: if a group is gigantic, ensure both splits non-empty
        if not train_idx or not val_idx:
            # fallback: plain random split
            perm = list(range(len(ds)))
            rng.shuffle(perm)
            cut = max(1, math.floor(val_size * len(ds)))
            val_idx, train_idx = perm[:cut], perm[cut:]

        ds_train = ds.select(train_idx)
        ds_val = ds.select(val_idx)

    # --- Stratified split (when you have a label/cluster column) ---
    elif stratify_col:
        # HF does stratify on categorical-like columns
        parts: DatasetDict = ds.train_test_split(
            test_size=val_size,
            seed=seed,
            stratify_by_column=stratify_col,
        )
        ds_train, ds_val = parts["train"], parts["test"]

    # --- Simple random split ---
    else:
        parts: DatasetDict = ds.train_test_split(test_size=val_size, seed=seed)
        ds_train, ds_val = parts["train"], parts["test"]

    # Map to dspy.Example lists
    train_set = [to_example(r) for r in ds_train]
    val_set = [to_example(r) for r in ds_val]
    return train_set, val_set

def to_dspy_example(row):
    # mark inputs; leave gold 'sql' as label
    return dspy.Example(
        sql_prompt=row["sql_prompt"],
        sql_context=row["sql_context"],
        sql=row["sql"],          # gold label
    ).with_inputs("sql_prompt", "sql_context")


# call function that splits ds['train'] into train_set and val_set as needed
# ds is your loaded HF dataset dict; we split ds["train"]
train_set, val_set = split_for_gepa(
    ds["train"],
    to_dspy_example,          # your to_dspy_example(row)
    val_size=0.5,
    seed=42,
    group_col=None,      # e.g., "db_id" if available
    stratify_col=None,   # or a column like "op_class" if you want stratification
)

In [12]:
max_variants_to_try = 20 # number of variants to test
mini_batch_size = 20 # mini-batch size
val_set_size = 200 # val-set size
train_set_size = 200 # train-set size

def budget_for_variants(N, V, k, slack=2):
    # slack handles occasional extra probes/promotions
    return V + N * (k + slack)

def metric_with_feedback(example, pred, trace=None, pred_name=None, pred_trace=None):
    judge_response = judge(sql_context=example.sql_context, sql_prompt=example.sql_prompt, golden_sql=example.sql, candidate_sql=pred.sql)
    score = 0
    if (judge_response.similar):
        score = 1
    return dspy.Prediction(score=score, feedback=judge_response.reasoning)

val_for_tracking = val_set[:val_set_size]   # 128–512 is a good range
train_set_for_optimization = train_set[:train_set_size]
optimizer = GEPA(
    metric=metric_with_feedback,
    num_threads=32,
    track_stats=True,
    reflection_minibatch_size=mini_batch_size,
    reflection_lm=reflection_lm,
    use_wandb=True,
    wandb_api_key=wandb_api_key,
    log_dir="logs",
    auto="light"   
)

# Run GEPA

In [13]:
optimized_program = optimizer.compile(
    program,
    trainset=train_set_for_optimization,
    valset=val_for_tracking,
)

2025/10/14 16:00:34 INFO dspy.teleprompt.gepa.gepa: Running GEPA for approx 1180 metric calls of the program. This amounts to 2.95 full evals on the train+val set.
2025/10/14 16:00:34 INFO dspy.teleprompt.gepa.gepa: Using 200 examples for tracking Pareto scores. You can consider using a smaller sample of the valset to allow GEPA to explore more diverse solutions within the same budget.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/raveesh/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mraveeshbhalla90[0m ([33mraveeshbhalla90-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [dspy, litellm, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/
GEPA Optimization:   0%|          | 0/1180 [00:00<?, ?rollouts/s]2025/10/14 16:00:37 INFO dspy.evaluate.evaluate: Average Metric: 136.0 / 200 (68.0%)
2025/10/14 16:00:37 INFO dspy.teleprompt.gepa.gepa: Iteration 0: Base program full valset score: 0.68
GEPA Optimization:  17%|█▋        | 200/1180 [00:01<00:07, 136.71rollouts/s]2025/10/14 16:00:37 INFO dspy.teleprompt.gepa.gepa: Iteration 1: Selected program 0 score: 0.68


Average Metric: 14.00 / 20 (70.0%): 100%|██████████| 20/20 [00:00<00:00, 157.90it/s]

2025/10/14 16:00:37 INFO dspy.evaluate.evaluate: Average Metric: 14.0 / 20 (70.0%)





2025/10/14 16:01:19 INFO dspy.teleprompt.gepa.gepa: Iteration 1: Proposed new text for predict: You are a SQL-writing database expert. You will be given two inputs:
- sql_context: one or more CREATE TABLE and INSERT statements describing the schema and sample data.
- sql_prompt: a natural-language request describing the data or change the user wants.

Your job: produce a single correct SQL statement that answers the sql_prompt using the tables and columns exactly as defined in sql_context. Also include a very short (1–3 sentence) plain-language reasoning comment explaining your interpretation/assumptions. If the prompt is an action (INSERT/UPDATE/DELETE) produce that DML statement; if it's a question, produce a SELECT that returns the requested result.

Output rules and best practices (apply these on every task):

1. Output format
   - Begin with a one- or two-sentence reasoning comment describing how you interpreted the prompt and any important assumptions.
   - Then output the SQL qu

Average Metric: 15.00 / 20 (75.0%): 100%|██████████| 20/20 [00:10<00:00,  1.84it/s]

2025/10/14 16:02:10 INFO dspy.evaluate.evaluate: Average Metric: 15.0 / 20 (75.0%)





2025/10/14 16:02:39 INFO dspy.teleprompt.gepa.gepa: Iteration 2: Proposed new text for predict: You are a SQL-writing database expert. You will be given two inputs:
- sql_context: one or more CREATE TABLE and INSERT statements describing the schema and sample data.
- sql_prompt: a natural-language request describing the data or change the user wants.

Your job: produce exactly one syntactically correct SQL statement that answers the sql_prompt using the tables and columns exactly as defined in sql_context, plus a very short plain-language reasoning comment (1–3 sentences) immediately before the SQL. Follow these rules strictly.

1) Output shape and format
   - Always begin with a 1–3 sentence reasoning comment explaining how you interpreted the prompt and any assumptions. The comment may be formatted as an SQL comment line (e.g., -- reasoning ...) or as a single plain-text line immediately before the SQL.
   - After that comment, output only one SQL statement and nothing else. Do not i

Average Metric: 15.00 / 20 (75.0%): 100%|██████████| 20/20 [00:08<00:00,  2.39it/s]

2025/10/14 16:03:01 INFO dspy.evaluate.evaluate: Average Metric: 15.0 / 20 (75.0%)





2025/10/14 16:03:31 INFO dspy.teleprompt.gepa.gepa: Iteration 3: Proposed new text for predict: You are a SQL-writing database expert assistant. You will be given two inputs:
- sql_context: one or more CREATE TABLE and INSERT statements that fully describe the available schema and any sample data.
- sql_prompt: a natural-language request describing the data to retrieve or the change to make.

Your job: produce a single, syntactically-correct SQL statement (and nothing else) that answers the sql_prompt using only the tables and columns exactly as defined in sql_context. Precede the SQL with a very short (1–3 sentence) plain-language reasoning comment that (a) says how you interpreted the prompt and (b) lists any important assumptions. The reasoning may be formatted either as a SQL comment line(s) (e.g., -- reasoning ...) or as a plain text line immediately before the SQL. After that single reasoning line(s), output only the SQL statement and nothing else.

Strict rules and expectations 

Average Metric: 15.00 / 20 (75.0%): 100%|██████████| 20/20 [00:08<00:00,  2.43it/s]

2025/10/14 16:04:03 INFO dspy.evaluate.evaluate: Average Metric: 15.0 / 20 (75.0%)





2025/10/14 16:04:32 INFO dspy.teleprompt.gepa.gepa: Iteration 4: Proposed new text for predict: You are a SQL-writing database expert assistant. You will be given two inputs:
- sql_context: one or more CREATE TABLE / INSERT / CREATE VIEW statements that define the schema and any sample data.
- sql_prompt: a natural-language request describing the data to return or a data change to perform.

Your job: produce a single, syntactically-correct SQL statement that answers sql_prompt using the tables and columns exactly as defined in sql_context, and a very short (1–3 sentence) plain-language reasoning comment describing your interpretation and any assumptions.

Strict output format (must follow exactly)
1. Begin with a one- to three-sentence reasoning comment describing how you interpreted the prompt and any important assumptions. This reasoning may be:
   - A SQL comment line(s) starting with -- (recommended), OR
   - A single plaintext line immediately before the SQL.
   Keep it concise (1

Average Metric: 16.00 / 20 (80.0%): 100%|██████████| 20/20 [00:06<00:00,  3.06it/s]

2025/10/14 16:04:48 INFO dspy.evaluate.evaluate: Average Metric: 16.0 / 20 (80.0%)





2025/10/14 16:05:16 INFO dspy.teleprompt.gepa.gepa: Iteration 5: Proposed new text for predict: You are a SQL-writing database expert assistant. You will be given two inputs:
- sql_context: one or more CREATE TABLE and INSERT statements that fully describe the schema and any sample data.
- sql_prompt: a natural-language request describing the data to retrieve or the change to make.

Your job: produce a single, correct SQL statement that answers the sql_prompt using only the tables and columns exactly as defined in sql_context, plus a very short (1–3 sentence) plain-language reasoning comment explaining your interpretation and any assumptions.

Required output format and behavior:
1. Output structure
   - Begin with a one- or two-sentence reasoning comment that explains how you interpreted the prompt and any important assumptions (e.g., how you interpret ambiguous geographic regions, date ranges, or whether to include zero-counts). The reasoning may be formatted either as a SQL comment 

Average Metric: 10.00 / 20 (50.0%): 100%|██████████| 20/20 [00:06<00:00,  2.94it/s]

2025/10/14 16:05:41 INFO dspy.evaluate.evaluate: Average Metric: 10.0 / 20 (50.0%)





2025/10/14 16:06:14 INFO dspy.teleprompt.gepa.gepa: Iteration 6: Proposed new text for predict: You are a SQL/query-writing assistant (a database expert). For each task you will be given:
- sql_context: DDL and sample data (CREATE SCHEMA / CREATE TABLE / CREATE VIEW / INSERT ...). This is the authoritative schema and example data you must use.
- sql_prompt: a natural-language request describing the data the user wants.

Your job:
- Produce a short, clear reasoning paragraph explaining how you derived the query and any assumptions you made from incomplete/ambiguous prompts or from the provided context.
- Produce a single SQL statement that implements the requested operation against the provided schema and data. The SQL must be syntactically valid and appropriate for the task described in the prompt.

Formatting requirements:
- Return two clearly labeled sections in your response: "reasoning" (one or two short paragraphs) and "sql" (the SQL statement).
- The "sql" section must contain on

Average Metric: 12.00 / 20 (60.0%): 100%|██████████| 20/20 [00:08<00:00,  2.41it/s]

2025/10/14 16:07:03 INFO dspy.evaluate.evaluate: Average Metric: 12.0 / 20 (60.0%)





2025/10/14 16:07:31 INFO dspy.teleprompt.gepa.gepa: Iteration 7: Proposed new text for predict: You are a SQL/database expert assistant. Your job is to take two inputs — a natural-language request (sql_prompt) and a schema + sample data context (sql_context) — and produce a correct, concise SQL statement (or statements) that answer the user's request plus a short reasoning section that explains any assumptions or interpretations you made.

Required input format (you will always receive these):
- sql_prompt: a single natural-language sentence or short paragraph describing what the user wants.
- sql_context: one or more CREATE TABLE / INSERT statements that define the available tables, columns, and sample data. Column names and types in this context are authoritative for composing queries, but may include typos or obvious mistakes.

Output format (always include both parts):
1. A brief "reasoning" section (1–6 sentences) that:
   - States the approach taken (joins/aggregations/filters).


Average Metric: 16.00 / 20 (80.0%): 100%|██████████| 20/20 [00:08<00:00,  2.43it/s] 

2025/10/14 16:08:20 INFO dspy.evaluate.evaluate: Average Metric: 16.0 / 20 (80.0%)





2025/10/14 16:09:01 INFO dspy.teleprompt.gepa.gepa: Iteration 8: Proposed new text for predict: You are a SQL-writing assistant. You will be given two inputs:
- sql_context: DDL (CREATE TABLE) and optional INSERT statements describing table schemas and sample data.
- sql_prompt: a natural-language request for data or a data modification operation.

Your job: produce a correct SQL statement that answers the sql_prompt using the schema(s) in sql_context. Also provide a very brief reasoning (1–3 sentences) that lists any assumptions you made (dates/time interpretation, case-sensitivity, SQL dialect if nonstandard syntax used, or when context is ambiguous).

Rules, strategies, and domain knowledge to apply (follow these carefully):

1. Parse the context
   - Use the provided CREATE TABLE definitions to know available tables and columns.
   - Treat the INSERTs only as examples of data; do not hardcode values from them unless the prompt explicitly asks for them.

2. Match the prompt precisel

Average Metric: 15.00 / 20 (75.0%): 100%|██████████| 20/20 [00:09<00:00,  2.16it/s]

2025/10/14 16:09:18 INFO dspy.evaluate.evaluate: Average Metric: 15.0 / 20 (75.0%)





2025/10/14 16:09:44 INFO dspy.teleprompt.gepa.gepa: Iteration 9: Proposed new text for predict: You are a SQL-writing database expert. You will be given two inputs:
- sql_context: one or more CREATE TABLE and INSERT statements that define the schema and sample data.
- sql_prompt: a natural-language request describing the data wanted or a change to perform.

Your job: produce exactly one syntactically correct SQL statement (a single SELECT, INSERT, UPDATE or DELETE) that answers the sql_prompt using only table and column names from sql_context. Precede the SQL with a very short (1–3 sentence) plain-language reasoning comment that states how you interpreted the prompt and any important assumptions. Follow these rules precisely every time.

1) Output format
- Start with a 1–3 sentence reasoning line or lines. This may be formatted as a SQL comment (e.g., -- reasoning ...) or as plain text immediately before the SQL.
- After the reasoning, output only the single SQL statement (no other tex

Average Metric: 15.00 / 20 (75.0%): 100%|██████████| 20/20 [00:07<00:00,  2.85it/s]

2025/10/14 16:09:59 INFO dspy.evaluate.evaluate: Average Metric: 15.0 / 20 (75.0%)





2025/10/14 16:10:21 INFO dspy.teleprompt.gepa.gepa: Iteration 10: Proposed new text for predict: You are a SQL/query-writing assistant (a database expert). For each task you will be given two inputs:
- sql_context: DDL and sample data (CREATE SCHEMA / CREATE TABLE / CREATE VIEW / INSERT ...) — this is the authoritative schema and sample data you must use.
- sql_prompt: a natural-language request describing the data the user wants.

Your job:
- Produce two clearly labeled sections in your response: "reasoning" and "sql".
  - "reasoning": one or two short paragraphs explaining how you derived the query and any assumptions you made (including any handling of ambiguous or missing information). If you use any dialect-specific functions (non-ANSI-standard SQL), state the dialect choice in this reasoning section.
  - "sql": a single, syntactically valid SQL statement that implements the requested operation against the provided schema and data. The "sql" section must contain only the SQL state

# Review original and optimized prompts

In [14]:
print(program.predict.signature.instructions)

You are a database expert. You are provided with context for how some table(s) were constructed, and a natural language prompt for what the user wants. Your job is to write a SQL query to provide them with the required data.


In [15]:
print(optimized_program.predict.signature.instructions)

You are a SQL/query-writing assistant (a database expert). For each task you will be given:
- sql_context: DDL and sample data (CREATE SCHEMA / CREATE TABLE / CREATE VIEW / INSERT ...). This is the authoritative schema and example data you must use.
- sql_prompt: a natural-language request describing the data the user wants.

Your job:
- Produce a short, clear reasoning paragraph explaining how you derived the query and any assumptions you made from incomplete/ambiguous prompts or from the provided context.
- Produce a single SQL statement that implements the requested operation against the provided schema and data. The SQL must be syntactically valid and appropriate for the task described in the prompt.

Formatting requirements:
- Return two clearly labeled sections in your response: "reasoning" (one or two short paragraphs) and "sql" (the SQL statement).
- The "sql" section must contain only the SQL statement (no additional commentary in that section). Use standard SQL where possible

# Store Programs

In [16]:
program.save("./dspy_program/program.json", save_program=False)
optimized_program.save("./optimized_program/program.json", save_program=False)

# Evals

In [22]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import Dataset
from time import perf_counter
from typing import Dict, Any, Optional

def evaluate_program(
    program,
    ds_test: Dataset,
    limit: int = 100,
    max_workers: int = 8,
    field_map: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
    """
    Evaluate a DSPy program on the first `limit` rows of a HF Dataset split.

    Args:
        program: a DSPy Module with signature program(sql_prompt=..., sql_context=...)
        ds_test: Hugging Face Dataset (e.g., ds["test"])
        limit: number of rows to evaluate (default 100)
        max_workers: parallel threads for I/O-bound LM + judge
        field_map: optional mapping if your column names differ:
                   {"sql_prompt": "...", "sql_context": "...", "sql": "..."}

    Returns:
        {
          "accuracy": float,
          "correct": int,
          "total": int,
          "avg_latency_s": float,
          "failures": [ {idx, reason, pred_sql, feedback} ... up to 20 ],
        }
    """
    if field_map is None:
        field_map = {"sql_prompt": "sql_prompt", "sql_context": "sql_context", "sql": "sql"}

    n = min(limit, len(ds_test))
    subset = ds_test.select(range(n))
    start = perf_counter()

    def _eval_one(i_row):
        i, row = i_row
        try:
            pred = program(
                sql_prompt=row[field_map["sql_prompt"]],
                sql_context=row[field_map["sql_context"]],
            )
            pred_sql = getattr(pred, "sql", None) or (pred.get("sql") if isinstance(pred, dict) else None) or ""
            jr = judge(
                sql_context=row[field_map["sql_context"]],
                sql_prompt=row[field_map["sql_prompt"]],
                golden_sql=row[field_map["sql"]],
                candidate_sql=pred_sql,
            )
            ok = bool(getattr(jr, "similar", False))
            feedback = getattr(jr, "reasoning", "") or ""
            return (i, ok, pred_sql, feedback, None)
        except Exception as e:
            return (i, False, "", "", f"{type(e).__name__}: {e}")

    results = []
    # Threaded evaluation (I/O bound: LM + judge). Tune max_workers to your provider limits.
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futures = [ex.submit(_eval_one, (i, subset[i])) for i in range(n)]
        for f in as_completed(futures):
            results.append(f.result())

    # Sort back to input order
    results.sort(key=lambda x: x[0])

    correct = sum(1 for _, ok, *_ in results if ok)
    total = n
    acc = correct / total if total else 0.0
    elapsed = perf_counter() - start
    avg_lat = elapsed / total if total else 0.0

    failures = []
    for i, ok, pred_sql, feedback, err in results:
        if not ok and len(failures) < 20:
            failures.append({
                "idx": i,
                "reason": ("error: " + err) if err else "mismatch",
                "pred_sql": pred_sql,
                "feedback": feedback,
            })

    return {
        "accuracy": acc,
        "correct": correct,
        "total": total,
        "avg_latency_s": avg_lat,
        "failures": failures,
    }
    
test_split = ds["test"]
test_split = test_split.shuffle()

In [23]:
og_metrics = evaluate_program(program, test_split, limit=500, max_workers=32)
print(f"Original Program: {og_metrics}")

Original Program: {'accuracy': 0.694, 'correct': 347, 'total': 500, 'avg_latency_s': 0.19754389458199148, 'failures': [{'idx': 9, 'reason': 'mismatch', 'pred_sql': "SELECT COUNT(*) AS total_articles FROM ny_times WHERE article_date >= '2021-01-01' AND article_date < '2021-03-01';", 'feedback': 'The golden SQL query returns two counts separately: one count of articles published in January 2021 and another count of articles published in February 2021. It uses a UNION ALL to list these counts as two separate rows. In contrast, the candidate SQL query provides a single count of all articles published in the combined date range from January 1, 2021, to just before March 1, 2021, which effectively covers both January and February 2021 collectively. Although the candidate query sums the two months together and returns one single total count, while the golden SQL returns two separate counts, the underlying aggregation of the number of articles for the two months combined is effectively the sam

In [20]:
opt_metrics = evaluate_program(optimized_program, test_split, limit=500, max_workers=32)
print(f"Original Program: {opt_metrics}")

Original Program: {'accuracy': 0.736, 'correct': 368, 'total': 500, 'avg_latency_s': 0.1709705568339996, 'failures': [{'idx': 4, 'reason': 'mismatch', 'pred_sql': "UPDATE peacekeeping_operations SET troops = 850 WHERE country = 'Afghanistan' AND year = 2005;", 'feedback': "The golden SQL query uses a common table expression (CTE) with an UPDATE statement and then inserts the affected rows back into the same table. This is an unusual and likely incorrect approach for updating a value, as it tries to insert rows back into the table after updating them, which may lead to errors or duplicate entries if the primary key is not changed or not handled properly.\n\nThe candidate SQL query is a straightforward UPDATE statement that sets the 'troops' value to 850 for the row where the country is 'Afghanistan' and the year is 2005. This is the standard and correct way to update a row in a table.\n\nFunctionally, both queries intend to update the troops value for the specific record. However, the g

# Store Eval Results

In [None]:
import json

def save_metrics(metrics, path):
    with open(path, "w") as f:
        json.dump(metrics, f, indent=4, sort_keys=True) 

save_metrics(og_metrics, "./dspy_program/4.1-mini.json")
save_metrics(opt_metrics,"./optimized_program/4.1-mini.json")