## Import packages

In [1]:
import json
import os
import shutil
import subprocess

from datasets import load_dataset
from FlagEmbedding import FlagModel
from FlagEmbedding.baai_general_embedding.finetune.hn_mine import find_knn_neg

In [2]:
def save_jsonl_file(file, path):
    with open(path, 'w') as f:
        for item in file:
            json.dump(item, f)
            f.write('\n')

    print(f"Save at {path}")
    return 

## Configs

In [3]:
dataset_name = "airesearch/WangchanX-Legal-ThaiCCL-RAG"
model_name = "BAAI/bge-m3"
output_dir = 'outputs'
temporary_dir = "temp"

In [4]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
if not os.path.exists(temporary_dir):
    os.makedirs(temporary_dir)

## Load data from Hugging Face

In [5]:
legal_dataset = load_dataset(dataset_name)

## Prepare data

In [6]:
# legal documents
legal_documents = []
for split in legal_dataset.keys():
    for data in legal_dataset[split]:
        legal_documents += [i['text'] for i in data['positive_contexts'] if len(i['text']) != 0]

legal_documents = sorted(list(set(legal_documents)))
print(f'No.legal documents = {len(legal_documents)}')

No.legal documents = 4513


In [7]:
# positive data
positive_data = []
for data in legal_dataset['train']:
    pos = [i['text'] for i in data['positive_contexts'] if len(i['text']) != 0]
    if len(pos) != 0:
        positive_data.append({'query': data['question'], 'pos': pos})

temp_input_path = os.path.join(temporary_dir, 'temp_positive_data.jsonl')
save_jsonl_file(positive_data, temp_input_path)

Save at temp/temp_positive_data.jsonl


### Hard Negatives
Hard negatives is a widely used method to improve the quality of sentence embedding. You can mine hard negatives following this command:

- `input_file`: json data for finetuning. 
- `output_file`: path to save JSON data with mined hard negatives for finetuning
- `negative_number`: the number of sampled negatives 
- `sample_range`: where to sample negative. For example, `[2, 100]` means sampling `negative_number` negatives from top2-top200 documents. **You can set larger value to reduce the difficulty of negatives (e.g., set it `[60, 300]` to sample negatives from top60-300 passages)**
- `candidate_pool`: The pool to retrieval. The default value is None, and this script will retrieve from the combination of all `neg` in `input_file`. 
The format of this file is the same as [pretrain data](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/pretrain#2-data-format). If input a candidate_pool, this script will retrieve negatives from this file.
- `use_gpu_for_searching`: whether to use faiss-gpu to retrieve negatives.

In [8]:
# hard negative data for fine tuning
temp_hn_path = os.path.join(temporary_dir, 'temp_hn_data.jsonl')
model = FlagModel(model_name, query_instruction_for_retrieval="")
find_knn_neg(
    model=model, 
    input_file=temp_input_path, 
    candidate_pool=legal_documents, 
    output_file=temp_hn_path, 
    sample_range=[2, 100], 
    negative_number=3, 
    use_gpu=False
)
os.remove(temp_input_path)

----------using 4*GPUs----------
inferencing embedding for corpus (number=4513)--------------


Inference Embeddings: 100%|███████████████████████| 5/5 [00:06<00:00,  1.25s/it]


inferencing embedding for queries (number=7238)--------------


Inference Embeddings: 100%|███████████████████████| 8/8 [00:01<00:00,  4.86it/s]


create index and search------------------


Batches: 100%|████████████████████████████████| 114/114 [00:04<00:00, 27.34it/s]


## Training

Here is an simple example of how to perform unified fine-tuning (dense embedding, sparse embedding and colbert) based on `BAAI/bge-m3`

**some important arguments**:
- `per_device_train_batch_size`: batch size in training. In most of cases, larger batch size will bring stronger performance.
- `train_group_size`: the number of positive and negatives for a query in training.
There are always one positive, so this argument will control the number of negatives (#negatives=train_group_size-1).
Noted that the number of negatives should not be larger than the numbers of negatives in data `"neg":List[str]`.
Besides the negatives in this group, the in-batch negatives also will be used in fine-tuning.
- `negatives_cross_device`: share the negatives across all GPUs. This argument will extend the number of negatives.
- `learning_rate`: select a appropriate for your model. Recommend 1e-5/2e-5/3e-5 for large/base/small-scale. 
- `temperature`: It will influence the distribution of similarity scores. **Recommended value: 0.01-0.1.**
- `query_max_len`: max length for query. Please set it according the average length of queries in your data.
- `passage_max_len`: max length for passage. Please set it according the average length of passages in your data.

For more training arguments please refer to [transformers.TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments)


In [None]:
commands = [
    "cd FlagEmbedding",
    f"""
    CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node 1 \
        -m FlagEmbedding.BGE_M3.run \
        --output_dir {output_dir} \
        --model_name_or_path {model_name} \
        --train_data {temporary_dir} \
        --learning_rate 1e-5 \
        --num_train_epochs 5 \
        --per_device_train_batch_size 1 \
        --dataloader_drop_last True \
        --normlized True \
        --temperature 0.02 \
        --query_max_len 64 \
        --passage_max_len 256 \
        --train_group_size 4 \
        --negatives_cross_device \
        --logging_steps 1000 \
        --same_task_within_batch True \
        --unified_finetuning True \
        --use_self_distill True \
        --save_strategy epoch \
    """,
]

In [None]:
# run commands
for command in commands:
    result = subprocess.run(command, shell=True, capture_output=True, text=True)

In [9]:
shutil.rmtree(temporary_dir)