In [1]:
# cd ..

/root/anindya/Submission/text2sql/text2sql


## Error Handling Datasets and Prompts

In this section we are going to discuss on how you can create error handling prompt which you can pass it to the models during inference for self-correction from errors, or make error handling prompts to fine-tune your models furthur to make them learn how to handle errors. 

In [2]:
from premsql.datasets.error_dataset import ErrorDatasetGenerator
from premsql.generators.huggingface import Text2SQLGeneratorHF
from premsql.evaluator.from_langchain import ExecutorUsingLangChain

  from .autonotebook import tqdm as notebook_tqdm


In order to make a error handling dataset or error handling prompt, make sure the data entity has: `db_id`, `db_path` and existing `prompt` which was used earlier to generate results from the model.  Let's see an example to understand this better. We will be using our standard BirdBench dataset for this. We also define our generators in this case it will be [Prem-1B-SQL](https://huggingface.co/premai-io/prem-1B-SQL) model and a DB executor from langchain. 

You are't aware of generators, executors and datasets then you can check out the following:

1. [Datasets tutorial](/examples/datasets.ipynb)
2. [Generators tutorial](/examples/generators.ipynb)
3. [Executors and evaluators tutorial](/examples/evaluation.ipynb)

Since we are making a error dataset, so we will be using existing datasets. Because our goal is to transform the existing train datasets to a error handling datasets. 

The flow is simple:

### For training

1. Start with a exising datasets which is compatible with premsql datasets. 
2. Then use a generator to run on that dataset. The executor will gather errors for in-correct generations. 
3. Now use the existing response, initial prompt and the error to create the new data points which will be now using a error handling prompt. 

### For Inference

premsql already handles automatic error handling in the [simple-pipeline](/premsql/pipelines/simple.py) and [execution guided decoding](/examples/generators.ipynb) section. So that you do not need to worry about that. 


Now let's start with defining our generators and execuror first. 

In [3]:
generator = Text2SQLGeneratorHF(
    model_or_name_or_path="premai-io/prem-1B-SQL",
    experiment_name="testing_error_gen",
    type="train", # do not type: 'test' since this will be used during training
    device="cuda:0"
)

executor = ExecutorUsingLangChain()

2024-09-07 09:15:05,666 - [GENERATOR] - INFO - Experiment folder found in: experiments/train/testing_error_gen
Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.20s/it]


After this we define our existing training dataset. We are using BirdBench dataset but you can also use your own text2sql compatible datasets or any of our existing datasets. For demo purposes, we have set `num_rows` to 10, but in actual scenerio you should be using full length of the training datasets. Because generally your error dataset will be lesser than the length of the training dataset if you are using a descent trained model which can generate SQL.

In [4]:
from premsql.datasets import BirdDataset

bird_train = BirdDataset(
    split="train",
    dataset_folder="/root/anindya/text2sql/data"
).setup_dataset(
    num_rows=10,
)

2024-09-07 09:15:12,907 - [BIRD-DATASET] - INFO - Loaded Bird Dataset
2024-09-07 09:15:12,908 - [BIRD-DATASET] - INFO - Setting up Bird Dataset


{}


Applying prompt: 100%|██████████| 10/10 [00:00<00:00, 2803.30it/s]


Now we define our error handling dataset. It is simple, all you need is to feed in the generator of your choice and the executor. 

In [5]:
error_dataset_gen = ErrorDatasetGenerator(
    generator=generator,
    executor=executor
)

Now we generate and save the results. You can use `force` if you want to force the generation once more. Once the error prompt creations are done, it will save the dataset inside `./experiments/train/<generator-experiment-name>/error_dataset.json`. 

In [6]:
error_dataset = error_dataset_gen.generate_and_save(
    datasets=bird_train,
    force=True
)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Generating result ...: 100%|██████████| 10/10 [00:23<00:00,  2.33s/it]
2024-09-07 09:15:43,302 - [GENERATOR] - INFO - All responses are written to: experiments/train/testing_error_gen
2024-09-07 09:15:43,303 - [ERROR-HANDLING-DATASET] - INFO - Starting Evaluation
100%|██████████| 10/10 [00:29<00:00,  2.99s/it]
2024-09-07 09:16:13,195 - [UTILS] - INFO - Saved JSON in: experiments/train/testing_error_gen/accuracy.json
2024-09-07 09:16:13,197 - [UTILS] - INFO - Saved JSON in: experiments/train/testing_error_gen/predict.json
Applying error prompt: 100%|██████████| 10/10 [00:00<00:00, 43873.47it/s]


Once generations are fininshed, this is how a sample datapoint would look like. The `prompt` key will now contain error handling prompt. This is how the error_prompt template looks like:


```python
ERROR_HANDLING_PROMPT = """
{existing_prompt}

# Generated SQL: {sql}

## Error Message

{error_msg}

Carefully review the original question and error message, then rewrite the SQL query to address the identified issues. 
Ensure your corrected query uses correct column names, 
follows proper SQL syntax, and accurately answers the original question 
without introducing new errors.

# SQL: 
"""
```

You can also change the prompt by the following method:

```python
error_dataset = error_dataset_gen.generate_and_save(
    datasets=bird_train,
    force=True,
    prompt_template=your_prompt_template
)
```

Make sure your prompt template should atleast contain the four keys as laid down by the default error handling prompt. 

In [7]:
error_dataset[0]

{'db_id': 'movie_platform',
 'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.',
 'SQL': 'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1',
 'prompt': '\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write correct\nerror free SQL query based on the question asked. Please make sure:\n\n1. Do not add ``` at start / end of the query. It should be a single line query in a  single line (string format)\n2. Make sure the column names are correct and exists in the table\n3. For column names which has a space with it, make sure you have put `` in that column name\n4. Think step by step and always check schema and question and the column names before writing the\nquery. \n\n# Database and Table Schema:\nCREATE TABLE "lists"\n(\n    user_id                     INTEGER\n        references lists_users (user_id),\n    list_id        

You do not need to run the error handling pipeline again and again once you have generated them once. The next time you require this dataset (most probably to use it during fine-tuning) you just need to use `from_existing` class method. 

It requires `experiment_name` as an required argument. Make sure that experiment exists. It is the same experiment name which was used in the generators that was used for error handling dataset generations. Here is an example below. 

In [9]:
existing_error_dataset = ErrorDatasetGenerator.from_existing(
    experiment_name="testing_error_gen"
)

print(existing_error_dataset[0])

{'db_id': 'movie_platform', 'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.', 'SQL': 'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1', 'prompt': '\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write correct\nerror free SQL query based on the question asked. Please make sure:\n\n1. Do not add ``` at start / end of the query. It should be a single line query in a  single line (string format)\n2. Make sure the column names are correct and exists in the table\n3. For column names which has a space with it, make sure you have put `` in that column name\n4. Think step by step and always check schema and question and the column names before writing the\nquery. \n\n# Database and Table Schema:\nCREATE TABLE "lists"\n(\n    user_id                     INTEGER\n        references lists_users (user_id),\n    list_id           

You can even tokenize the entities as well if you want. Right now we only support huggingface transformers tokenizers to tokenize the error dataset during the time of loading. This is how we do it while loading an existing dataset. 

In [11]:
# Even tokenize this

existing_error_dataset = ErrorDatasetGenerator.from_existing(
    experiment_name="testing_error_gen",
    tokenize_model_name_or_path="premai-io/prem-1B-SQL",
)

existing_error_dataset[0]

2024-09-07 09:21:15,995 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-07 09:21:15,996 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing:   0%|          | 0/10 [00:00<?, ?it/s]

Tokenizing: 100%|██████████| 10/10 [00:00<00:00, 160.05it/s]
Tokenizing: 100%|██████████| 10/10 [00:00<00:00, 181.33it/s]


{'input_ids': tensor([32013, 32013,  2042,  ...,   207,    16, 32021]),
 'labels': tensor([ -100,  -100,  -100,  ...,   207,    16, 32021]),
 'raw': {'db_id': 'movie_platform',
  'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.',
  'SQL': 'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1',
  'prompt': '<｜begin▁of▁sentence｜>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\n\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write correct\nerror free SQL query based on the question asked. Please make sure:\n\n1. Do not add ``` at start / end of the query. It should

### Another example using sqlite executor

This is an another example which uses sqlite executor to do the same thing as done above. This shows how easy it is to plug and play the components and customize it accordingly. 

In [12]:
from premsql.evaluator.from_sqlite import SQLiteExecutor

generator = Text2SQLGeneratorHF(
    model_or_name_or_path="premai-io/prem-1B-SQL",
    experiment_name="testing_error_sqlite",
    type="train",
    device="cuda:0"
)
sqlite_executor = SQLiteExecutor()

error_dataset_gen = ErrorDatasetGenerator(
    generator=generator,
    executor=sqlite_executor
)

2024-09-07 09:21:27,223 - [GENERATOR] - INFO - Experiment folder found in: experiments/train/testing_error_sqlite
Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}
Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.90s/it]


You can also generate a tokenized dataset on the fly. Here is how you do that. 

In [15]:
error_dataset_from_sqlite = error_dataset_gen.generate_and_save(
    datasets=bird_train,
    tokenize=True,
    force=True
)

Generating result ...: 100%|██████████| 10/10 [00:22<00:00,  2.22s/it]
2024-09-07 09:22:09,232 - [GENERATOR] - INFO - All responses are written to: experiments/train/testing_error_sqlite
2024-09-07 09:22:09,233 - [ERROR-HANDLING-DATASET] - INFO - Starting Evaluation
100%|██████████| 10/10 [00:29<00:00,  2.91s/it]
2024-09-07 09:22:38,359 - [UTILS] - INFO - Saved JSON in: experiments/train/testing_error_sqlite/accuracy.json
2024-09-07 09:22:38,361 - [UTILS] - INFO - Saved JSON in: experiments/train/testing_error_sqlite/predict.json
Applying error prompt: 100%|██████████| 10/10 [00:00<00:00, 44104.14it/s]
2024-09-07 09:22:38,583 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-07 09:22:38,584 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|██████████| 10/10 [00:00<00:00, 158.85it/s]
Tokenizing: 100%|██████████| 10/10 [00:00<00:00, 182.43it/s]


In [16]:
error_dataset_from_sqlite[0]

{'input_ids': tensor([32013, 32013,  2042,  ...,   207,    16, 32021]),
 'labels': tensor([ -100,  -100,  -100,  ...,   207,    16, 32021]),
 'raw': {'db_id': 'movie_platform',
  'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.',
  'SQL': 'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1',
  'prompt': '<｜begin▁of▁sentence｜>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\n\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write correct\nerror free SQL query based on the question asked. Please make sure:\n\n1. Do not add ``` at start / end of the query. It should

Thats it, that is how you generate a error handling dataset. This dataset will be compatible with other premsql datasets. So you can use / mix all of them to use as a singular dataset entity which can be now used collectively for fine-tuning purposes. 