In [1]:
# cd ..

/root/anindya/Submission/text2sql/text2sql


## Generators

premsql generators is responsible to produce SQL from natural language question from the user. You can think this as of the inference api specific to text-to-sql. Generators are very much modular in nature, you can plug in any kind of third party API or model or any kind of pipeline (more on this below). 

This tutorial is going to cover how to use huggingface and premai provider to use local models and hosted models for free. Lastly, we are also going to show how can you write your own generators. Let's start by importing all the various packages. 

In [2]:
from premsql.generators.huggingface import Text2SQLGeneratorHF
from premsql.generators.premai import Text2SQLGeneratorPremAI
from premsql.datasets import Text2SQLDataset

  from .autonotebook import tqdm as notebook_tqdm


### How Generators work

premsql generators provide two types of generation strategies. One is a simple generation strategy where we simply generate the SQL from the prompt (which contains the schema of the tables, user questions, few shot examples etc). 

There is another strategy which sometimes give a bump in the performance is, execution guided decoding. Simply, it means the model generates a SQL and it executes the SQL into the DB. If it gets an error, it uses that error in a self-correction prompt and generates once again, till the max number of trials maxes out. 

We will be showing both the examples below. Let's start with simple generation. We will be using BirdBench dev dataset for this example. 

In [5]:
dataset = Text2SQLDataset(
    dataset_name="bird",
    split="test",
    database_folder_name="test_databases",
    json_file_name="test.json",
    dataset_folder="/root/anindya/Submission/text2sql/data",
).setup_dataset(
    num_rows=10,
    num_fewshot=3,
)

2024-09-05 18:17:04,316 - [BIRD-DATASET] - INFO - Loaded Bird Dataset
2024-09-05 18:17:04,317 - [BIRD-DATASET] - INFO - Setting up Bird Dataset


{'database_folder_name': 'test_databases', 'json_file_name': 'test.json'}


Applying prompt: 100%|██████████| 10/10 [00:00<00:00, 1724.35it/s]


In [6]:
dataset[0]

{'question_id': 0,
 'db_id': 'california_schools',
 'question': 'What is the highest eligible free rate for K-12 students in the schools in Alameda County?',
 'evidence': 'Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`',
 'SQL': "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
 'difficulty': 'simple',
 'db_path': '/root/anindya/Submission/text2sql/data/bird/test/test_databases/california_schools/california_schools.sqlite',
 'prompt': "\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write correct\nerror free SQL query based on the question asked. Please make sure:\n\n1. Do not add ``` at start / end of the query. It should be a single line query in a  single line (string format)\n2. Make sure the column names are correct and exists in the table\n3. For column names which has a

The input of the generator is not just prompt but a `data_blob` which should contain the following information:

- `prompt`: The prompt which needs to be passed
- `db_path`: The db path 

If you have these two information you can use the generators for your own inference using your own data. Make sure the prompt contains all the schema of the tables belonging to the DB. Now let's define our generators. We will be using [Prem-1B-SQL](https://huggingface.co/premai-io/prem-1B-SQL) for this experiment. 

In [8]:
generator = Text2SQLGeneratorHF(
    model_or_name_or_path="premai-io/prem-1B-SQL",
    experiment_name="test_generators",
    device="cuda:0",
    type="test"
)

2024-09-05 18:21:22,692 - [GENERATOR] - INFO - Created new experiment folder: experiments/test/test_generators
Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.08s/it]


`Text2SQLGeneratorHF` internally uses HuggingFace transformers. You instantiate the class with a `experiment_name`. A folder `./experiments/<experiment_name>` is created in your current directory (You can also change that directory by assigning the path to `experiment_folder` argument). 

This folders are created to store the generation and evaluation result, so that you do need to generate results everytime. It caches them inside the experiment directory. Now let's generate results using a single datapoint. 

In [11]:
import sqlparse

sample = dataset[0]

response = generator.generate(
    data_blob={
        "prompt": sample["prompt"],
    },
    temperature=0.1,
    max_new_tokens=256
)

sqlparse.format(response)



"SELECT MAX(T1.`Free Meal Count (K-12)`) FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T2.County = 'Alameda' AND T1.`Free Meal Count (K-12)` IS NOT NULL AND T1.`Free Meal Count (K-12)` > 0 GROUP BY T1.`Free Meal Count (K-12)` ORDER BY T1.`Free Meal Count (K-12)` DESC LIMIT 1"

The `generate` method is used just for single response. This does not saves anything. Now let's try to generate for multiple question and save the results. 