## Spider

download dataset from: <https://yale-lily.github.io/spider>

In [1]:
import json

train_json_path = "./data/spider/train_spider.json"

with open(train_json_path) as f:
    data = json.load(f)

query_id = 0
db_id = data[query_id]["db_id"]
gold_query = data[query_id]["query"]
question = data[query_id]["question"]

print(db_id, gold_query, question, sep="\n")

department_management
SELECT count(*) FROM head WHERE age  >  56
How many heads of the departments are older than 56 ?


In [2]:
schema_path = f"./data/spider/database/{db_id}/schema.sql"

with open(schema_path) as f:
    for line in f:
        print(line.strip())

PRAGMA foreign_keys=ON;
BEGIN TRANSACTION;
CREATE TABLE IF NOT EXISTS "department" (
"Department_ID" int,
"Name" text,
"Creation" text,
"Ranking" int,
"Budget_in_Billions" real,
"Num_Employees" real,
PRIMARY KEY ("Department_ID")
);
INSERT INTO department VALUES(1,'State','1789','1',9.9600000000000008526,30265.999999999999999);
INSERT INTO department VALUES(2,'Treasury','1789','2',11.099999999999999644,115896.99999999999999);
INSERT INTO department VALUES(3,'Defense','1947','3',439.30000000000001135,3000000.0);
INSERT INTO department VALUES(4,'Justice','1870','4',23.399999999999998578,112556.99999999999999);
INSERT INTO department VALUES(5,'Interior','1849','5',10.699999999999999289,71436.000000000000002);
INSERT INTO department VALUES(6,'Agriculture','1889','6',77.599999999999994316,109831.99999999999999);
INSERT INTO department VALUES(7,'Commerce','1903','7',6.2000000000000001776,35999.999999999999999);
INSERT INTO department VALUES(8,'Labor','1913','8',59.700000000000002843,17346.99

In [3]:
import subprocess

db_id = "department_management"

db_path = f"./data/spider/database/{db_id}/{db_id}.sqlite"
# cmd = ["sqlite3", db_path, gold_query]
cmd = ["sqlite3", db_path, "SELECT count(*) FROM head WHERE age2  >  56"]

result = subprocess.run(cmd, capture_output=True, text=True)
out = result.stdout.strip()
err = result.stderr.strip()
print(type(out), out)
print(type(err), err)

<class 'str'> 
<class 'str'> Error: no such column: age2


## Testbed

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

access_token = "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"

checkpoint = "bigcode/starcoder"
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(
    checkpoint, padding_side="left", token=access_token
)
model = AutoModelForCausalLM.from_pretrained(checkpoint, token=access_token).to(device)
tokenizer.pad_token_id = tokenizer.eos_token_id

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device,
)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it]


In [5]:
def get_prompt(spider_path: str, db_id: str, question: str) -> str:
    schema_path = f"{spider_path}/database/{db_id}/schema.sql"
    with open(schema_path) as f:
        lines = [
            line.strip()
            for line in f
            if line.startswith("CREATE TABLE")
            or line.startswith('"')
            or line.startswith(");")
        ]

    schema = "\n".join(lines)

    prompt = f"""{schema}

Translate the following question into SQL.

Question: {question}
SQL: """
    return prompt


prompt = get_prompt(
    "./data/spider/",
    "department_management",
    "How many heads of the departments are older than 56 ?",
)
print(prompt)

CREATE TABLE IF NOT EXISTS "department" (
"Department_ID" int,
"Name" text,
"Creation" text,
"Ranking" int,
"Budget_in_Billions" real,
"Num_Employees" real,
);
CREATE TABLE IF NOT EXISTS "head" (
"head_ID" int,
"name" text,
"born_state" text,
"age" real,
);
CREATE TABLE IF NOT EXISTS "management" (
"department_ID" int,
"head_ID" int,
"temporary_acting" text,
);

Translate the following question into SQL.

Question: How many heads of the departments are older than 56 ?
SQL: 


In [6]:
def get_query(prompt: str, output: dict) -> str:
    try:
        generated = output["generated_text"][len(prompt) :].strip()
        query = " ".join(generated[: generated.index(";") + 1].split())
    except Exception as e:
        print("get_query error:", e)
        print(f"{generated=}")
        return ""
    return query


prompt = 'CREATE TABLE IF NOT EXISTS "department" (\n"Department_ID" int,\n"Name" text,\n"Creation" text,\n"Ranking" int,\n"Budget_in_Billions" real,\n"Num_Employees" real,\n);\nCREATE TABLE IF NOT EXISTS "head" (\n"head_ID" int,\n"name" text,\n"born_state" text,\n"age" real,\n);\nCREATE TABLE IF NOT EXISTS "management" (\n"department_ID" int,\n"head_ID" int,\n"temporary_acting" text,\n);\n\nTranslate the following question into SQL.\n\nQuestion: How many heads of the departments are older than 56 ?\nSQL: '
outputs = [
    {
        "generated_text": 'CREATE TABLE IF NOT EXISTS "department" (\n"Department_ID" int,\n"Name" text,\n"Creation" text,\n"Ranking" int,\n"Budget_in_Billions" real,\n"Num_Employees" real,\n);\nCREATE TABLE IF NOT EXISTS "head" (\n"head_ID" int,\n"name" text,\n"born_state" text,\n"age" real,\n);\nCREATE TABLE IF NOT EXISTS "management" (\n"department_ID" int,\n"head_ID" int,\n"temporary_acting" text,\n);\n\nTranslate the following question into SQL.\n\nQuestion: How many heads of the departments are older than 56 ?\nSQL: \nSELECT COUNT(*)\nFROM head\nWHERE age > 56;\n\nQuestion: How many employees are in the department with the highest budget?\nSQL: \nSELECT MAX(Num_Employees)\nFROM department;\n\nQuestion: How many employees are in the department with the lowest budget?\nSQL: \nSELECT MIN(Num_Employees)\nFROM department;\n\nQuestion: How many employees are in the department with the highest ranking?\nSQL: \nSELECT MAX(Ranking)\nFROM department;\n\nQuestion: How many employees are in the department with the lowest ranking?\nSQL: \nSELECT MIN(Ranking)\nFROM department;\n\nQuestion: How many employees are in the department with the highest budget and the highest ranking?\nSQL: \nSELECT MAX(Num_Employees)\nFROM department\nWHERE Ranking = (SELECT MAX(Ranking) FROM department);\n\nQuestion: How many employees are in the department with the highest budget and the lowest ranking?\nSQL: \nSELECT MAX(Num_Employees)\nFROM department\nWHERE Ranking = (SELECT MIN(Ranking) FROM department);\n\nQuestion: How many employees are in the department with the lowest budget and the highest ranking'
    }
]
print(get_query(prompt, outputs[0]))

SELECT COUNT(*) FROM head WHERE age > 56;


In [7]:
def run_sqlite(spider_path: str, db_id: str, query: str):
    import subprocess

    db_path = f"{spider_path}/database/{db_id}/{db_id}.sqlite"
    cmd = ["sqlite3", db_path, query]
    return subprocess.run(cmd, capture_output=True, text=True)


def run_spider(spider_path: str, split: str):
    import json
    import os

    assert split in ("train_spider", "train_others", "dev")
    with open(os.path.join(spider_path, f"{split}.json")) as f:
        dataset = json.load(f)
    print(f"{len(dataset)=}")
    dataset = dataset[:100]  # Just for testing.

    # no schema.sql
    skipping = [
        db_id
        for db_id in os.listdir(os.path.join(spider_path, "database"))
        if not os.path.isfile(
            os.path.join(spider_path, "database", db_id, "schema.sql")
        )
    ]
    # print(skipping)
    dataset = [data for data in dataset if data["db_id"] not in skipping]
    print(f"{len(dataset)=}")

    prompts = [
        get_prompt(spider_path, data["db_id"], data["question"]) for data in dataset
    ]
    generated_texts = generator(
        prompts,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id,
        batch_size=32,
    )
    assert len(generated_texts) == len(dataset)

    num_errors = 0
    for i, data in enumerate(dataset):
        db_id = data["db_id"]

        gold_query = data["query"]
        gold_result = run_sqlite(spider_path, db_id, gold_query)

        query = get_query(prompts[i], generated_texts[i][0])
        result = run_sqlite(spider_path, db_id, query)
        if result.stdout.strip() != gold_result.stdout.strip():
            num_errors += 1
            print("=" * 10)
            print(db_id, data["question"])
            print(gold_query)
            print(gold_result.stdout)
            print("-" * 5)
            print(query)
            print(result.stdout)
            print(result.stderr)
            print("=" * 10)

    accuracy = 1 - num_errors / len(dataset)
    print(f"{accuracy=}")


run_spider("./data/spider", "train_spider")

len(dataset)=7000
len(dataset)=100
department_management What are the names of the states where at least 3 heads were born?
SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3
California

-----
SELECT DISTINCT state FROM head WHERE age >= 3;

Error: no such column: state

department_management In which year were most departments established?
SELECT creation FROM department GROUP BY creation ORDER BY count(*) DESC LIMIT 1
1789

-----
SELECT year, COUNT(year) FROM (SELECT DISTINCT(year) FROM department) GROUP BY year ORDER BY COUNT(year) DESC LIMIT 1;

Error: no such column: year

department_management How many acting statuses are there?
SELECT count(DISTINCT temporary_acting) FROM management
2

-----
SELECT COUNT(*) FROM management;
5


department_management What are the distinct ages of the heads who are acting?
SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id  =  T2.head_id WHERE T2.temporary_acting  =  'Yes'
53.0
52.0
69.0

-----
SELECT DI