In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
import torch
import warnings

warnings.simplefilter("ignore", FutureWarning)

torch.__version__

'2.5.1+cu124'

## **BGE-M3**

**Note:** **[BGE-M3 Fine-tune Guide](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune/embedder)**


### **Fine-tune**

In [None]:
from BGE_M3.arguments import DataArguments, ModelArguments, RetrieverTrainingArguments
from BGE_M3.data import CustomTrainDataset, CustomEmbedCollator
from BGE_M3.modeling import BGEM3Model

import logging
import os

logger = logging.getLogger('BGE_M3_FINE_TUNE')
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)

In [None]:
# Khởi tạo đối tượng ModelArguments với tham số
model_args = ModelArguments(
    model_name_or_path="bge-m3-output",
    tokenizer_name="BAAI/bge-m3",
    cache_dir=None
)

# Khởi tạo đối tượng DataArguments với tham số
data_args = DataArguments(
    knowledge_distillation=True,
    train_data=["Data/train_finetune_hn_v3.jsonl"],
    cache_path="/data_cache",
    train_group_size=8,
    query_max_len=160,
    passage_max_len=256,
    max_example_num_per_dataset=100000,
    same_task_within_batch=True,
    shuffle_ratio=0.0,
    small_threshold=0,
    drop_threshold=0,
    pad_to_multiple_of=8
)

# Khởi tạo đối tượng RetrieverTrainingArguments với tham số
training_args = RetrieverTrainingArguments(
    output_dir="/kaggle/working/bge-m3-output",
    negatives_cross_device=False,
    temperature=0.02,
    fix_position_embedding=True,
    sentence_pooling_method='cls',
    normlized=True,
    enable_sub_batch=False,
    unified_finetuning=False,
    use_self_distill=False,
    fix_encoder=False,
    colbert_dim=-1,
    self_distill_start_step=0,
    per_device_train_batch_size=8,
    sub_batch_size=-1,
    num_train_epochs=3,
    learning_rate=1e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
)

import os

os.makedirs(training_args.output_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = BGEM3Model(model_name=model_args.model_name_or_path,
                   tokenizer=tokenizer,
                   normlized=training_args.normlized,
                   sentence_pooling_method=training_args.sentence_pooling_method,
                   negatives_cross_device=training_args.negatives_cross_device,
                   temperature=training_args.temperature,
                   enable_sub_batch=training_args.enable_sub_batch,
                   unified_finetuning=training_args.unified_finetuning,
                   use_self_distill=training_args.use_self_distill,
                   colbert_dim=training_args.colbert_dim,
                   self_distill_start_step=training_args.self_distill_start_step)


In [None]:
FINETUNE_ALL = False

if not FINETUNE_ALL:
    if training_args.fix_position_embedding:
        for k, v in model.named_parameters():
            if "position_embeddings" in k:
                v.requires_grad = False
    if training_args.fix_encoder:
        for k, v in model.named_parameters():
            if "colbert_linear" in k or 'sparse_linear' in k:
                v.requires_grad = False
    
    for k, v in model.named_parameters():
        if "embeddings" in k:
            v.requires_grad = False
    
    for layer in model.model.encoder.layer[:-12]:
        for name, param in layer.named_parameters():
            param.requires_grad = False
else:
    if training_args.fix_position_embedding:
        for k, v in model.named_parameters():
            if "position_embeddings" in k:
                v.requires_grad = False
    if training_args.fix_encoder:
        for k, v in model.named_parameters():
            if "colbert_linear" in k or 'sparse_linear' in k:
                v.requires_grad = False
    
    for k, v in model.named_parameters():
        if "embeddings" in k:
            v.requires_grad = False

# for k, v in model.named_parameters():
#     print(k, v.requires_grad)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")


In [None]:
train_dataset = CustomTrainDataset(args=data_args, 
                                        default_batch_size=training_args.per_device_train_batch_size, 
                                        seed=training_args.seed)
data_collator = CustomEmbedCollator(
    tokenizer,
    query_max_len=data_args.query_max_len,
    passage_max_len=data_args.passage_max_len,
    sub_batch_size=training_args.sub_batch_size,
    pad_to_multiple_of=2,
    padding='max_length',
    return_tensors="pt"
)

train_dataloader = DataLoader(
    train_dataset,
    collate_fn=data_collator,
    shuffle=True
)

In [None]:
from torch import optim
from torch.cuda.amp import autocast, GradScaler
from transformers import BatchEncoding
from tqdm import tqdm
from transformers import get_scheduler


model = model.to(device).train()

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

num_steps = len(train_dataloader)
total_traning_steps = num_steps * training_args.num_train_epochs

scaler = GradScaler()

train_dataset.refresh_epoch()

scheduler = get_scheduler(
    name='cosine_with_min_lr',
    optimizer=optimizer,
    num_warmup_steps=int(total_traning_steps * training_args.warmup_ratio),
    num_training_steps=total_traning_steps,
    scheduler_specific_kwargs={'min_lr': 1e-6}
)

training_args.num_train_epochs = 5

# Training loop
for epoch in range(training_args.num_train_epochs):
    print(('\n' + '%15s' * 3) % ('epoch', 'memory', 'loss'))
    p_bar = tqdm(train_dataloader, total=num_steps)
    loss_total = 0
    step = 0

    for batch in p_bar:
        batch = {k: v.to(device) if isinstance(v, BatchEncoding) else v for k, v in batch.items()}

        optimizer.zero_grad()

        with autocast():
            outputs = model(**batch)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        loss_total += loss.item()
        step += 1

        memory = f'{torch.cuda.memory_reserved() / 1E9:.4g}G'  # (GB)
        s = ('%15s' * 2 + '%15.5g') % (f'{epoch + 1}/{training_args.num_train_epochs}', memory, loss_total / step)
        p_bar.set_description(s)

    train_dataset.refresh_epoch()

    model.save(training_args.output_dir)
    with open(os.path.join(training_args.output_dir, 'loss.txt'), 'a') as f:
        f.write(f"Epoch {epoch + 1}, Loss: {loss_total / step}\n")

### **Eval**

In [None]:
from BGE_M3.utils import BGEM3Eval

m3_eval = BGEM3Eval(model_name="BAAI/bge-m3",
                    tokenizer_name='BAAI/bge-m3',
                    data_path='Data/test_finetune_hn.jsonl',
                    candidate_pool='Data/corpus.jsonl',
                    batch_size=128, query_max_length=160, positive_max_length=256, max_neg=100)

queries_dense_vecs, positives_dense_vecs = m3_eval.inference()

for k in [1, 3, 5, 10, 100]:
    acc, recall, _ = BGEM3Eval.top_k_accuracy_recall(queries_dense_vecs,
                                                     positives_dense_vecs,
                                                     m3_eval.labels, k)
    print(f"\nTop-{k} accuracy: {acc:.4f}", "---", f"Top-{k} reacll: {recall:.4f}\n")


In [None]:
from BGE_M3.utils import BGEM3Eval

m3_eval = BGEM3Eval(model_name="bge-m3-output",
                    tokenizer_name='BAAI/bge-m3',
                    data_path='Data/test_finetune_hn.jsonl',
                    candidate_pool='Data/corpus.jsonl',
                    batch_size=128, query_max_length=160, positive_max_length=256, max_neg=100)

queries_dense_vecs, positives_dense_vecs = m3_eval.inference()

for k in [1, 3, 5, 10, 100]:
    acc, recall, _ = BGEM3Eval.top_k_accuracy_recall(queries_dense_vecs,
                                                     positives_dense_vecs,
                                                     m3_eval.labels, k)
    print(f"\nTop-{k} accuracy: {acc:.4f}", "---", f"Top-{k} reacll: {recall:.4f}\n")


## **BGE-M3-Reanker**

### **Finetune**

In [None]:
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3")
model = AutoModelForSequenceClassification.from_pretrained("bge-m3-reranker-output")


FINETUNE_ALL = True
USE_LORA = False

if not FINETUNE_ALL:
    for k, v in model.named_parameters():
        if "position_embeddings" in k:
            v.requires_grad = False

    for k, v in model.named_parameters():
        if "embeddings" in k:
            v.requires_grad = False

    for layer in model.roberta.encoder.layer[:-15]:
        for name, param in layer.named_parameters():
            param.requires_grad = False
else:
    for k, v in model.named_parameters():
        if "position_embeddings" in k:
            v.requires_grad = False

    for k, v in model.named_parameters():
        if "embeddings" in k:
            v.requires_grad = False

# for k, v in model.named_parameters():
#     print(k, v.requires_grad)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())

print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")


In [None]:
from BGE_M3.data import CrossDataset, RerankerCollator


dataset = CrossDataset(tokenizer, 
                       'Data/train_finetune_rerank_hn.jsonl',
                       160, 256, 3)

print("samples:", dataset.__len__())

In [9]:
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from transformers import get_scheduler
import os

output_dir = '/kaggle/working/bge-m3-reranker-output'
os.makedirs(output_dir, exist_ok=True)

epochs = 3
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, 
                        collate_fn=RerankerCollator(tokenizer))
num_steps = len(dataloader)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device).train()

criterion = torch.nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

scheduler = get_scheduler(
    name='cosine_with_min_lr',
    optimizer=optimizer,
    num_warmup_steps=int(num_steps * epochs * 0.1),
    num_training_steps=num_steps * epochs,
    scheduler_specific_kwargs={'min_lr': 1e-6}
)

scaler = GradScaler()

# Training loop
for epoch in range(epochs):
    print(('\n' + '%15s' * 3) % ('epoch', 'memory', 'loss'))
    p_bar = tqdm(dataloader, total=num_steps)
    loss_total = 0
    step = 0

    for batch in p_bar:
        inpusts = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()

        with autocast():
            logits = model(**inpusts).logits
            grouped_logits = logits.view(batch_size, -1)
            target = torch.zeros(batch_size, device=device, dtype=torch.long)
            
            loss = criterion(grouped_logits, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        loss_total += loss.item()
        step += 1

        memory = f'{torch.cuda.memory_reserved() / 1E9:.4g}G'  # (GB)
        s = ('%15s' * 2 + '%15.5g') % (f'{epoch + 1}/{epochs}', memory, loss_total / step)
        p_bar.set_description(s)

    model.save_pretrained(output_dir, state_dict=model.state_dict())
    with open(output_dir + '/loss.txt', 'a') as f:
        f.write(f"Epoch {epoch + 1}, Loss: {loss_total / step}\n")


          epoch         memory           loss


            1/3          13.4G        0.29176: 100%|██████████| 1706/1706 [32:12<00:00,  1.13s/it]



          epoch         memory           loss


            2/3          13.4G        0.22979: 100%|██████████| 1706/1706 [31:44<00:00,  1.12s/it]



          epoch         memory           loss


            3/3          13.4G         0.1641: 100%|██████████| 1706/1706 [31:29<00:00,  1.11s/it]


### **Eval**

In [12]:
import torch
import gc

gc.collect()
torch.cuda.empty_cache() 

In [None]:
from BGE_M3.utils import BGEM3Eval, Reranker, reranker_top_k_accuracy_recall

In [None]:
m3_eval = BGEM3Eval(model_name="bge-m3-output",
                    tokenizer_name='BAAI/bge-m3',
                    data_path='Data/test_finetune_hn.jsonl',
                    candidate_pool='Data/corpus.jsonl',
                    batch_size=128, query_max_length=160, positive_max_length=256, max_neg=100)

Queries shape: torch.Size([2048, 160])
Passage shape: torch.Size([31527, 256])


In [18]:
queries_dense_vecs, positives_dense_vecs = m3_eval.inference()

Queries inferencing: 100%|██████████| 16/16 [00:14<00:00,  1.08it/s]
Passage inferencing: 100%|██████████| 247/247 [05:54<00:00,  1.43s/it]


In [None]:
reranker = Reranker('bge-m3-reranker-output',
                    'BAAI/bge-reranker-v2-m3')

In [23]:
reranker_top_k_accuracy_recall(reranker, queries_dense_vecs, positives_dense_vecs,
                               m3_eval.corpus, m3_eval.queries, m3_eval.labels,
                               k_list = [1, 3, 5, 10], n_candidate=16)

Batches: 100%|██████████| 4/4 [00:00<00:00,  6.44it/s]


reranking...


2048it [08:08,  4.19it/s]

computing metric...

Top-1 accuracy: 0.5713 --- Top-1 reacll: 0.3758


Top-3 accuracy: 0.8047 --- Top-3 reacll: 0.6409


Top-5 accuracy: 0.8750 --- Top-5 reacll: 0.7259


Top-10 accuracy: 0.9199 --- Top-10 reacll: 0.8077






## **Ensemble BM25**

In [None]:
from BGE_M3.utils import Reranker
from BGE_M3.modeling import BGEM3ForInference
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, TensorDataset
from rank_bm25 import BM25Okapi
from pyvi import ViTokenizer
from tqdm import tqdm
import numpy as np
import json
import pickle


In [None]:
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3', use_fast=False)

model_inference = BGEM3ForInference(model_name='bge-m3-output', 
                                    tokenizer=tokenizer, enable_sub_batch=False, 
                                    unified_finetuning=False)

m3_eval_model = model_inference.to(device).half().eval()


In [None]:
with open('bm25-eval/queries.jsonl', 'r', encoding='utf-8') as f:
    queries = [json.loads(line.strip())['text'] for line in f]

with open('bm25-eval/corpus.jsonl', 'r', encoding='utf-8') as f:
    corpus = [json.loads(line.strip())['text'] for line in f]

In [9]:
batch_size = 256



queries_tk = tokenizer.batch_encode_plus(queries, padding='max_length', truncation=True, 
                                           max_length=160, return_tensors="pt")
passage_tk = tokenizer.batch_encode_plus(corpus, padding='max_length', truncation=True,
                                             max_length=256, return_tensors="pt")

q_dataset = TensorDataset(queries_tk['input_ids'], queries_tk['attention_mask'])
p_dataset = TensorDataset(passage_tk['input_ids'], passage_tk['attention_mask'])

q_dataloader = DataLoader(q_dataset, batch_size=batch_size)
p_dataloader = DataLoader(p_dataset, batch_size=batch_size)

print("Queries shape:", queries_tk['input_ids'].shape)
print("Passage shape:", passage_tk['input_ids'].shape)


Queries shape: torch.Size([2048, 160])
Passage shape: torch.Size([31527, 256])


In [11]:
queries_dense_vecs = []
passages_dense_vecs = []


with torch.no_grad():
    for batch in tqdm(q_dataloader, desc="Queries inferencing"):
        query_input, query_mask = batch
        query_ips = {'input_ids': query_input.to(device),
                   'attention_mask': query_mask.to(device)}

        query_outputs = m3_eval_model(query_ips)['dense_vecs']
        
        queries_dense_vecs.append(query_outputs.cpu())

    for batch in tqdm(p_dataloader, desc="Passage inferencing"):
        p_input, p_mask = batch
        passages = {'input_ids': p_input.to(device), 
                    'attention_mask': p_mask.to(device)}

        passages_outputs = m3_eval_model(passages)['dense_vecs']
        
        passages_dense_vecs.append(passages_outputs.cpu())

queries_dense_vecs = torch.cat(queries_dense_vecs, dim=0)
passages_dense_vecs = torch.cat(passages_dense_vecs, dim=0)

Queries inferencing: 100%|██████████| 8/8 [00:11<00:00,  1.50s/it]
Passage inferencing: 100%|██████████| 124/124 [06:02<00:00,  2.93s/it]


In [None]:
# Tách từ từng câu trong corpus
tokenized_corpus = [ViTokenizer.tokenize(doc.lower()).split() for doc in corpus]

# Khởi tạo BM25
bm25 = BM25Okapi(tokenized_corpus)

tokenized_queries = [ViTokenizer.tokenize(doc.lower()).split() for doc in queries]

queries_bm25_scores = [bm25.get_scores(tokenized_query) for tokenized_query in tokenized_queries]

with open('bm25-eval/queries_bm25_scores.pkl', 'wb') as f:
    pickle.dump(queries_bm25_scores, f)


# with open('bm25-eval/queries_bm25_scores.pkl', 'rb') as f:
#     queries_bm25_scores = pickle.load(f)

In [None]:
with open('bm25-eval/labels.pkl', 'rb') as f:
    lables = pickle.load(f)

In [None]:
sims = queries_dense_vecs @ passages_dense_vecs.T

sims = sims.numpy()

# with open('bm25-eval/bge_m3_scores.pkl', 'rb') as f:
#     sims = pickle.load(f)


In [7]:
l2_norms = np.linalg.norm(queries_bm25_scores, axis=1, keepdims=True)
queries_bm25_scores_normalized = queries_bm25_scores / l2_norms

In [22]:
fusion_scores = sims + 3*np.array(queries_bm25_scores_normalized)

fusion_rank = np.argsort(-fusion_scores, axis=1)

In [23]:
print('computing metric...')
for k in [1, 3, 5, 10]:
  sum_recall = 0
  sum_acc = 0

  for i, rel_indexs in enumerate(lables):
      topk_candidates = fusion_rank[i][:k]
      n_recall = 0
      for rel_indx in rel_indexs:
          if rel_indx in topk_candidates:
              n_recall += 1
      if n_recall > 0:
          sum_acc += 1
      sum_recall += n_recall / len(rel_indexs)

  acc = sum_acc / len(fusion_rank)
  recall = sum_recall / len(fusion_rank)

  print(f"\nTop-{k} accuracy: {acc:.4f}", "---", f"Top-{k} reacll: {recall:.4f}\n")

computing metric...

Top-1 accuracy: 0.5469 --- Top-1 reacll: 0.3563


Top-3 accuracy: 0.8008 --- Top-3 reacll: 0.6353


Top-5 accuracy: 0.8682 --- Top-5 reacll: 0.7181


Top-10 accuracy: 0.9219 --- Top-10 reacll: 0.8066



In [None]:
reranker = Reranker('bge-m3-reranker-output',
                    'BAAI/bge-reranker-v2-m3')

In [25]:
print('reranking...')
reranking = []
for i, query in tqdm(enumerate(queries)):
  passages = [corpus[idx] for idx in fusion_rank[i][:16]]
  sorted_indices = reranker.rerank(query, passages)
  reranking.append([fusion_rank[i][idx] for idx in sorted_indices])

reranking...


2048it [09:30,  3.59it/s]


In [26]:
print('computing metric...')
for k in [1, 3, 5, 10]:
  sum_recall = 0
  sum_acc = 0

  for i, rel_indexs in enumerate(lables):
      topk_candidates = reranking[i][:k]
      n_recall = 0
      for rel_indx in rel_indexs:
          if rel_indx in topk_candidates:
              n_recall += 1
      if n_recall > 0:
          sum_acc += 1
      sum_recall += n_recall / len(rel_indexs)

  acc = sum_acc / len(fusion_rank)
  recall = sum_recall / len(fusion_rank)

  print(f"\nTop-{k} accuracy: {acc:.4f}", "---", f"Top-{k} reacll: {recall:.4f}\n")

computing metric...

Top-1 accuracy: 0.5605 --- Top-1 reacll: 0.3713


Top-3 accuracy: 0.8101 --- Top-3 reacll: 0.6506


Top-5 accuracy: 0.8735 --- Top-5 reacll: 0.7299


Top-10 accuracy: 0.9209 --- Top-10 reacll: 0.8140

