In [None]:
%pip install -e .

In [1]:
import spider_env as gym

env = gym.SpiderEnv()

Loading cached Spider dataset from /home/wangdazhang/.cache/spider
Schema file not found for /home/wangdazhang/.cache/spider/spider/database/icfp_1
Schema file not found for /home/wangdazhang/.cache/spider/spider/database/small_bank_1
Schema file not found for /home/wangdazhang/.cache/spider/spider/database/flight_4
Schema file not found for /home/wangdazhang/.cache/spider/spider/database/company_1
Schema file not found for /home/wangdazhang/.cache/spider/spider/database/chinook_1
Schema file not found for /home/wangdazhang/.cache/spider/spider/database/epinions_1
Schema file not found for /home/wangdazhang/.cache/spider/spider/database/twitter_1


In [2]:
observation, info = env.reset(7)

print(observation)
print(info)

{'observation': 'department_management', 'instruction': 'What are the names of the states where at least 3 heads were born?', 'feedback': None}
{'schema': 'CREATE TABLE IF NOT EXISTS "department" (\n"Department_ID" int,\n"Name" text,\n"Creation" text,\n"Ranking" int,\n"Budget_in_Billions" real,\n"Num_Employees" real,\nPRIMARY KEY ("Department_ID")\n);\nCREATE TABLE IF NOT EXISTS "head" (\n"head_ID" int,\n"name" text,\n"born_state" text,\n"age" real,\nPRIMARY KEY ("head_ID")\n);\nCREATE TABLE IF NOT EXISTS "management" (\n"department_ID" int,\n"head_ID" int,\n"temporary_acting" text,\nPRIMARY KEY ("Department_ID","head_ID"),\nFOREIGN KEY ("Department_ID") REFERENCES "department"("Department_ID"),\nFOREIGN KEY ("head_ID") REFERENCES "head"("head_ID")\n);\n', 'gold_query': 'SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3', 'gold_result': [('California',)]}


In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

access_token = "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"

checkpoint = "bigcode/starcoder"

tokenizer = AutoTokenizer.from_pretrained(
    checkpoint, padding_side="left", token=access_token
)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    token=access_token,
    load_in_8bit=True,
    device_map="auto",
)
tokenizer.pad_token_id = tokenizer.eos_token_id

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)


# TODO: More robust implementation.
def get_query(prompt: str, generated: dict) -> str:
    try:
        generated_text = generated["generated_text"][len(prompt) :].strip()
        query = " ".join(generated_text[: generated_text.index(";") + 1].split())
    except Exception as e:
        print("get_query error:", e)
        print(f"{generated_text=}")
        return ""
    return query

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 7/7 [00:14<00:00,  2.06s/it]


In [4]:
def get_prompt(observation: dict, info: dict) -> str:
    question = observation["instruction"]
    schema = info["schema"]
    assert schema is not None

    prompt = f"""{schema}

Translate the following question into SQL.

Question: {question}
SQL: """
    return prompt


prompt = get_prompt(observation, info)
print(prompt)

generated = generator(
    prompt,
    # do_sample=True,
    # top_p=0.9,
    num_return_sequences=1,
    max_new_tokens=64,
    pad_token_id=tokenizer.eos_token_id,
)

query = get_query(prompt, generated[0])
print(query)

CREATE TABLE IF NOT EXISTS "department" (
"Department_ID" int,
"Name" text,
"Creation" text,
"Ranking" int,
"Budget_in_Billions" real,
"Num_Employees" real,
PRIMARY KEY ("Department_ID")
);
CREATE TABLE IF NOT EXISTS "head" (
"head_ID" int,
"name" text,
"born_state" text,
"age" real,
PRIMARY KEY ("head_ID")
);
CREATE TABLE IF NOT EXISTS "management" (
"department_ID" int,
"head_ID" int,
"temporary_acting" text,
PRIMARY KEY ("Department_ID","head_ID"),
FOREIGN KEY ("Department_ID") REFERENCES "department"("Department_ID"),
FOREIGN KEY ("head_ID") REFERENCES "head"("head_ID")
);


Translate the following question into SQL.

Question: What are the names of the states where at least 3 heads were born?
SQL: 
SELECT DISTINCT state FROM head WHERE age >= 3;


In [5]:
action = query

observation, reward, terminated, truncated, info = env.step(action)

print(observation)
print(info)
print(f"{reward=}")

{'observation': 'department_management', 'instruction': 'What are the names of the states where at least 3 heads were born?', 'feedback': {'result': None, 'error': 'no such column: state'}}
{'schema': 'CREATE TABLE IF NOT EXISTS "department" (\n"Department_ID" int,\n"Name" text,\n"Creation" text,\n"Ranking" int,\n"Budget_in_Billions" real,\n"Num_Employees" real,\nPRIMARY KEY ("Department_ID")\n);\nCREATE TABLE IF NOT EXISTS "head" (\n"head_ID" int,\n"name" text,\n"born_state" text,\n"age" real,\nPRIMARY KEY ("head_ID")\n);\nCREATE TABLE IF NOT EXISTS "management" (\n"department_ID" int,\n"head_ID" int,\n"temporary_acting" text,\nPRIMARY KEY ("Department_ID","head_ID"),\nFOREIGN KEY ("Department_ID") REFERENCES "department"("Department_ID"),\nFOREIGN KEY ("head_ID") REFERENCES "head"("head_ID")\n);\n', 'gold_query': 'SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3', 'gold_result': [('California',)]}
reward=0.0


In [6]:
# TODO: Make this prompt effective.
def get_prompt_with_error(observation: dict, info: dict) -> str:
    question = observation["instruction"]
    error = observation["feedback"]["error"]
    assert error is not None
    schema = info["schema"]
    assert schema is not None

    prompt = f"""{schema}

Translate the following question into SQL.

Question: {question}

SQL: {query}

Feedback: The SQL query fails to execute with this error: {error}

Fix the SQL.

SQL: """
    return prompt


prompt = get_prompt_with_error(observation, info)
print(prompt)

generated = generator(
    prompt,
    # do_sample=True,
    # top_p=0.9,
    num_return_sequences=1,
    max_new_tokens=64,
    pad_token_id=tokenizer.eos_token_id,
)

query = get_query(prompt, generated[0])
print(query)

CREATE TABLE IF NOT EXISTS "department" (
"Department_ID" int,
"Name" text,
"Creation" text,
"Ranking" int,
"Budget_in_Billions" real,
"Num_Employees" real,
PRIMARY KEY ("Department_ID")
);
CREATE TABLE IF NOT EXISTS "head" (
"head_ID" int,
"name" text,
"born_state" text,
"age" real,
PRIMARY KEY ("head_ID")
);
CREATE TABLE IF NOT EXISTS "management" (
"department_ID" int,
"head_ID" int,
"temporary_acting" text,
PRIMARY KEY ("Department_ID","head_ID"),
FOREIGN KEY ("Department_ID") REFERENCES "department"("Department_ID"),
FOREIGN KEY ("head_ID") REFERENCES "head"("head_ID")
);


Translate the following question into SQL.

Question: What are the names of the states where at least 3 heads were born?

SQL: SELECT DISTINCT state FROM head WHERE age >= 3;

Feedback: The SQL query fails to execute with this error: no such column: state

Fix the SQL.

SQL: 
SELECT DISTINCT state FROM head WHERE age >= 3;


In [7]:
action = query

observation, reward, terminated, truncated, info = env.step(action)

print(observation)
print(info)
print(f"{reward=}")

{'observation': 'department_management', 'instruction': 'What are the names of the states where at least 3 heads were born?', 'feedback': {'result': None, 'error': 'no such column: state'}}
{'schema': 'CREATE TABLE IF NOT EXISTS "department" (\n"Department_ID" int,\n"Name" text,\n"Creation" text,\n"Ranking" int,\n"Budget_in_Billions" real,\n"Num_Employees" real,\nPRIMARY KEY ("Department_ID")\n);\nCREATE TABLE IF NOT EXISTS "head" (\n"head_ID" int,\n"name" text,\n"born_state" text,\n"age" real,\nPRIMARY KEY ("head_ID")\n);\nCREATE TABLE IF NOT EXISTS "management" (\n"department_ID" int,\n"head_ID" int,\n"temporary_acting" text,\nPRIMARY KEY ("Department_ID","head_ID"),\nFOREIGN KEY ("Department_ID") REFERENCES "department"("Department_ID"),\nFOREIGN KEY ("head_ID") REFERENCES "head"("head_ID")\n);\n', 'gold_query': 'SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3', 'gold_result': [('California',)]}
reward=0.0


In [8]:
# test manual prompt

prompt = """CREATE TABLE IF NOT EXISTS "head" (
"head_ID" int,
"name" text,
"born_state" text,
"age" real,
PRIMARY KEY ("head_ID")
);

Translate the following question into SQL.

Question: What are the names of the states where at least 3 heads were born?
SQL: """

generated = generator(
    prompt,
    # do_sample=True,
    # top_p=0.9,
    num_return_sequences=1,
    max_new_tokens=64,
    pad_token_id=tokenizer.eos_token_id,
)

query = get_query(prompt, generated[0])
print(query)

SELECT DISTINCT born_state FROM head GROUP BY born_state HAVING COUNT(*) >= 3;


In [9]:
action = query

observation, reward, terminated, truncated, info = env.step(action)

print(observation)
print(info)
print(f"{reward=}")

{'observation': 'department_management', 'instruction': 'What are the names of the states where at least 3 heads were born?', 'feedback': {'result': [('California',)], 'error': None}}
{'schema': 'CREATE TABLE IF NOT EXISTS "department" (\n"Department_ID" int,\n"Name" text,\n"Creation" text,\n"Ranking" int,\n"Budget_in_Billions" real,\n"Num_Employees" real,\nPRIMARY KEY ("Department_ID")\n);\nCREATE TABLE IF NOT EXISTS "head" (\n"head_ID" int,\n"name" text,\n"born_state" text,\n"age" real,\nPRIMARY KEY ("head_ID")\n);\nCREATE TABLE IF NOT EXISTS "management" (\n"department_ID" int,\n"head_ID" int,\n"temporary_acting" text,\nPRIMARY KEY ("Department_ID","head_ID"),\nFOREIGN KEY ("Department_ID") REFERENCES "department"("Department_ID"),\nFOREIGN KEY ("head_ID") REFERENCES "head"("head_ID")\n);\n', 'gold_query': 'SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3', 'gold_result': [('California',)]}
reward=1.0
