In [20]:
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer

In [19]:
!pip install sentence_transformers datasets accelerate



In [21]:
import random
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import load_dataset

In [22]:
corpus = load_dataset("hiieu/legal_eval", split="corpus")
queries = load_dataset("hiieu/legal_eval", split="queries")
relevant_docs_data = load_dataset("hiieu/legal_eval_label", split="train")
# Convert the datasets to dictionaries
corpus = dict(zip(corpus["id"], corpus["text"]))  # Our corpus (cid => document)
queries = dict(zip(queries["id"], queries["text"]))  # Our queries (qid => question)
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for qid, corpus_ids in zip(relevant_docs_data["question_id"], relevant_docs_data["corpus_id"]):
    qid = str(qid)
    corpus_ids = str(corpus_ids)
    if qid not in relevant_docs:
        relevant_docs[qid] = set()
    relevant_docs[qid].add(corpus_ids)

In [23]:
corpus

{'01/2013/tt-bvhttdl_5': 'Điều 5. Tổ chức đại hội1. Tổ chức các kỳ đại hội:\na) Trước khi tổ chức đại hội nhiệm kỳ hoặc đại hội bất thường 30 (ba mươi) ngày làm việc, hội thể thao quốc gia gửi văn bản báo cáo theo quy định đến Bộ Nội vụ và Bộ Văn hóa, Thể thao và Du lịch (qua Tổng cục Thể dục thể thao);\nb) Theo yêu cầu của Bộ Nội vụ, Tổng cục Thể dục thể thao có trách nhiệm thẩm định hồ sơ và trình Bộ Văn hóa, Thể thao và Du lịch xem xét tham gia ý kiến bằng văn bản với Bộ Nội vụ về việc cho phép hội thể thao quốc gia tổ chức đại hội nhiệm kỳ hoặc đại hội bất thường;\nc) Tổng cục Thể dục thể thao hướng dẫn và tạo điều kiện để hội thể thao quốc gia tiến hành đại hội theo đúng quy định của pháp luật.\n2. Nhân sự lãnh đạo của hội thể thao quốc gia:\na) Theo đề nghị của hội thể thao quốc gia, Tổng cục Thể dục thể thao tham gia ý kiến với Ban chấp hành hội về việc giới thiệu cán bộ, công chức, viên chức thuộc thẩm quyền quản lý tham gia ứng cử vào Ban chấp hành nhiệm kỳ mới;\nb) Đối với nh

In [24]:

import json
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

model = SentenceTransformer("hiieu/halong_embedding")
matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small
matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

In [26]:
for query_id, docs in relevant_docs.items():
    print(query_id)

beaf0fc9dc93543de34d8eebd3326b1c
5a4dabcf640ca9c7789cd34545d690b6
aa2a714d00de99332c5ad7fcbd9d47e8
0ea4b37ac7c50de1c9e730ea76bae04c
3263bd5f639ce577a8e28dd3974b110e
e96a6a244829e4347e7769fda7b85feb
6d7726b8c09c7974e238b236a5b25d16
27f27c7127815cedc0582a39260286df
84c4eb084883098dedd4ec4c10f4643d
2cc0913a6a611556689058a6280223cc
90631b1682e226bc5feb1865b4b96d28
cf0430eb62d782c85b69d5ca5f13af5e
f781aba8f899d9e55def28224dd1bc9b
9790aa611a815aa9afcdc8620f67fd04
a19a6a08240b0756e496f8607d4c6dbe
a50dd37d818965b79bade1fdf7b2b5ff
f2f544cca405ff27ad52b3e7fed85899
58bae9de200d788ab19c032458296d5b
a91676f6aa0124bc4fd30fe87eddb2c9
87264edf452ba35e4b570d6f07a505e7
ffc174f568c3501fc886fcef08e2233f
88ad3c7b071edd17c72bd8190a709f3e
eac5a7ec979e0293006180a47245ec94
eb19da44545d6444bdc3813ce4e7800d
a377cd49ca931a6854dadb6cbdbcf1d1
25082bfa83461a54300cf3427c495e7b
5cfabfd55f5dc66b95e51d1cd92d0c9a
e66ae5eecc1672bac2c5799673666bda
e3f19447a5df1614eeb6f7e3e3faa754
b79a0d2ddf6f7d400d36ae63200fa1a3
56a0c5f94a

In [29]:
import pandas as pd
from datasets import Dataset

def prepare_training_dataset(queries, corpus, relevant_docs):
    anchors = []
    positives = []
    for query_id, docs in relevant_docs.items():
        for doc_id in docs:
            print(query_id, doc_id)
            anchors.append(queries[query_id])
            positives.append(corpus[doc_id] )
    df = {
        "anchor": anchors,
        "positive": positives
    }

    return Dataset.from_dict(df)

pairs = prepare_training_dataset(queries, corpus, relevant_docs)
pairs

beaf0fc9dc93543de34d8eebd3326b1c 23/2018/tt-bgtvt_22
5a4dabcf640ca9c7789cd34545d690b6 100/2019/nđ-cp_5
aa2a714d00de99332c5ad7fcbd9d47e8 23/2015/nđ-cp_2
0ea4b37ac7c50de1c9e730ea76bae04c 02/2019/tt-btp_2
3263bd5f639ce577a8e28dd3974b110e 21/2021/nđ-cp_20
e96a6a244829e4347e7769fda7b85feb 55/2019/nđ-cp_9
6d7726b8c09c7974e238b236a5b25d16 15/2021/nđ-cp_27
27f27c7127815cedc0582a39260286df 23/2018/tt-bgtvt_22
84c4eb084883098dedd4ec4c10f4643d 101/2015/qh13_80
2cc0913a6a611556689058a6280223cc 30/2020/nđ-cp_7
90631b1682e226bc5feb1865b4b96d28 91/2015/qh13_466
cf0430eb62d782c85b69d5ca5f13af5e 116/2007/nđ-cp_6
f781aba8f899d9e55def28224dd1bc9b 28/2017/nđ-cp_2
f781aba8f899d9e55def28224dd1bc9b 158/2013/nđ-cp_15
9790aa611a815aa9afcdc8620f67fd04 44/2020/nđ-cp_14
a19a6a08240b0756e496f8607d4c6dbe 01/2016/qh14_12
a50dd37d818965b79bade1fdf7b2b5ff 126/2020/nđ-cp_16
f2f544cca405ff27ad52b3e7fed85899 50/2019/tt-btc_8
58bae9de200d788ab19c032458296d5b 82/2020/nđ-cp_15
a91676f6aa0124bc4fd30fe87eddb2c9 09/2019/tt-nhn

Dataset({
    features: ['anchor', 'positive'],
    num_rows: 659
})

In [8]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [15]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

# torch.distributed.init_process_group(backend='nccl')
# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="sample", # output directory and hugging face model ID
    num_train_epochs=1,                         # number of epochs
    per_device_train_batch_size=4,             # train batch size
    gradient_accumulation_steps=2,             # for a global batch size of 512
    per_device_eval_batch_size=4,              # evaluation batch size
    #gradient_checkpointing=True,
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    #tf32=True,                                  # use tf32 precision
    fp16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="steps",                      # evaluate after each epoch
    #save_strategy="epoch",                      # save after each epoch
    save_steps = 500,
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_768_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
)


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.


In [16]:
from sentence_transformers import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,  # training arguments
    train_dataset=pairs,
    loss=train_loss,
    evaluator=evaluator,
)

trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
10,0.9373,No log,0.843505,0.959311,0.974961,0.985915,0.843505,0.327074,0.2,0.101095,0.829421,0.954617,0.971831,0.982786,0.919873,0.901638,0.898474,0.84507,0.957746,0.974961,0.98748,0.84507,0.326552,0.2,0.101252,0.830986,0.953052,0.971831,0.984351,0.920941,0.902395,0.899354,0.834116,0.949922,0.968701,0.981221,0.834116,0.323422,0.198435,0.100626,0.819249,0.944444,0.964789,0.978091,0.913018,0.894211,0.891096,0.837246,0.940532,0.959311,0.978091,0.837246,0.320814,0.19687,0.100469,0.823161,0.935837,0.956182,0.975743,0.910541,0.891603,0.88895,0.779343,0.923318,0.948357,0.973396,0.779343,0.313511,0.194679,0.1,0.765258,0.916275,0.945227,0.971049,0.881243,0.854637,0.8513,0.8513
20,0.7118,No log,0.849765,0.960876,0.979656,0.99061,0.849765,0.327595,0.200939,0.101565,0.835681,0.956182,0.976526,0.98748,0.925194,0.907144,0.903909,0.849765,0.964006,0.978091,0.995305,0.849765,0.328638,0.200626,0.102191,0.834898,0.959311,0.974961,0.992958,0.926479,0.907309,0.903512,0.840376,0.957746,0.971831,0.985915,0.840376,0.326552,0.199374,0.101252,0.825509,0.953052,0.968701,0.983568,0.918827,0.900027,0.896893,0.840376,0.946792,0.964006,0.984351,0.840376,0.323422,0.198122,0.101095,0.826291,0.942879,0.961659,0.982003,0.915662,0.896208,0.893699,0.793427,0.924883,0.957746,0.976526,0.793427,0.314554,0.19687,0.100313,0.77856,0.918623,0.955399,0.974178,0.889248,0.864471,0.860818,0.860818


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=21, training_loss=0.8058630809897468, metrics={'train_runtime': 337.9295, 'train_samples_per_second': 1.95, 'train_steps_per_second': 0.062, 'total_flos': 0.0, 'train_loss': 0.8058630809897468, 'epoch': 1.0})

In [None]:
### Re-evaluate
from sentence_transformers import SentenceTransformer
import torch
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)

for k,v in results.items():
    print(k, v)