# Fine Tuning for Text-to-SQL With Gradient and LlamaIndex


In [None]:
!pip install llama-index gradientai -q

In [None]:
import os
from llama_index.llms import GradientBaseModelLLM
from llama_index.finetuning.gradient.base import GradientFinetuneEngine

In [None]:
os.environ["GRADIENT_ACCESS_TOKEN"] = os.getenv("GRADIENT_API_KEY")
os.environ[
    "GRADIENT_WORKSPACE_ID"
] = "3c1cc9e5-1ea3-4807-aee8-c416b59a1250_workspace"

## Prepare Data

In [None]:
dialect = "sqlite"

In [None]:
from datasets import load_dataset
from pathlib import Path
import json


def load_jsonl(data_dir):
    data_path = Path(data_dir).as_posix()
    data = load_dataset("json", data_files=data_path)
    return data


def save_jsonl(data_dicts, out_path):
    with open(out_path, "w") as fp:
        for data_dict in data_dicts:
            fp.write(json.dumps(data_dict) + "\n")


def load_data_sql(data_dir: str = "data_sql"):
    dataset = load_dataset("b-mc2/sql-create-context")

    dataset_splits = {"train": dataset["train"]}
    out_path = Path(data_dir)

    out_path.parent.mkdir(parents=True, exist_ok=True)

    for key, ds in dataset_splits.items():
        with open(out_path, "w") as f:
            for item in ds:
                newitem = {
                    "input": item["question"],
                    "context": item["context"],
                    "output": item["answer"],
                }
                f.write(json.dumps(newitem) + "\n")

In [None]:
load_data_sql(data_dir="data_sql")

In [None]:
from math import ceil


def get_train_val_splits(
    data_dir: str = "data_sql",
    val_ratio: float = 0.1,
    seed: int = 42,
    shuffle: bool = True,
):
    data = load_jsonl(data_dir)
    # data_path = Path(data_dir).as_posix()
    # data = load_dataset("json", data_files=data_path)
    num_samples = len(data["train"])
    val_set_size = ceil(val_ratio * num_samples)

    train_val = data["train"].train_test_split(
        test_size=val_set_size, shuffle=shuffle, seed=seed
    )
    return train_val["train"].shuffle(), train_val["test"].shuffle()
    # train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
    # val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

In [None]:
raw_train_data, raw_val_data = get_train_val_splits(data_dir="data_sql")
save_jsonl(raw_train_data, "train_data_raw.jsonl")
save_jsonl(raw_val_data, "val_data_raw.jsonl")

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
raw_train_data[2]

{'input': 'When was the ship launched when the commissioned or completed(*) is 6 june 1864?',
 'context': 'CREATE TABLE table_12592074_1 (launched VARCHAR, commissioned_or_completed_ VARCHAR, _ VARCHAR)',
 'output': 'SELECT launched FROM table_12592074_1 WHERE commissioned_or_completed_ * _ = "6 June 1864"'}

In [None]:
# now we need to format into prompts amenable for training

# text_to_sql_tmpl_str = """\
# <s>### Instruction:\n{user_message}\n\n### Response:\n{response}</s>"""

# text_to_sql_inference_tmpl_str = """\
# <s>### Instruction:\n{user_message}\n\n### Response:\n"""

text_to_sql_tmpl_str = """\
<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n{response}</s>"""

text_to_sql_inference_tmpl_str = """\
<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n"""

# text_to_sql_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] {response} </s>"""

# text_to_sql_inference_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] """


def _generate_prompt_sql(input, context, dialect="sqlite", output=""):
    system_message = f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. 

You must output the SQL query that answers the question.
    
    """
    user_message = f"""### Dialect:
{dialect}

### Input:
{input}

### Context:
{context}

### Response:
"""
    if output:
        return text_to_sql_tmpl_str.format(
            system_message=system_message,
            user_message=user_message,
            response=output,
        )
    else:
        return text_to_sql_inference_tmpl_str.format(
            system_message=system_message, user_message=user_message
        )


def generate_prompt(data_point):
    full_prompt = _generate_prompt_sql(
        data_point["input"],
        data_point["context"],
        dialect="sqlite",
        output=data_point["output"],
    )
    return {"inputs": full_prompt}

In [None]:
train_data = raw_train_data.map(generate_prompt)
save_jsonl(train_data, "train_data.jsonl")
val_data = raw_val_data.map(generate_prompt)
save_jsonl(val_data, "val_data.jsonl")

Map:   0%|          | 0/70719 [00:00<?, ? examples/s]

Map:   0%|          | 0/7858 [00:00<?, ? examples/s]

## Run Fine-tuning 

In [None]:
# base_model_slug = "nous-hermes2"
base_model_slug = "llama2-7b-chat"
base_llm = GradientBaseModelLLM(
    base_model_slug=base_model_slug, max_tokens=100
)

In [None]:
# step max steps to 20 just for testing purposes
finetune_engine = GradientFinetuneEngine(
    base_model_slug=base_model_slug,
    name="text_to_sql",
    data_path="train_data.jsonl",
    verbose=True,
    max_steps=200,
    batch_size=4,
)

In [None]:
finetune_engine.model_adapter_id

'805c6fd6-daa8-4fc8-a509-bebb2f2c1024_model_adapter'

In [None]:
epochs = 1
for i in range(epochs):
    print(f"** EPOCH {i} **")
    finetune_engine.finetune()

** EPOCH 0 **
fine-tuning step 4: loss=1980.5546, trainable tokens=631
fine-tuning step 8: loss=1312.1577, trainable tokens=648
fine-tuning step 12: loss=1150.765, trainable tokens=665
fine-tuning step 16: loss=867.1619, trainable tokens=620
fine-tuning step 20: loss=747.4055, trainable tokens=650
fine-tuning step 24: loss=577.833, trainable tokens=640
fine-tuning step 28: loss=375.29626, trainable tokens=583
fine-tuning step 32: loss=506.39142, trainable tokens=604
fine-tuning step 36: loss=489.96997, trainable tokens=653
fine-tuning step 40: loss=380.26834, trainable tokens=613
fine-tuning step 44: loss=432.8493, trainable tokens=585
fine-tuning step 48: loss=597.9427, trainable tokens=638
fine-tuning step 52: loss=546.3855, trainable tokens=636
fine-tuning step 56: loss=466.26605, trainable tokens=614
fine-tuning step 60: loss=317.3381, trainable tokens=627
fine-tuning step 64: loss=447.2127, trainable tokens=655
fine-tuning step 68: loss=323.37115, trainable tokens=585
fine-tuning 

In [None]:
ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)

## Try it out + Evaluate

In [None]:
# create sample
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    column,
)
from llama_index import SQLDatabase

In [None]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

In [None]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

In [None]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

In [None]:
# insert sample rows
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()

In [None]:
from llama_index import ServiceContext

test_datapoint = raw_val_data[3]


def get_text2sql_completion(llm, sql_database, raw_datapoint):
    service_context = ServiceContext.from_defaults(llm=llm)
    text2sql_tmpl_str = _generate_prompt_sql(
        raw_datapoint["input"],
        raw_datapoint["context"],
        dialect="sqlite",
        output=None,
    )
    # text2sql_prompt = PromptTemplate(text2sql_tmpl_str)

    response = llm.complete(text2sql_tmpl_str)
    return str(response)

In [None]:
print(test_datapoint)

{'input': 'What is the lowest number conceded for the team that had less than 8 wins, scored 21, and had less than 23 points?', 'context': 'CREATE TABLE table_name_77 (conceded INTEGER, points VARCHAR, wins VARCHAR, scored VARCHAR)', 'output': 'SELECT MIN(conceded) FROM table_name_77 WHERE wins < 8 AND scored = 21 AND points < 23'}


In [None]:
# tmp_model_slug = "nous-hermes2"
tmp_model_slug = "llama2-7b-chat"
tmp_llm = GradientBaseModelLLM(base_model_slug=tmp_model_slug, max_tokens=100)
# tmp_llm = OpenAI(model="gpt-4")
get_text2sql_completion(tmp_llm, sql_database, test_datapoint)

'SELECT * FROM table_name_77 WHERE conceded = (SELECT MIN(conceded) FROM table_name_77 WHERE points < 23 AND wins < 8);\n\nThis SQL query will return all the rows from the `table_name_77` table where the `conceded` column is equal to the minimum value of `conceded` found in rows where `points` is less than 23 and `wins`'

In [None]:
get_text2sql_completion(ft_llm, sql_database, test_datapoint)

'SELECT MIN(conceded) FROM table_name_77 WHERE wins < 8 AND scored = 21 AND points < 23'

In [None]:
from llama_index.query_engine import NLSQLTableQueryEngine
from llama_index import ServiceContext, PromptTemplate


def get_text2sql_query_engine(
    llm,
    sql_database,
):
    service_context = ServiceContext.from_defaults(llm=llm)
    text2sql_tmpl_str = _generate_prompt_sql(
        "{query_str}", "{schema}", dialect="{dialect}", output=""
    )
    sql_prompt = PromptTemplate(text2sql_tmpl_str)
    # print(sql_prompt.template)
    # raise Exception
    query_engine = NLSQLTableQueryEngine(
        sql_database,
        text_to_sql_prompt=sql_prompt,
        service_context=service_context,
        synthesize_response=False,
    )
    return query_engine

In [None]:
# query = "Which city has the highest population?"
query = "What is the average population and total population of the cities?"

In [None]:
from llama_index.llms import OpenAI

tmp_llm = OpenAI(model="gpt-4")
tmp_query_engine = get_text2sql_query_engine(tmp_llm, sql_database)

tmp_response = tmp_query_engine.query(query)
print(str(tmp_response))

[(7336250.0, 29345000)]


In [None]:
base_query_engine = get_text2sql_query_engine(base_llm, sql_database)

In [None]:
base_response = base_query_engine.query(query)

Warning: You can only execute one statement at a time.

In [None]:
print(str(base_response))

NameError: name 'base_response' is not defined

In [None]:
base_response.metadata["sql_query"]

NameError: name 'base_response' is not defined

In [None]:
ft_query_engine = get_text2sql_query_engine(ft_llm, sql_database)

In [None]:
ft_response = ft_query_engine.query(query)

In [None]:
print(str(ft_response))

[(2930000.0, 1), (13960000.0, 1), (9776000.0, 1), (2679000.0, 1)]


In [None]:
ft_response.metadata["sql_query"]

'SELECT AVG(population), COUNT(*) FROM city_stats GROUP BY country'