Official code for "LitE-SQL: A Lightweight and Efficient Text-to-SQL Framework with Vector-based Schema Linking and Execution-Guided Self-Correction"

To set up the environment, use the following commands:
git clone https://github.com/shengminp/LitE-SQL.git
cd LitE-SQL
conda env create -f environment.yml
After installation, your project directory structure should look like this:
.
├── datasets
├── schema_retriever
├── sql_generator
├── environment.yml
└── README.md
Specifically, after downloading the dataset, you should arrange the directory structure of the datasets as follows:
.
└── datasets
├── bird
| |── train
| | ├── train_databases
| | ├── train_tables.json
| | ├── train.json
| | └── train_gold.sql
| └── dev_20240627
| ├── dev_databases
| ├── dev.json
| ├── dev_tables.json
| ├── dev_tied_append.json
| └── dev.sql
└── spider
|── train
| ├── database
| ├── train_gold.sql
| ├── train_others.json
| ├── tables.json
| └── train_spider.sql
|── dev
| ├── database
| ├── dev.json
| ├── tables.json
| └── dev_gold.sql
└── test
├── test_database
├── test.json
├── test_tables.json
└── test_gold.sql
The directory structure of schema retriever is as follow:
.
└──schema_retriever
├── data
│ └── preprocessing.py
├── language_model
│ |── language_model.py
│ └── saved_model
├── utils
│ ├── configs.py
│ ├── dataset_util.py
│ ├── db_utils.py
| └── utils.py
└── scripts
├── fine-tune.py
└── retrieve.py
Please refer to ./schema_retriever/data/README.md
NOTE: ./schema_retriever/data/preprocessing.py must be executed before following steps.
After this fine-tuning step, the embedding model would be saved on $ft_path/$model_name. Run the following command:
python schema_retriever/scripts/fine-tune.py \
--ft_path $ft_path \
--model_name $model_name \
--epoch $epoch \
--batch_size $batch_size \
--n_limit $n_limit \
--LM_MODEL $LM_MODEL \
--DATA_PATH $DATA_PATH \
- $ft_path: Path of embedding model to be saved (
schema_retriever/language_model/saved_model) - $model_name: Name of embedding model (
fine-tuned-embedding-model) - $LM_MODEL: Embedding model from Huggingface (
intfloat/multilingual-e5-large) - $DATA_PATH: Path for training dataset (
schema_retriever/data/fine-tuning_samples_from_BIRD_augmented_version.json)
This Schema Retriever code is for BIRD. After running this retrieval step, retrieved schema information would be saved on sql_generator/dev_20240627/retrieved/BIRD-dev-more-schema.json. Run the following command:
python schema_retriever/scripts/retrieve.py \
--root_path $root_path \
--db_schema_info_path $db_schema_info_path \
--data_path $data_path \
--database_dir_path $database_dir_path \
--SL_K $K
- $root_path: Path of dataset root dir (
datasets/bird/dev_20240627) - $db_schema_info_path: Path of schema info (
dev_tables.json) - $data_path: Path of dataset (
dev.json) - $database_dir_path: Path of database (
dev_databases) - $SL_K: The number of retrieved columns (
25)
The directory structure of SQL generator is as follow:
.
└──sql_generator
├── datasets
│ └── Prepare.ipynb
├── models
├── results
└── scripts
├── utils
│ ├── spider_tool
| | ├── __init__.py
| | ├── evaluation.py
| | ├── exec_eval.py
| | ├── parse.py
| | └── process_sql.py
│ ├── __init__.py
│ ├── config.py
│ ├── data.py
│ ├── metric.py
| └── trainer.py
├── rft.py
├── sft.py
└── generate.py
- Run
./datasets/Prepare.ipynbto prepare dataset for suervised fine-tuning. - The dataset directory should look like this:
.
└── datasets
├── bird
| |── train
| | ├── sft
| | | └── sft_train.json
| | ├── rft
| | ├── BIRD-train-more-schema.json
| | ...
| └── dev_20240627
| ├── retrieved
| | ├── BIRD-dev-more-schema.json
| | ...
| ...
└── spider
|── train
| ├── sft
| | └── sft_train.json
| ├── rft
| ├── SPIDER-train-more-schema.json
| ...
|── dev
| ├── retrieved
| | ├── SPIDER-dev-more-schema.json
| | ...
| ...
└── test
├── retrieved
| ├── SPIDER-test-more-schema.json
| ...
...
Phase-1: Supervised Fine-tuning
In this phase, a Qwen2.5-Coder model is fine-tuned. Run the following command:
python sft.py \
--base_model $model_name \
--data_name $data_name \
--training_phase sft \
- $model_name: Model name from Huggingface
(Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-3B-Instruct, Qwen/Qwen2.5-Coder-7B-Instruct). - $data_name: Dataset name
(bird, spider).
Phase-2: Reinforcement Fine-tuning
In this phase, we fine tune the model with self-generated data. Run the following command:
python rft.py \
--base_model $base_model \
--tune_model $checkpoint \
--ref_model $checkpoint \
--data_name $data_name \
--training_phase rft \
--rft_iter $rft_iter \
- $base_model: Model name from Huggingface
(Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-3B-Instruct, Qwen/Qwen2.5-Coder-7B-Instruct). - $tune_model: Path of checkpoint from sft pahse.
- $ref_model: Same as $tune_model.
- $data_name: Dataset name
(bird, spider).
Use the following command to do inferences after sft phase:
python generate.py \
--base_model $base_model \
--data_name $data_name \
--training_phase sft \
--checkpoint_dir $checkpoint_path\
--schema_file $schema_file_path\
--generation_mode $generation_mode \
--generation_type $generation_type \
--generation_file $generation_file
- $base_name: Model name
(Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-3B-Instruct, Qwen/Qwen2.5-Coder-7B-Instruct). - $data_name: Dataset name
(bird, spider). - $checkpoint_dir: Path of checkpoint.
- $schema_file: Path of schema file
- $generation_mode: Generation mode
(generate, revise). - $generation_type: Generation type
(greedy, random). - $generation_file: Generation file
(test, augment).
Use the following command to generate data for rft phase:
python generate.py \
--base_model $base_model \
--data_name $data_name \
--training_phase sft \
--checkpoint_dir $checkpoint_path\
--schema_file $schema_file_path\
--generation_mode $generation_mode \
--generation_type $generation_type \
--generation_file $generation_file \
--augment_times $augment_times \
--generate_times $generate_times \
--rft_iter 1 \
- $augment_times: The times to augment the training data.
- $generate_times: The times to generate new data.
Use the following command to do inferences after rft phase:
python generate.py \
--base_model $base_model \
--data_name $data_name \
--training_phase rft \
--checkpoint_dir $checkpoint_path\
--schema_file $schema_file_path\
--generation_mode $generation_mode \
--generation_type $generation_type \
--generation_file $generation_file \
--revise_times $revise_times \
--rft_iter 1 \
- $revise_times: The times to revise the generated SQL.
This project is licensed under the MIT © Shengmin Piao & Jieun Lee