This is the official code implementation of the paper Training Task Experts through Retrieval Based Distillation by Jiaxin Ge*, Xueying Jia*, Vijay Viswanathan, Hongyin Luo, and Graham Neubig.
Retrieval Based Distillation (ReBase) is a method that first retrieves data from rich online sources and then transforms them into domain-specific data. This method greatly enhances data diversity. Moreover, ReBase generates Chain-of-Thought reasoning and distills the reasoning capacity of LLMs.
Please prepare the Anthropic REST API key and the Huggingface token (with write access). Copy these into the config.py
file located in both the data_preparation
and finetune
folders.
conda create -n rebase python=3.10
conda activate rebase
pip install -r requirements.txt
- More details about unsloth install could refer to: https://github.com/unslothai/unsloth
You can modify and run run_scripts.sh to pass all the procedures needed. However, it is recommended to read through the following steps to understand more details about what we are doing here.
cd data_preparation
Step 1: download datasets and and merge them into a corpus.
python merge_datasets.py
- This step takes 1 or 2 days to finish, you could also find a subset of
dataset_index.json
to process data.
Step 2: get corpus embedding
python embed_corpus.py
Step 3: retrieve from corpus for a specific task
python retrieve_data.py --task_name squad
- available task_name including: 'mconala', 'mnli', 'squad', 'bbh' (This task includes following subtasks:['date_understanding', 'logical_deduction_five_objects', 'movie_recommendation', 'multistep_arithmetic_two', 'object_counting', 'word_sorting', 'hyperbaton', 'sports_understanding', 'boolean_expressions', 'tracking_shuffled_objects_seven_objects', 'ruin_names', 'tracking_shuffled_objects_three_objects', 'causal_judgement', 'reasoning_about_colored_objects', 'logical_deduction_seven_objects', 'temporal_sequences', 'salient_translation_error_detection', 'tracking_shuffled_objects_five_objects', 'geometric_shapes', 'disambiguation_qa', 'dyck_languages', 'navigate', 'formal_fallacies', 'web_of_lies', 'snarks', 'penguins_in_a_table', 'logical_deduction_three_objects'])
Step 4: transform retrieved data to training data
python dataset_transformer.py --task_name squad
cd finetune
Step 5: finetune model using transformed data
python finetune.py --model_name unsloth/llama-3-8b-bnb-4bit --data_path ../data_preparation/tasks/squad/transformed_data_score_use_full_row_dataset.csv --finetuned_model_name Username/modelname
model_name
: Name of the pre-trained model.data_path
: Path to the training data CSV file.finetuned_model_name
: Specify the name of the fine-tuned model. Replace "Username" with your Huggingface username and create a model name of your choice.
We present the datasets generated by ReBase under the folder rebase_datasets
. Under each folder, there contains three files:
retrieved_data_rank_score_dataset.json
: The row ids retrieved from the corpus.selected_rows_rank_score_dataset.json
: The original rows corresponding to the row ids.transformed_data_score_use_full_row_dataset.csv
: The final transformed dataset by ReBase.
These files provides an overview of what data is retrieved and how are they transformed.
If you find our work useful, you can cite us at
@article{ge2024training,
title={Training Task Experts through Retrieval Based Distillation},
author={Ge, Jiaxin and Jia, Xueying and Viswanathan, Vijay and Luo, Hongyin and Neubig, Graham},
journal={arXiv preprint arXiv:2407.05463},
year={2024}
}