In [1]:
import os 
import sys 

# Run this incase if you have not installed the repo as a package but still
# want to run this notebook

current_dir = os.getcwd()
dir_to_use = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(dir_to_use)

### Evaluating Defog AI SQL Coder on BIRDBench dataset

In this example, we are going to show how to evaluate Open Source LLMs using text2sql library. We are going to evaluate Defog AI's latest [text2sql model](https://huggingface.co/defog/llama-3-sqlcoder-8b). 

In [2]:
import re
import sqlparse

from text2sql.eval.dataset.bird import BirdBenchEvalDataset
from text2sql.eval.settings import SQLGeneratorConfig, ModelConfig
from text2sql.eval.generator import SQLGeneratorFromModel
from text2sql.eval.executor.bird.acc import BirdExecutorAcc
from text2sql.eval.executor.bird.ves import BirdExecutorVES

config = SQLGeneratorConfig()
model_config = ModelConfig(
    model_name="defog/llama-3-sqlcoder-8b",
    temperature=0.1,
    max_tokens=256,
    is_instruct=True
)

eval_dataset = BirdBenchEvalDataset(config=config)

def postprocess(input_string: str):
    sql_start_keywords = [
        r"\bSELECT\b",
        r"\bINSERT\b",
        r"\bUPDATE\b",
        r"\bDELETE\b",
        r"\bWITH\b",
    ]

    sql_start_pattern = re.compile("|".join(sql_start_keywords), re.IGNORECASE)
    match = sql_start_pattern.search(input_string)

    if match:
        start_pos = match.start()
        sql_statement = input_string[start_pos:]
        return sqlparse.format(sql_statement)
    else:
        return sqlparse.format(input_string)


def run(dataset, difficulty, num_rows):
    filter_by = ("difficulty", difficulty)
    processed = dataset.process_and_filter(
        num_rows=num_rows, 
        filter_by=filter_by
    ).apply_prompt(apply_knowledge=True)

    
    config = SQLGeneratorConfig(model_name=f"defog_{difficulty}_{num_rows}")
    client = SQLGeneratorFromModel(
        generator_config=config,
        engine_config=model_config
    )
    acc = BirdExecutorAcc(generator_config=config)
    ves = BirdExecutorVES(generator_config=config)

    data_with_gen = client.generate_and_save_results(
        data=processed, 
        force=True,
        postprocess=postprocess
    )

    acc.execute(model_responses=data_with_gen, filter_used=filter_by)
    ves.execute(model_responses=data_with_gen, filter_used=filter_by)
    print("\n")
    print("\n")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
run(eval_dataset, difficulty="simple", num_rows=100)

2024-08-05 07:06:42,980 - text2sql-eval - INFO - ./data/eval/ is not empty. Use force=True to re-download and overwrite the contents.
2024-08-05 07:06:42,980 - text2sql-eval - INFO - ./data/eval/ is not empty. Use force=True to re-download and overwrite the contents.
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.22s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  0%|          | 0/100 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 100/100 [03:14<00:00,  1.94s/it]
2024-08-05 07:10:03,050 - text2sql-eval - INFO - all responses written to ./experiments/eval/prem_defog_simple_100/predict_dev.json
2024-08-05 07:10:03,050 - text2sql-eval - INFO - all responses written to ./experiments/eval

=>  ./experiments/eval/prem_defog_simple_100 acc_simple.json
+-------------+-------------------+-------------------+
| Category    |   num_correct (%) |   total questions |
| simple      |                18 |               100 |
+-------------+-------------------+-------------------+
| overall     |                18 |               100 |
+-------------+-------------------+-------------------+
| moderate    |                 0 |                 0 |
+-------------+-------------------+-------------------+
| challenging |                 0 |                 0 |
+-------------+-------------------+-------------------+
+-------------+-----------+-------------------+
| Category    |   VES (%) |   total questions |
| simple      |   21.8705 |               100 |
+-------------+-----------+-------------------+
| overall     |   21.8705 |               100 |
+-------------+-----------+-------------------+
| moderate    |    0      |                 0 |
+-------------+-----------+----------------

In [4]:
run(eval_dataset, difficulty="moderate", num_rows=100)

2024-08-05 07:43:27,148 - text2sql-eval - INFO - ./data/eval/ is not empty. Use force=True to re-download and overwrite the contents.
2024-08-05 07:43:27,148 - text2sql-eval - INFO - ./data/eval/ is not empty. Use force=True to re-download and overwrite the contents.
Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00,  2.08s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|██████████| 100/100 [03:49<00:00,  2.29s/it]
2024-08-05 07:47:25,656 - text2sql-eval - INFO - all responses written to ./experiments/eval/prem_defog_moderate_100/predict_dev.json
2024-08-05 07:47:25,656 - text2sql-eval - INFO - all responses written to ./experiments/eval/prem_defog_moderate_100/predict_dev.json


=>  ./experiments/eval/prem_defog_moderate_100 acc_moderate.json
+-------------+-------------------+-------------------+
| Category    |   num_correct (%) |   total questions |
| moderate    |                17 |               100 |
+-------------+-------------------+-------------------+
| overall     |                17 |               100 |
+-------------+-------------------+-------------------+
| simple      |                 0 |                 0 |
+-------------+-------------------+-------------------+
| challenging |                 0 |                 0 |
+-------------+-------------------+-------------------+
+-------------+-----------+-------------------+
| Category    |   VES (%) |   total questions |
| moderate    |   22.2302 |               100 |
+-------------+-----------+-------------------+
| overall     |   22.2302 |               100 |
+-------------+-----------+-------------------+
| simple      |    0      |                 0 |
+-------------+-----------+------------

In [5]:
run(eval_dataset, difficulty="challenging", num_rows=100)

2024-08-05 07:47:34,755 - text2sql-eval - INFO - ./data/eval/ is not empty. Use force=True to re-download and overwrite the contents.
2024-08-05 07:47:34,755 - text2sql-eval - INFO - ./data/eval/ is not empty. Use force=True to re-download and overwrite the contents.
Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00,  2.06s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|██████████| 100/100 [04:10<00:00,  2.51s/it]
2024-08-05 07:51:54,827 - text2sql-eval - INFO - all responses written to ./experiments/eval/prem_defog_challenging_100/predict_dev.json
2024-08-05 07:51:54,827 - text2sql-eval - INFO - all responses written to ./experiments/eval/prem_defog_challenging_100/predict_dev.json


=>  ./experiments/eval/prem_defog_challenging_100 acc_challenging.json
+-------------+-------------------+-------------------+
| Category    |   num_correct (%) |   total questions |
| challenging |                27 |               100 |
+-------------+-------------------+-------------------+
| overall     |                27 |               100 |
+-------------+-------------------+-------------------+
| simple      |                 0 |                 0 |
+-------------+-------------------+-------------------+
| moderate    |                 0 |                 0 |
+-------------+-------------------+-------------------+
+-------------+-----------+-------------------+
| Category    |   VES (%) |   total questions |
| challenging |   27.9221 |               100 |
+-------------+-----------+-------------------+
| overall     |   27.9221 |               100 |
+-------------+-----------+-------------------+
| simple      |    0      |                 0 |
+-------------+-----------+------