## Datasets

premsql datasets helps to use different already available and pre-processed datasets in a simple way. Since Text-to-SQL is a complex task and requires data which has a depdenency of database and tables. 

premsql datasets provides simple APIs to use those and also helps you to create your own dataset using your own private databases. 

In [1]:
import os 
import sys 

current_dir = os.getcwd()
dir_to_use = os.path.abspath(os.path.join(current_dir, ".."))
sys.path.append(dir_to_use)

Currently the following datasets are readily available:

1. [BirdBench Dataset](https://huggingface.co/datasets/premai-io/birdbench)
2. [Spider Unified Datasets](https://huggingface.co/datasets/premai-io/spider)
3. [Domains](https://huggingface.co/datasets/premai-io/domains)
4. [Gretel AI Dataset](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql) (A synthetic text to SQL dataset by Gretel AI)

Now we are going to see how to use these datasets in a simple way.

In [2]:
from premsql.datasets import Text2SQLDataset
from premsql.datasets.utils import print_data
# load the bird dataset

bird_dataset = Text2SQLDataset(
    dataset_name='bird', split="train", force_download=False,
    dataset_folder="../data"
)

2024-09-04 19:35:47,374 - [BIRD-DATASET] - INFO - Loaded Bird Dataset


Currently, this is just the object which has the raw the data. This object consist of two methods: 

1. `raw_dataset`: This will return a dict containing the raw data opened form the json file. 
2. `filters_available`: This will return the list of filters available for the dataset.

So for our train dataset here is how we can see the raw data.

In [3]:
raw_bird_training_dataset = bird_dataset.raw_dataset
raw_bird_training_dataset[0]


[1m{[0m
[2;32m│   [0m[32m'db_id'[0m: [32m'movie_platform'[0m,
[2;32m│   [0m[32m'question'[0m: [32m'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.'[0m,
[2;32m│   [0m[32m'evidence'[0m: [32m'released in the year 1945 refers to movie_release_year = 1945;'[0m,
[2;32m│   [0m[32m'SQL'[0m: [32m'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'[0m
[1m}[0m

Now, we can also see what all filters are available for the dataset. You can simply use `.filters_available` to see the available filters.

In [4]:
bird_dataset.filter_availables

[1m[[0m[32m'db_id'[0m[1m][0m

Now, in order to load the processed dataset, you can simply call `setup_dataset` method. This will load the processed dataset and return the dataset object. 

This dataset has certain (optional) methods available for furthur customization:

- filter_by: tuple | None: This will filter the dataset based on the given filter.

- num_rows: int | None: This will return the number of rows from the dataset.

- num_fewshot: int | None: This will determine how many few shot examples to create in the prompt

- model_name_or_path: str | None: This will apply the prompt template of the model you choose. For example, if you want to finetune a llama model then it will wrap the prompt with the llama model prompt template.

Also if this is not provided then it will not tokenize the dataset. 

- prompt_template: str | None: If you want to use any other kind of prompt template then you can provide that here. You can check out the default prompt template [here](/premsql/datasets/prompts.py). 

**Note**:
If `model_name_or_path` is provided then it will automatically use the prompt template of that model and tokenize, otherwise it will not.

In [5]:
# Now let's setup the bird dataset 

bird_dataset = bird_dataset.setup_dataset(
    model_name_or_path="premai-io/prem-1B-SQL", 
    num_fewshot=3, 
    num_rows=3
)

print_data(bird_dataset[0])

2024-09-04 19:35:47,403 - [BIRD-DATASET] - INFO - Setting up Bird Dataset
Applying prompt: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 2846.81it/s]
2024-09-04 19:35:47,795 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-04 19:35:47,796 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 308.34it/s]
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 371.85it/s]



[1m{[0m
[2;32m│   [0m[32m'input_ids'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1;36m32013[0m, [1;36m32013[0m,  [1;36m2042[0m,  [33m...[0m,   [1;36m207[0m,    [1;36m16[0m, [1;36m32021[0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'labels'[0m: [1;35mtensor[0m[1m([0m[1m[[0m [1;36m-100[0m,  [1;36m-100[0m,  [1;36m-100[0m,  [33m...[0m,   [1;36m207[0m,    [1;36m16[0m, [1;36m32021[0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'raw'[0m: [1m{[0m
[2;32m│   │   [0m[32m'db_id'[0m: [32m'movie_platform'[0m,
[2;32m│   │   [0m[32m'question'[0m: [32m'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.'[0m,
[2;32m│   │   [0m[32m'evidence'[0m: [32m'released in the year 1945 refers to movie_release_year = 1945;'[0m,
[2;32m│   │   [0m[32m'SQL'[0m: [32m'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'[0m,
[2;32m│   │   [0m[32m'db_path'[0m: [3

Sometimes tokenization could be time consuming, and it could be computation heavt. So, you can also preview the dataset without even tokenizing first. Here is
how you do it. 

In [6]:
bird_dataset_without_tokenization = Text2SQLDataset(
    dataset_name='bird', split="train", force_download=False,
    dataset_folder="../data"
).setup_dataset(
    model_name_or_path=None, num_fewshot=3, num_rows=3
)

print_data(bird_dataset_without_tokenization[0])

2024-09-04 19:35:47,840 - [BIRD-DATASET] - INFO - Loaded Bird Dataset
2024-09-04 19:35:47,840 - [BIRD-DATASET] - INFO - Setting up Bird Dataset
Applying prompt: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 2420.72it/s]



[1m{[0m
[2;32m│   [0m[32m'db_id'[0m: [32m'movie_platform'[0m,
[2;32m│   [0m[32m'question'[0m: [32m'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.'[0m,
[2;32m│   [0m[32m'evidence'[0m: [32m'released in the year 1945 refers to movie_release_year = 1945;'[0m,
[2;32m│   [0m[32m'SQL'[0m: [32m'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'[0m,
[2;32m│   [0m[32m'db_path'[0m: [32m'../data/bird/train/train_databases/movie_platform/movie_platform.sqlite'[0m,
[2;32m│   [0m[32m'prompt'[0m: [32m'\n# Follow these instruction:\nYou will be given schemas of tables of a database. Your job is to write....itles released in year 1945. Sort the listing by the descending order of movie popularity.\n\n# SQL: \n'[0m
[1m}[0m

BirdDataset has two instance, a `train` and `validation` instance. For train dataset, you can only filter by `db_id`. This will only return results which are belonging to that database id. 

For BirdDevDataset you can filter by `db_id` and `difficulty`. Here is how you load a validation dataset and then filter by `difficulty`. 

In [7]:
# Load the BirdBench dev dataset and filter the dataset by 
# difficulty

bird_validation = Text2SQLDataset(
    dataset_name='bird', split="validation", force_download=False,
    dataset_folder="../data"
).setup_dataset(
    model_name_or_path=None, 
    num_fewshot=3, 
    num_rows=100,
    filter_by=("difficulty", "simple")
)

# count the number of examples in the dataset which has 
# difficulty level as simple

len([
    example for example in bird_validation 
    if example["difficulty"] == "simple"
])

2024-09-04 19:35:47,900 - [BIRD-DATASET] - INFO - Loaded Bird Dataset
2024-09-04 19:35:47,901 - [BIRD-DATASET] - INFO - Setting up Bird Dataset
Applying prompt: 100%|███████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 3927.99it/s]


[1;36m100[0m

Similarly we can also filter by the dataset by `db_id`. 

In [8]:
bird_validation = Text2SQLDataset(
    dataset_name='bird', split="validation", force_download=False,
    dataset_folder="../data"
).setup_dataset(
    model_name_or_path="premai-io/prem-1B-SQL",
)
print_data(bird_validation[0])

2024-09-04 19:35:47,945 - [BIRD-DATASET] - INFO - Loaded Bird Dataset
2024-09-04 19:35:47,945 - [BIRD-DATASET] - INFO - Setting up Bird Dataset
Applying prompt: 100%|█████████████████████████████████████████████████████████████████| 1534/1534 [00:00<00:00, 3733.67it/s]
2024-09-04 19:35:48,675 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-04 19:35:48,676 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|███████████████████████████████████████████████████████████████████████| 1534/1534 [00:04<00:00, 364.93it/s]
Tokenizing: 100%|███████████████████████████████████████████████████████████████████████| 1534/1534 [00:04<00:00, 369.42it/s]



[1m{[0m
[2;32m│   [0m[32m'input_ids'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1;36m32013[0m, [1;36m32013[0m,  [1;36m2042[0m,  [33m...[0m,   [1;36m207[0m,    [1;36m16[0m, [1;36m32021[0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'labels'[0m: [1;35mtensor[0m[1m([0m[1m[[0m [1;36m-100[0m,  [1;36m-100[0m,  [1;36m-100[0m,  [33m...[0m,   [1;36m207[0m,    [1;36m16[0m, [1;36m32021[0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'raw'[0m: [1m{[0m
[2;32m│   │   [0m[32m'question_id'[0m: [1;36m0[0m,
[2;32m│   │   [0m[32m'db_id'[0m: [32m'california_schools'[0m,
[2;32m│   │   [0m[32m'question'[0m: [32m'What is the highest eligible free rate for K-12 students in the schools in Alameda County?'[0m,
[2;32m│   │   [0m[32m'evidence'[0m: [32m'Eligible free rate for K-12 = `Free Meal Count [0m[32m([0m[32mK-12[0m[32m)[0m[32m` / `Enrollment [0m[32m([0m[32mK-12[0m[32m)[0m[32m`'[0m,
[2;32m│   │   [0m[32m'SQL'[0m: [32m"SELECT `Free Me

That's it, thats how easy it is to use the datasets. Similarly you can also use other available datasets

In [9]:
# Loading Spider Dataset

spider_dataset = Text2SQLDataset(
    dataset_name="spider",
    split="train",
    dataset_folder="../data",
).setup_dataset(
    num_fewshot=3,
    num_rows=3,
    model_name_or_path="premai-io/prem-1B-SQL",
)

2024-09-04 19:35:57,783 - [SPIDER-DATASET] - INFO - Loaded Spider Dataset
2024-09-04 19:35:57,786 - [SPIDER-DATASET] - INFO - Setting up Spider Dataset
Applying prompt: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 2293.22it/s]
2024-09-04 19:35:58,266 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-04 19:35:58,266 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 472.86it/s]
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 578.37it/s]


In [10]:
## Loading Domains dataset

domains = Text2SQLDataset(
    dataset_name="domains",
    split="train",
    dataset_folder="../data",
).setup_dataset(
    num_fewshot=3,
    num_rows=3,
    model_name_or_path="premai-io/prem-1B-SQL",
)

2024-09-04 19:35:58,299 - [DOMAINS-DATASET] - INFO - Loaded Domains Dataset
2024-09-04 19:35:58,301 - [DOMAINS-DATASET] - INFO - Setting up Domains Dataset
Applying prompt: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 2168.72it/s]
2024-09-04 19:35:58,655 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-04 19:35:58,656 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 268.16it/s]
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 234.70it/s]


In [11]:
# Loading Gretel AI Dataset (This is a synthetic dataset)

gretel_dataset = Text2SQLDataset(
    dataset_name="gretel",
    split="train",
    dataset_folder="../data",
).setup_dataset(
    num_fewshot=3,
    num_rows=3,
    model_name_or_path="premai-io/prem-1B-SQL",
)

Applying prompt: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 2851.97it/s]
2024-09-04 19:35:59,384 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-04 19:35:59,385 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 645.87it/s]
Tokenizing: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 845.97it/s]


In [12]:
print_data(gretel_dataset[0]["raw"])


[1m{[0m
[2;32m│   [0m[32m'id'[0m: [1;36m5097[0m,
[2;32m│   [0m[32m'question'[0m: [32m'What is the total volume of timber sold by each salesperson, sorted by salesperson?'[0m,
[2;32m│   [0m[32m'schema'[0m: [32m"CREATE TABLE salesperson [0m[32m([0m[32msalesperson_id INT, name TEXT, region TEXT[0m[32m)[0m[32m; INSERT INTO salesperson [0m[32m([0m[32msalesperson_id, name, region[0m[32m)[0m[32m VALUES [0m[32m([0m[32m1, 'John Doe', 'North'[0m[32m)[0m[32m, [0m[32m([0m[32m2, 'Jane Smith', 'South'[0m[32m)[0m[32m; CREATE TABLE timber_sales [0m[32m([0m[32msales_id INT, salesperson_id INT, volume REAL, sale_date DATE[0m[32m)[0m[32m; INSERT INTO timber_sales [0m[32m([0m[32msales_id, salesperson_id, volume, sale_date[0m[32m)[0m[32m VALUES [0m[32m([0m[32m1, 1, 120, '2021-01-01'[0m[32m)[0m[32m, [0m[32m([0m[32m2, 1, 150, '2021-02-01'[0m[32m)[0m[32m, [0m[32m([0m[32m3, 2, 180, '2021-01-01'[0m[32m)[0m[32m;"[0m,
[2

One of the best things of the premsql datasets is that it supports packing. This means you can pack multiple datasets together and use them as a single dataset. This is very useful when you want to train on multiple datasets.

In [13]:
# Merge all the datasets

print(f"Length of bird dataset: {len(bird_dataset)}")
print(f"Length of spider dataset: {len(spider_dataset)}")
print(f"Length of domains dataset: {len(domains)}")
print(f"Length of gretel dataset: {len(gretel_dataset)}")

merged_dataset = [*bird_dataset, *spider_dataset, *domains, *gretel_dataset]
print(f"Length of merged dataset: {len(merged_dataset)}")

In [14]:
print_data(merged_dataset[0])


[1m{[0m
[2;32m│   [0m[32m'input_ids'[0m: [1;35mtensor[0m[1m([0m[1m[[0m[1;36m32013[0m, [1;36m32013[0m,  [1;36m2042[0m,  [33m...[0m,   [1;36m207[0m,    [1;36m16[0m, [1;36m32021[0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'labels'[0m: [1;35mtensor[0m[1m([0m[1m[[0m [1;36m-100[0m,  [1;36m-100[0m,  [1;36m-100[0m,  [33m...[0m,   [1;36m207[0m,    [1;36m16[0m, [1;36m32021[0m[1m][0m[1m)[0m,
[2;32m│   [0m[32m'raw'[0m: [1m{[0m
[2;32m│   │   [0m[32m'db_id'[0m: [32m'movie_platform'[0m,
[2;32m│   │   [0m[32m'question'[0m: [32m'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.'[0m,
[2;32m│   │   [0m[32m'evidence'[0m: [32m'released in the year 1945 refers to movie_release_year = 1945;'[0m,
[2;32m│   │   [0m[32m'SQL'[0m: [32m'SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1'[0m,
[2;32m│   │   [0m[32m'db_path'[0m: [3

### How does a prompt looks like in premsql

You might wonder how does a prompt looks like in premsql. This is how a single prompt looks like when wrapped around a model's prompt template. 

In [15]:
print(gretel_dataset[0]["raw"]["prompt"])

### Creating your own dataset

In this section, we are going to see how we can make our own dataset similar like the above. Creating your own dataset could come with several customization and variables. One of the easiest ways to create your own dataset is to simply annotate the dataset in the given file structure:

```
├── databases
│   ├── california_schools
│       ├── california_schools.sqlite
│   ├── card_games
│   ├── codebase_community
│   ├── debit_card_specializing
│   ├── european_football_2
│   ├── financial
│   ├── formula_1
│   ├── student_club
│   ├── superhero
│   ├── thrombosis_prediction
│   └── toxicology
├── train.json  
├── validation.json # Optional 
```

The reason we do this hierchy is, in a real world scenerio, we can have 
multiple databases, and each databases could be multiple tables. So this is how we organize them.

Suppose you are saving everything inside `./data` folder then inside that folder you should have a `databases` folder (you can name it something else too) and a `train/validation.json` file. 

Inside the databases folder you should have multple sub folders where under each sub-folder you should have a `.sqlite` file of the same name. For example: if the db name is `california_schools` then you should have a .sqlite file inside `california_schools` folder. 

The `train` or `validation` JSON file, should be a list of dictionaries, having the following (required) keys:

1. `db_id`: this represent the folder and the `.sqlite` file name.
2. `question`: this represent the question asked by the user.
3. `SQL`: This is the ground truth SQL.

**Please note:** All the keys are case sensitive. Here is an example of a single datapoint. 

```json
"question_id": 0,
"db_id": "california_schools",
"question": "What is the highest eligible free rate for K-12 students in the schools in Alameda County?",
"evidence": "Eligible free rate for K-12 = `Free Meal Count (K-12)` / `Enrollment (K-12)`",
"SQL": "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
"difficulty": "simple"
```

You can also keep other keys too, those will be automatically used as filter keys. Now you can use the code to automatically load your dataset from the folder. 

In [16]:
from premsql.datasets import StandardDataset

path = "../data/bird/validation"
dataset = StandardDataset(
    split="validation",
    dataset_path=path,
    database_folder_name="dev_databases",
    json_file_name="validation.json",
)

In [17]:
dataset.filter_availables

[1m[[0m[32m'db_id'[0m, [32m'difficulty'[0m[1m][0m

We have loaded our Bird dev database but this time we have used the `StandardDataset` class. A `StandardDataset` class acts like a template for all text2sql compatible datasets when following the above structure. 

### Towards more customization

Last but not the least, there is one more level of customization that you can do while creating text-to-sql datasets. Till now all of these use cases shown above were tightly coupled with `.sqlite` specific databases. However if you:

1. have different databases (like postgres or any cloud DB instance)
2. or want to have lot of custom logics, before making prompts
3. or add more utility on top of premsql

This section will help you to achieve that. 

**Note** In case of the point number one, you can also migrate one subset of the dataset to SQLite. Once you have migrated a subset of your database content to SQLite and have done annotations for that, you can then go for the first route to create a Text2SQL compatible dataset for fine-tuning and inference. 

If you still want to go for full customization then you can achieve this with three steps. A detailed tutorial on this will be coming on future versions. However in short, you need to define two things for making a premsql fully custom dataset.

**DatasetInstance:** A dataset instance helps to operations on individual datapoints. You need to extend `premsql.datasets.base.Text2SQLBaseInstance` class to define your own. Here is how a blueprint looks like:

```python

class CustomDataInstance(Text2SQLBaseInstance):
    def __init__(self, dataset: list[dict]) -> None:
        super().__init__(dataset=dataset)

    def schema_prompt(self, db_path: str) -> str:
        # write your schema prompt here
        # you need to fetch the schema from your database
        # and format it. For sqlite database it would look
        # like this: SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}
        # check out Text2SQLBaseInstance premsql/datasets/base for more details
```

Additionally this class some more methods: `additional_prompt` `apply_prompt` those have some db agnostic default implementation, however you can change those too if you want. 

Once you have your instance defined, you can now define your custom class by inheriting from
`premsql.datasets.base.Text2SQLBaseDataset` class, like this:


```python
class CustomText2SQLDataset(Text2SQLBaseDataset):
    def __init__(
        self,
        split: str,
        dataset_folder: Optional[Union[str, Path]] = "./data",
        hf_token: Optional[str] = None,
        force_download: Optional[bool] = False,
    ):
        # Define your logic here
        pass 

    def setup_dataset(
        self,
        filter_by: tuple | None = None,
        num_rows: int | None = None,
        num_fewshot: int | None = None,
        model_name_or_path: str | None = None,
        prompt_template: str | None = None,
    ):
        logger.info("Setting up Spider Dataset")
        return super().setup_dataset(
            filter_by, num_rows, num_fewshot, model_name_or_path, prompt_template
        )
```

Based on your requirements you can define all the necessary things in __init__ method and `setup_dataset` method. You can checkout `Text2SQLBaseDataset` class to see how things are defined. We will roll out a detailed tutorial on how to make a dataset for a different database very soon. 