In [4]:
!rm -rf /content/SQL-of-Thought
!git clone https://github.com/shollercoaster/SQL-of-Thought.git
!pip install -q openai gdown

import gdown, zipfile, os
if not os.path.exists('/content/spider'):
    gdown.download(id='1TqleXec_OykOYFREKKtschzY29dUcVAQ', output='/content/spider.zip', quiet=False)
    with zipfile.ZipFile('/content/spider.zip','r') as z:
        z.extractall('/content/')

Cloning into 'SQL-of-Thought'...
remote: Enumerating objects: 11, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 11 (delta 2), reused 3 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (11/11), 19.01 KiB | 973.00 KiB/s, done.
Resolving deltas: 100% (2/2), done.


In [5]:
import os, json, shutil
os.chdir('/content/SQL-of-Thought')

# grab the error taxonomy file whatever it's called
for f in os.listdir('.'):
    if 'error_taxonomy' in f and f.endswith('.json') and f != 'error_taxonomy.json':
        shutil.copy(f, 'error_taxonomy.json')
        break

os.makedirs('ablations_actual', exist_ok=True)
if not os.path.exists('nl2sql_bugs.json'):
    json.dump([], open('nl2sql_bugs.json','w'))

# repo doesn't have this file so writing a minimal version
with open('analyze_by_subproblems.py','w') as f:
    f.write("""import json

def parse_subproblems(s):
    if isinstance(s, str):
        try: data = json.loads(s)
        except: return []
    elif isinstance(s, dict): data = s
    else: return []
    return [sp.get("clause","").strip().upper()
            for sp in data.get("subproblems",[])
            if sp.get("clause","").strip()]
""")

# patch utils.py - fix paths, missing deps, client init
with open('utils.py','r') as f:
    lines = f.readlines()

out = []
i = 0
while i < len(lines):
    ln = lines[i].replace('../../spider/', '/content/spider/')

    if ln.strip() == 'import anthropic':
        out += ['try:\n','    import anthropic\n','except ImportError:\n','    anthropic = None\n']
        i+=1; continue

    if ln.strip().startswith('openai_client = OpenAI('):
        out.append('openai_client = None\n')
        i+=1; continue

    if ln.strip() == 'import torch' and i+1<len(lines) and 'from transformers' in lines[i+1]:
        out += ['try:\n','    import torch\n',
                '    from transformers import AutoTokenizer, AutoModelForCausalLM\n',
                'except ImportError:\n','    torch = None\n',
                '    AutoTokenizer = None\n','    AutoModelForCausalLM = None\n']
        i+=2; continue

    if ln.startswith('def load_local_model('):
        out.append("""
def init_clients(use_openai=True, use_anthropic=False, use_local=False, base_url=None):
    global openai_client, client, model, tokenizer
    if use_openai:
        kw = {"api_key": OPENAI_API_KEY}
        if base_url: kw["base_url"] = base_url
        openai_client = OpenAI(**kw)
    if use_anthropic and anthropic:
        client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
    if use_local:
        model, tokenizer = load_local_model()

""")
        out.append(ln); i+=1; continue

    if ln.strip() == 'model, tokenizer = load_local_model()':
        out.append('model, tokenizer = None, None\n'); i+=1; continue
    if 'client = anthropic.Anthropic(' in ln:
        out.append('client = None\n'); i+=1; continue
    if 'model= "gpt-5"' in ln:
        ln = ln.replace('model= "gpt-5", # "gpt-3.5-turbo",','model=model,')
    if 'open("../../nl2sql_bugs.json")' in ln:
        ln = ln.replace('open("../../nl2sql_bugs.json")','open("nl2sql_bugs.json")')

    out.append(ln)
    i+=1

with open('utils.py','w') as f:
    f.writelines(out)

# same path fix for eval script
with open('run_eval_single_schemalink.py','r') as f:
    src = f.read()
with open('run_eval_single_schemalink.py','w') as f:
    f.write(src.replace('../../spider/','/content/spider/'))

print("done patching")

done patching


In [6]:
import os, sys, json
from getpass import getpass

os.chdir('/content/SQL-of-Thought')
os.environ["OPENAI_API_KEY"] = getpass("api key: ")

for m in ['utils','prompts','analyze_by_subproblems']:
    if m in sys.modules: del sys.modules[m]

from utils import (init_clients, load_spider, load_schema, call_agent,
                   postprocess_sql, query_execution, exec_query,
                   clean_json, clause_specific_prompts)
from prompts import (alt_schema_linking_agent_prompt, subproblem_agent_prompt,
                     query_plan_agent_prompt, sql_agent_prompt,
                     correction_plan_agent_prompt, correction_sql_agent_prompt)
from analyze_by_subproblems import parse_subproblems

init_clients(use_openai=True, base_url="https://api.deepseek.com")

# quick sanity check
print(call_agent("Say hi.", "deepseek-chat"))

N = 100
MAX_FIX = 3
MDL = "deepseek-chat"

api key: Â·Â·Â·Â·Â·Â·Â·Â·Â·Â·
Hello! ðŸ‘‹ It's great to meet you. How can I help you today?


In [7]:
def run(gen_fn, tag, correct=True):
    dev = load_spider(dev=True)
    tot = em = vs = ea = 0
    rows = []

    for i, item in enumerate(dev[:N]):
        q, gold, db = item['question'], item['query'], item['db_id']
        schema = load_schema(db)
        sql = ""
        ok = False
        print(f"\n[{i+1}/{N}] {q}")

        try:
            sql, ctx = gen_fn(q, schema, db)
            print(f"  -> {sql}")
            ok, err = query_execution(item, sql)

            if correct:
                att = 0
                while not ok and att < MAX_FIX:
                    cp = call_agent(correction_plan_agent_prompt(q, sql, ctx, err), MDL)
                    sql = postprocess_sql(call_agent(
                        correction_sql_agent_prompt(q, ctx, cp, sql), MDL))
                    print(f"  fix{att+1}: {sql}")
                    ok, err = query_execution(item, sql)
                    att += 1
        except Exception as e:
            print(f"  ERR: {e}")

        gc = postprocess_sql(gold)
        is_em = sql.strip().lower() == gc.strip().lower()
        _, ge = exec_query(f"/content/spider/database/{db}/{db}.sqlite", sql)
        is_vs = ge is None

        if is_em: em += 1
        if is_vs: vs += 1
        if ok: ea += 1
        tot += 1
        rows.append({"question":q, "db_id":db, "gold":gc, "gen":sql,
                      "em":is_em, "valid":is_vs, "exec":ok})
        print(f"  gold: {gc}")
        print(f"  EM={em}/{tot} EA={ea}/{tot} VS={vs}/{tot}")

    s = {"total":tot, "em":em, "vs":vs, "ea":ea,
         "em%":round(em/tot,4), "vs%":round(vs/tot,4), "ea%":round(ea/tot,4)}
    path = f"ablations_actual/{N}_deepseek_{tag}.json"
    json.dump({"summary":s, "results":rows}, open(path,"w"), indent=2)
    print(f"\n=== {tag} === {json.dumps(s)}")
    return s


# generation functions for each variant

def gen_sot(q, sch, db):
    cs = call_agent(alt_schema_linking_agent_prompt(q, sch), MDL)
    sj = clean_json(call_agent(subproblem_agent_prompt(q, cs), MDL))
    cl = list(set(parse_subproblems(sj)))
    clause_specific_prompts(cl)
    plan = call_agent(query_plan_agent_prompt(q, cs, sj), MDL)
    sql = postprocess_sql(call_agent(sql_agent_prompt(q, plan, cs), MDL))
    return sql, cs

def gen_nosub(q, sch, db):
    cs = call_agent(alt_schema_linking_agent_prompt(q, sch), MDL)
    plan = call_agent(query_plan_agent_prompt(q, cs, "{}"), MDL)
    sql = postprocess_sql(call_agent(sql_agent_prompt(q, plan, cs), MDL))
    return sql, cs

def gen_noschema(q, sch, db):
    sj = clean_json(call_agent(subproblem_agent_prompt(q, sch), MDL))
    cl = list(set(parse_subproblems(sj)))
    clause_specific_prompts(cl)
    plan = call_agent(query_plan_agent_prompt(q, sch, sj), MDL)
    sql = postprocess_sql(call_agent(sql_agent_prompt(q, plan, sch), MDL))
    return sql, sch

def gen_direct(q, sch, db):
    p = f"Given this schema:\n{sch}\n\nWrite a SQLite query for: {q}\n\nOutput only SQL."
    return postprocess_sql(call_agent(p, MDL)), sch

In [8]:
r1 = run(gen_sot, "sot_full")


[1/100] How many singers do we have?
  -> select count(distinct singer.singer_id) from singer
  gold: select count(*) from singer
  EM=0/1 EA=1/1 VS=1/1

[2/100] What is the total number of singers?
  -> select count(*) from singer
  gold: select count(*) from singer
  EM=1/2 EA=2/2 VS=2/2

[3/100] Show name, country, age for all singers ordered by age from the oldest to the youngest.
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=2/3 EA=3/3 VS=3/3

[4/100] What are the names, countries, and ages for every singer in descending order of age?
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=3/4 EA=4/4 VS=4/4

[5/100] What is the average, minimum, and maximum age of all singers from France?
  -> select avg(singer.age), min(singer.age), max(singer.age) from singer where singer.country = 'france'
  gold: select avg(age), min(age

In [9]:
r2 = run(gen_sot, "no_correction", correct=False)


[1/100] How many singers do we have?
  -> select count(distinct singer.singer_id) from singer
  gold: select count(*) from singer
  EM=0/1 EA=1/1 VS=1/1

[2/100] What is the total number of singers?
  -> select count(distinct singer.singer_id) from singer
  gold: select count(*) from singer
  EM=0/2 EA=2/2 VS=2/2

[3/100] Show name, country, age for all singers ordered by age from the oldest to the youngest.
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=1/3 EA=3/3 VS=3/3

[4/100] What are the names, countries, and ages for every singer in descending order of age?
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=2/4 EA=4/4 VS=4/4

[5/100] What is the average, minimum, and maximum age of all singers from France?
  -> select avg(singer.age), min(singer.age), max(singer.age) from singer where singer.country = 'france'
  gold: 

In [10]:
r3 = run(gen_nosub, "no_subproblem")


[1/100] How many singers do we have?
  -> select count(distinct singer_id) from singer
  gold: select count(*) from singer
  EM=0/1 EA=1/1 VS=1/1

[2/100] What is the total number of singers?
  -> select count(singer.singer_id) from singer
  gold: select count(*) from singer
  EM=0/2 EA=2/2 VS=2/2

[3/100] Show name, country, age for all singers ordered by age from the oldest to the youngest.
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=1/3 EA=3/3 VS=3/3

[4/100] What are the names, countries, and ages for every singer in descending order of age?
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=2/4 EA=4/4 VS=4/4

[5/100] What is the average, minimum, and maximum age of all singers from France?
  -> select avg(singer.age) as average_age, min(singer.age) as minimum_age, max(singer.age) as maximum_age from singer where singe

In [11]:
r4 = run(gen_noschema, "no_schemalink")


[1/100] How many singers do we have?
  -> select count(distinct singer_id) from singer
  gold: select count(*) from singer
  EM=0/1 EA=1/1 VS=1/1

[2/100] What is the total number of singers?
  -> select count(*) from singer
  gold: select count(*) from singer
  EM=1/2 EA=2/2 VS=2/2

[3/100] Show name, country, age for all singers ordered by age from the oldest to the youngest.
  -> select s.name, s.country, s.age from singer s order by s.age desc
  gold: select name, country, age from singer order by age desc
  EM=1/3 EA=3/3 VS=3/3

[4/100] What are the names, countries, and ages for every singer in descending order of age?
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=2/4 EA=4/4 VS=4/4

[5/100] What is the average, minimum, and maximum age of all singers from France?
  -> select avg(singer.age), min(singer.age), max(singer.age) from singer where singer.country = 'france'
  gold: select avg(age), min(

In [12]:
r5 = run(gen_direct, "baseline", correct=False)


[1/100] How many singers do we have?
  -> select count(*) from singer
  gold: select count(*) from singer
  EM=1/1 EA=1/1 VS=1/1

[2/100] What is the total number of singers?
  -> select count(*) from singer
  gold: select count(*) from singer
  EM=2/2 EA=2/2 VS=2/2

[3/100] Show name, country, age for all singers ordered by age from the oldest to the youngest.
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=3/3 EA=3/3 VS=3/3

[4/100] What are the names, countries, and ages for every singer in descending order of age?
  -> select name, country, age from singer order by age desc
  gold: select name, country, age from singer order by age desc
  EM=4/4 EA=4/4 VS=4/4

[5/100] What is the average, minimum, and maximum age of all singers from France?
  -> select avg(age) as average_age, min(age) as minimum_age, max(age) as maximum_age from singer where country = 'france'
  gold: select avg(age), min(age), max(

In [13]:
print()
print("="*55)
print(f"{'config':<28} {'EA':>8} {'EM':>8} {'Valid':>8}")
print("-"*55)
for name, r in [("Full SoT", r1),
                ("w/o correction", r2),
                ("w/o subproblem", r3),
                ("w/o schema linking", r4),
                ("baseline", r5)]:
    print(f"{name:<28} {r['ea%']:>7.1%} {r['em%']:>7.1%} {r['vs%']:>7.1%}")
print("="*55)


config                             EA       EM    Valid
-------------------------------------------------------
Full SoT                       89.0%   11.0%   99.0%
w/o correction                 85.0%    6.0%  100.0%
w/o subproblem                 90.0%   10.0%   99.0%
w/o schema linking             89.0%   12.0%  100.0%
baseline                       87.0%   21.0%  100.0%
