In [1]:
# cd ..

/root/anindya/Submission/text2sql/text2sql


## Finetuner

premsql fine-tuner is the module that fine-tunes model for text to SQL tasks. We support the following ways of fine-tuning

1. Full fine-tuning 
2. PEFT using LoRA
3. PEFT using QLoRA

You can even make your own custom fine-tuning pipeline using our components and the set of tools that premsql provides. This tutorial expects you to know the following topics. 

1. [premsql datasets](/examples/datasets.ipynb)
2. [premsql generators](/examples/generators.ipynb)
3. [premsql evaluators](/examples/evaluation.ipynb)
4. [premsql error handling datasets](/examples/error_dataset.ipynb)

Additionally it would be great if you have some ideas on how huggingface transformers [TRL](https://huggingface.co/docs/trl/en/index) library works. We start by importing some packages. 

In [2]:
from premsql.datasets import (
    BirdDataset,
    SpiderUnifiedDataset,
    DomainsDataset,
    GretelAIDataset
)

from premsql.evaluator.from_sqlite import SQLiteExecutor
from premsql.datasets import Text2SQLDataset
from premsql.tuner.peft import Text2SQLPeftTuner
from premsql.datasets.error_dataset import ErrorDatasetGenerator

  from .autonotebook import tqdm as notebook_tqdm


[2024-09-07 10:41:00,820] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/root/miniconda3/envs/deep/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status




  def forward(ctx, input, weight, bias=None):
  def backward(ctx, grad_output):


In [3]:
path = "/root/anindya/text2sql/data"
model_name_or_path = "premai-io/prem-1B-SQL"

### Defining different datasets

Now first we need some training datasets. In our tutorial, we are using small subsets (only for demo purposes, during actual fine-tuning you should be using the full dataset) of various datasets that prem sql provides. We start off by importing the BirdBench training datasets. 

In [4]:
bird_train = BirdDataset(split="train", dataset_folder=path).setup_dataset(
    num_rows=100,
)

2024-09-07 10:41:16,255 - [BIRD-DATASET] - INFO - Loaded Bird Dataset
2024-09-07 10:41:16,257 - [BIRD-DATASET] - INFO - Setting up Bird Dataset
Applying prompt: 100%|██████████| 100/100 [00:00<00:00, 3519.80it/s]
2024-09-07 10:41:16,891 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-07 10:41:16,892 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 188.71it/s]
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 199.78it/s]


Followed by we then load the Spider dataset

In [5]:
spider_train = SpiderUnifiedDataset(split="train", dataset_folder="./data").setup_dataset(
    num_rows=100
)

2024-09-07 10:41:44,000 - [SPIDER-DATASET] - INFO - Loaded Spider Dataset
2024-09-07 10:41:44,005 - [SPIDER-DATASET] - INFO - Setting up Spider Dataset
Applying prompt: 100%|██████████| 100/100 [00:00<00:00, 4144.69it/s]
2024-09-07 10:41:44,636 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-07 10:41:44,637 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 399.31it/s]
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 436.83it/s]


We load the domains dataset here. 

In [6]:
domains_dataset = DomainsDataset(split="train", dataset_folder="./data").setup_dataset(
    num_rows=100,
)

2024-09-07 10:42:00,249 - [DOMAINS-DATASET] - INFO - Loaded Domains Dataset
2024-09-07 10:42:00,252 - [DOMAINS-DATASET] - INFO - Setting up Domains Dataset
Applying prompt: 100%|██████████| 100/100 [00:00<00:00, 2671.91it/s]
2024-09-07 10:42:00,681 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-07 10:42:00,682 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 226.39it/s]
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 241.73it/s]


We also load the Gretel AI synthetic Text to SQL dataset. 

In [7]:
gertelai_dataset = GretelAIDataset(split="train", dataset_folder="./data",).setup_dataset(
    num_rows=100,
)

Applying prompt: 100%|██████████| 100/100 [00:00<00:00, 162130.03it/s]
2024-09-07 10:42:14,958 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-07 10:42:14,958 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 517.27it/s]
Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 579.19it/s]


Last but not the least we also load an error dataset. You can learn more about error handling dataset [here](/examples/error_dataset.ipynb). 

In [8]:
existing_error_dataset = ErrorDatasetGenerator.from_existing(
    experiment_name="testing_error_gen",
)

2024-09-07 10:42:28,011 - [DATASET] - INFO - Casted dataset with model chat template
2024-09-07 10:42:28,012 - [DATASET] - INFO - Starting Tokenization ...
Tokenizing: 100%|██████████| 10/10 [00:00<00:00, 160.95it/s]
Tokenizing: 100%|██████████| 10/10 [00:00<00:00, 180.55it/s]


### NOTE:

Since this tutorial is about fine-tuning using PEFT (using LoRA), so internally this workflow uses TRL. So the datasets we are instantiating do need to be tokenized, since TRL will be tokenizing under the hood. 


Now let's Merge all the datasets. We can pack different datasets into one single entity just like this. 

In [9]:
merged_dataset = [
    *spider_train,
    *bird_train,
    *domains_dataset,
    *gertelai_dataset,
    *existing_error_dataset
    
]

Additionally we also initialize the BirdBench validation dataset so that we can use it during the time of validation. 

Text-to-SQL validation methods are different from normal LLM fine-tuning tasks validation processes. Here we execute generated SQL on the database and check if it matches with the ground truth tables or not. So premsql offers a custom and a robust [huggingface callback](/premsql/tuner/callback.py) that helps you to evaluate during each evaluate steps of model training which is the same evaluation method we do using evaluators. 

So in this case, all you need to do is to define your validation datasets and thats it, our callback will take care of rest of things. If you are unfamiliar with the syntaxes below, you should check out [datasets](/examples/datasets.ipynb) and [evaluator](/examples/evaluation.ipynb) section. 

In [10]:
bird_dev = Text2SQLDataset(dataset_name="bird", split="validation", dataset_folder=path).setup_dataset(
    num_rows=10,
    filter_by=("difficulty", "challenging")
)

2024-09-07 10:43:00,302 - [BIRD-DATASET] - INFO - Loaded Bird Dataset
2024-09-07 10:43:00,303 - [BIRD-DATASET] - INFO - Setting up Bird Dataset
Applying prompt: 100%|██████████| 10/10 [00:00<00:00, 1762.53it/s]


Now that we have set up everything, we need to initialize our tuner class. To initialize our tuner, we need to provide a `model_name_or_path` which will load the model (which is to be fine-tuned) and also provide an `experiment_name` which will save all the logs. 

In [11]:
tuner = Text2SQLPeftTuner(
    model_name_or_path=model_name_or_path,
    experiment_name="lora_tuning"
)

Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}
Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.82s/it]


Finally we call the train functions to provide the following things:

1. train_datasets: The merged datasets which will be used for training
2. output_dir: the output directory in which model weights will be stored
3. num_train_epochs: Number of epochs
4. per_device_train_batch_size: The train batch size per device 
5. gradient_accumulation_steps: Number of gradient accumulation steps
6. evaluation_dataset: The evaluation dataset. It can also be None, and in that case it will not do evaluation steps during fine-tuning.
7. eval_steps: After how many steps we need to start evaluation. 
8. max_seq_length: Maximum permissible sequence length of the model. 
9. executor: Only provide an [executor](/examples/evaluation.ipynb) when you have defined a evaluation_dataset. 
10. filter_eval_results_by: Make sure the filter key and filter value is present inside the dataset. This will filter the results out. In our case we are filtering by difficulty to only evaluate on challenging data points.

Additionally you can provide your additional parameters (which should be compatible with [transformers TrainingArguments](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/trainer#transformers.TrainingArguments)) in form of **kwargs and it will override any other default settings. Now let's use this information to train the model. 

In [None]:
tuner.train(
    train_datasets=merged_dataset,
    output_dir="./output",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    evaluation_dataset=bird_dev,
    eval_steps=100,
    max_seq_length=1024,
    executor=SQLiteExecutor(),
    filter_eval_results_by=("difficulty", "challenging")
)

This will start training the model. And you will see all the model outputs being stored inside `./output` and all the model fine-tuning logs being stored inside `./experiments/train/<experiment-name>` directory. You can checkout our [fine-tuning using LoRA script](/examples/lora_tuning.py) for an end to end example.