In [1]:
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
import torch
import warnings

warnings.simplefilter("ignore", FutureWarning)

torch.__version__

'2.6.0+cu124'

## **BGE-M3**

**Note:**

**[BGE-M3 Fine-tune Guide](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune/embedder)**

**[BGE-M3-Reranker-v2 Fine-tune Guide](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune/reranker)**

### **Fine-tune**

In [2]:
import logging
import os

logger = logging.getLogger('BGE_M3_FINE_TUNE')
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)

In [3]:
from BGE_M3.arguments import DataArguments, ModelArguments, RetrieverTrainingArguments
from BGE_M3.data import CustomTrainDataset, CustomEmbedCollator
from BGE_M3.modeling import BGEM3Model

04/09/2025 19:02:59 - INFO - numexpr.utils -   NumExpr defaulting to 8 threads.
04/09/2025 19:03:02 - INFO - datasets -   PyTorch version 2.6.0+cu124 available.
04/09/2025 19:03:02 - INFO - datasets -   TensorFlow version 2.10.1 available.


In [4]:
# Khởi tạo đối tượng ModelArguments với tham số
model_args = ModelArguments(
    model_name_or_path="BAAI/bge-m3",
    tokenizer_name="BAAI/bge-m3",
    cache_dir=None
)

# Khởi tạo đối tượng DataArguments với tham số
data_args = DataArguments(
    knowledge_distillation=True,
    train_data=["./Data/train.jsonl"],
    cache_path="./data_cache",
    train_group_size=4,
    query_max_len=256,
    passage_max_len=400,
    max_example_num_per_dataset=100000,
    same_task_within_batch=True,
    shuffle_ratio=0.0,
    small_threshold=0,
    drop_threshold=0,
    pad_to_multiple_of=8
)

# Khởi tạo đối tượng RetrieverTrainingArguments với tham số
training_args = RetrieverTrainingArguments(
    output_dir="./bge-m3-output",
    negatives_cross_device=False,
    temperature=0.02,
    fix_position_embedding=True,
    sentence_pooling_method='cls',
    normlized=True,
    enable_sub_batch=False,
    unified_finetuning=False,
    use_self_distill=False,
    fix_encoder=False,
    colbert_dim=-1,
    self_distill_start_step=0,
    per_device_train_batch_size=2,
    sub_batch_size=-1,
    num_train_epochs=5,
    learning_rate=1e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
)

os.makedirs(training_args.output_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda


In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)

model = BGEM3Model(model_name=model_args.model_name_or_path,
                   tokenizer=tokenizer,
                   normlized=training_args.normlized,
                   sentence_pooling_method=training_args.sentence_pooling_method,
                   negatives_cross_device=training_args.negatives_cross_device,
                   temperature=training_args.temperature,
                   enable_sub_batch=training_args.enable_sub_batch,
                   unified_finetuning=training_args.unified_finetuning,
                   use_self_distill=training_args.use_self_distill,
                   colbert_dim=training_args.colbert_dim,
                   self_distill_start_step=training_args.self_distill_start_step)

if training_args.fix_position_embedding:
    for k, v in model.named_parameters():
        if "position_embeddings" in k:
            logging.info(f"Freeze the parameters for {k}")
            v.requires_grad = False
if training_args.fix_encoder:
    for k, v in model.named_parameters():
        if "colbert_linear" in k or 'sparse_linear' in k:
            logging.info(f"train the parameters for {k}")
        else:
            v.requires_grad = False


Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

04/09/2025 19:03:52 - INFO - BGE_M3.modeling -   loading model-model_name: C:\Users\ACER\.cache\huggingface\hub\models--BAAI--bge-m3\snapshots\5617a9f61b028005a4858fdac845db406aefb181
04/09/2025 19:03:57 - INFO - BGE_M3.modeling -   loading existing colbert_linear and sparse_linear---------
04/09/2025 19:03:57 - INFO - root -   Freeze the parameters for model.embeddings.position_embeddings.weight


In [6]:
train_dataset = CustomTrainDataset(args=data_args, 
                                        default_batch_size=training_args.per_device_train_batch_size, 
                                        seed=training_args.seed)
data_collator = CustomEmbedCollator(
    tokenizer,
    query_max_len=data_args.query_max_len,
    passage_max_len=data_args.passage_max_len,
    sub_batch_size=training_args.sub_batch_size,
    pad_to_multiple_of=data_args.pad_to_multiple_of,
    padding='max_length',
    return_tensors="pt"
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=training_args.per_device_train_batch_size,
    collate_fn=data_collator,
    shuffle=True
)

In [7]:
total_params = sum(p.numel() for p in model.model.parameters())

print(f"Tổng số tham số trong mô hình: {total_params}")

Tổng số tham số trong mô hình: 567754752


In [8]:
for k, v in model.named_parameters():
    if "embeddings" in k:
        v.requires_grad = False

for layer in model.model.encoder.layer[:-1]:
    for param in layer.parameters():
        param.requires_grad = False

for k, v in model.named_parameters():
    print(k, v.requires_grad)  # In ra trạng thái của từng tham số trong mô hình

model.embeddings.word_embeddings.weight False
model.embeddings.position_embeddings.weight False
model.embeddings.token_type_embeddings.weight False
model.embeddings.LayerNorm.weight False
model.embeddings.LayerNorm.bias False
model.encoder.layer.0.attention.self.query.weight False
model.encoder.layer.0.attention.self.query.bias False
model.encoder.layer.0.attention.self.key.weight False
model.encoder.layer.0.attention.self.key.bias False
model.encoder.layer.0.attention.self.value.weight False
model.encoder.layer.0.attention.self.value.bias False
model.encoder.layer.0.attention.output.dense.weight False
model.encoder.layer.0.attention.output.dense.bias False
model.encoder.layer.0.attention.output.LayerNorm.weight False
model.encoder.layer.0.attention.output.LayerNorm.bias False
model.encoder.layer.0.intermediate.dense.weight False
model.encoder.layer.0.intermediate.dense.bias False
model.encoder.layer.0.output.dense.weight False
model.encoder.layer.0.output.dense.bias False
model.encode

In [9]:
model = model.to(device)
model.train()

BGEM3Model(
  (model): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 1024, padding_idx=1)
      (position_embeddings): Embedding(8194, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True

In [10]:
from torch import optim
from torch.cuda.amp import autocast, GradScaler
from transformers import BatchEncoding
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

In [11]:
model.save(training_args.output_dir)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

num_steps = len(train_dataloader)
total_traning_steps = num_steps * training_args.num_train_epochs

scaler = GradScaler()

scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=int(total_traning_steps * training_args.warmup_ratio),
    num_training_steps=total_traning_steps,
)

# Training loop
for epoch in range(training_args.num_train_epochs):
    print(('\n' + '%15s' * 3) % ('epoch', 'memory', 'loss'))
    p_bar = tqdm(train_dataloader, total=num_steps)
    loss_total = 0
    step = 0

    for batch in p_bar:
        batch = {k: v.to(device) if isinstance(v, BatchEncoding) else v for k, v in batch.items()}

        optimizer.zero_grad()

        with autocast():
            outputs = model(**batch)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # outputs = model(**batch)
        # loss = outputs.loss

        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        loss_total += loss.item()
        step += 1

        memory = f'{torch.cuda.memory_reserved() / 1E9:.4g}G'  # (GB)
        s = ('%15s' * 2 + '%15.5g') % (f'{epoch + 1}/{training_args.num_train_epochs}', memory, loss_total / step)
        p_bar.set_description(s)

    print(f"Epoch {epoch + 1}, Loss: {loss_total / step}")


    model.save(training_args.output_dir)
    with open(os.path.join(training_args.output_dir, 'loss.txt'), 'a') as f:
        f.write(f"Epoch {epoch + 1}, Loss: {loss_total / step}\n")

### **Eval**

In [2]:
from BGE_M3.modeling import BGEM3ForInference

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3', use_fast=False)

model_inference = BGEM3ForInference(model_name='./bge-m3-output', tokenizer=tokenizer,
                                    enable_sub_batch=False, unified_finetuning=False)

model_inference = model_inference.half().to(device).eval()


In [5]:
import json

with open('Data/test.json', 'r') as f:
    test_data = json.load(f)

In [6]:
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

batch_size = 8
queries = []
positives = []

for item in test_data:
    queries.append(item['query'])
    positives.append(item['pos'])

# encode queries and positives
queries = tokenizer.batch_encode_plus(queries, padding='max_length', truncation=True, 
                                      max_length=256, return_tensors="pt")
positives = tokenizer.batch_encode_plus(positives, padding='max_length', truncation=True, 
                                        max_length=400, return_tensors="pt")

dataset = TensorDataset(queries['input_ids'], queries['attention_mask'],
                        positives['input_ids'], positives['attention_mask'])

dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

print("Queries shape:", queries['input_ids'].shape)
print("Positives shape:", positives['input_ids'].shape)

Queries shape: torch.Size([2000, 256])
Positives shape: torch.Size([2000, 400])


In [8]:
# Biến để lưu trữ tất cả các dự đoán
queries_dense_vecs = []
positives_dense_vecs = []

# Duyệt qua DataLoader và đánh giá mô hình
with torch.no_grad():  # Không tính toán gradient trong chế độ đánh giá
    for batch in tqdm(dataloader):
        query_input, query_mask, pos_input, pos_mask = batch
        queries = {'input_ids': query_input.to(device), 'attention_mask': query_mask.to(device)}
        positives = {'input_ids': pos_input.to(device), 'attention_mask': pos_mask.to(device)}

        query_outputs = model_inference(queries)['dense_vecs']
        pos_outputs = model_inference(positives)['dense_vecs']
        
        queries_dense_vecs.append(query_outputs.cpu())
        positives_dense_vecs.append(pos_outputs.cpu())
        

100%|██████████| 250/250 [10:58<00:00,  2.63s/it]


In [10]:
queries_dense_vecs_all = torch.cat(queries_dense_vecs, dim=0)
positives_dense_vecs_all = torch.cat(positives_dense_vecs, dim=0)

In [13]:
similarity = queries_dense_vecs_all @ positives_dense_vecs_all.T
similarity.shape

torch.Size([2000, 2000])

In [14]:
def top_k_accuracy(similarity, k):
    # similarity: ma trận tương đồng giữa các truy vấn và các đoạn văn
    # k: số lượng đoạn văn hàng đầu để xem xét
    top_k_indices = similarity.topk(k, dim=1).indices
    correct_count = 0

    for i in range(similarity.size(0)):
        if i in top_k_indices[i]:
            correct_count += 1

    return correct_count / similarity.size(0)

In [25]:
for k in [1, 3, 5, 10]:
    acc = top_k_accuracy(similarity, k)
    print(f"Top-{k} accuracy: {acc:.4f}")

Top-1 accuracy: 0.7930
Top-3 accuracy: 0.9200
Top-5 accuracy: 0.9420
Top-10 accuracy: 0.9670
