In [3]:
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer

In [1]:
!pip install sentence_transformers datasets accelerate

Collecting accelerate
  Downloading accelerate-1.0.1-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.0.1-py3-none-any.whl (330 kB)
Installing collected packages: accelerate
Successfully installed accelerate-1.0.1


In [2]:
import random
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import load_dataset

  from tqdm.autonotebook import tqdm, trange


In [3]:
corpus = load_dataset("hiieu/legal_eval", split="corpus")
queries = load_dataset("hiieu/legal_eval", split="queries")
relevant_docs_data = load_dataset("hiieu/legal_eval_label", split="train")
# Convert the datasets to dictionaries
corpus = dict(zip(corpus["id"], corpus["text"]))  # Our corpus (cid => document)
queries = dict(zip(queries["id"], queries["text"]))  # Our queries (qid => question)
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for qid, corpus_ids in zip(relevant_docs_data["question_id"], relevant_docs_data["corpus_id"]):
    qid = str(qid)
    corpus_ids = str(corpus_ids)
    if qid not in relevant_docs:
        relevant_docs[qid] = set()
    relevant_docs[qid].add(corpus_ids)

README.md:   0%|          | 0.00/426 [00:00<?, ?B/s]

queries-00000-of-00001.parquet:   0%|          | 0.00/58.8k [00:00<?, ?B/s]

corpus-00000-of-00001.parquet:   0%|          | 0.00/482k [00:00<?, ?B/s]

Generating queries split:   0%|          | 0/639 [00:00<?, ? examples/s]

Generating corpus split:   0%|          | 0/562 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/312 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/29.0k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/699 [00:00<?, ? examples/s]

In [5]:

import json
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    InformationRetrievalEvaluator,
    SequentialEvaluator,
)
from sentence_transformers.util import cos_sim
from datasets import load_dataset, concatenate_datasets

model = SentenceTransformer("hiieu/halong_embedding")
matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small
matryoshka_evaluators = []
# Iterate over the different dimensions
for dim in matryoshka_dimensions:
    ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(ir_evaluator)

# Create a sequential evaluator
evaluator = SequentialEvaluator(matryoshka_evaluators)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/201 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/13.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.34k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [6]:
results = evaluator(model)
for k,v in results.items():
    print(k, v)

dim_768_cosine_accuracy@1 0.8294209702660407
dim_768_cosine_accuracy@3 0.9233176838810642
dim_768_cosine_accuracy@5 0.9436619718309859
dim_768_cosine_accuracy@10 0.9687010954616588
dim_768_cosine_precision@1 0.8294209702660407
dim_768_cosine_precision@3 0.3145539906103286
dim_768_cosine_precision@5 0.1931142410015649
dim_768_cosine_precision@10 0.09906103286384975
dim_768_cosine_recall@1 0.8145539906103286
dim_768_cosine_recall@3 0.9178403755868545
dim_768_cosine_recall@5 0.9389671361502347
dim_768_cosine_recall@10 0.9640062597809077
dim_768_cosine_ndcg@10 0.8976041381292648
dim_768_cosine_mrr@10 0.879893558884169
dim_768_cosine_map@100 0.8763179130484675
dim_512_cosine_accuracy@1 0.8137715179968701
dim_512_cosine_accuracy@3 0.9233176838810642
dim_512_cosine_accuracy@5 0.9389671361502347
dim_512_cosine_accuracy@10 0.9702660406885759
dim_512_cosine_precision@1 0.8137715179968701
dim_512_cosine_precision@3 0.3145539906103286
dim_512_cosine_precision@5 0.1921752738654147
dim_512_cosine_pr

In [7]:
import pandas as pd
from datasets import Dataset

def prepare_training_dataset(queries, corpus, relevant_docs):
    anchors = []
    positives = []
    for query_id, docs in relevant_docs.items():
        for doc_id in docs:
          anchors.append(queries[query_id])
          positives.append(corpus[doc_id] )
    df = {
        "anchor": anchors,
        "positive": positives
    }

    return Dataset.from_dict(df)

pairs = prepare_training_dataset(queries, corpus, relevant_docs)
pairs[0]

{'anchor': 'Khi có sự cố, tai nạn giao thông đường sắt đô thị, doanh nghiệp kinh doanh đường sắt báo ngay tin cho ai?',
 'positive': 'Điều 22. Báo tin và xử lý tin báo về sự cố, tai nạn giao thông đường sắt1. Khi có sự cố, tai nạn xảy ra lái tàu hoặc nhân viên hỗ trợ an toàn phải báo ngay cho nhân viên điều độ chạy tàu.\n2. Nhân viên điều độ chạy tàu phải báo ngay cho những tổ chức, cá nhân sau đây:\na) Các ga hai đầu khu gian;\nb) Doanh nghiệp kinh doanh đường sắt đô thị.\n3. Doanh nghiệp kinh doanh đường sắt đô thị phải báo ngay cho những tổ chức, cá nhân sau đây:\na) Cơ quan công an nơi gần nhất;\nb) Ủy ban nhân dân cấp tỉnh (trong trường hợp xảy ra tai nạn chết người và trong các trường hợp cần sự phối hợp của Ủy ban nhân dân các cấp) và các đơn vị có liên quan.\n4. Doanh nghiệp kinh doanh đường sắt đô thị quy định việc báo tin sự cố giao thông đường sắt gây ra không phải dừng tàu, không bế tắc chính tuyến, không ảnh hưởng đến biểu đồ chạy tàu.\n5. Trường hợp quá thời gian chạy tàu

In [8]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss

matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)
train_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions
)

In [10]:
import torch

# Get the current version of PyTorch
torch_version = torch.__version__
torch_version

'2.0.1'

In [15]:
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers

# torch.distributed.init_process_group(backend='nccl')
# define training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="sample", # output directory and hugging face model ID
    num_train_epochs=1,                         # number of epochs
    per_device_train_batch_size=4,             # train batch size
    gradient_accumulation_steps=2,             # for a global batch size of 512
    per_device_eval_batch_size=4,              # evaluation batch size
    #gradient_checkpointing=True,
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    #tf32=True,                                  # use tf32 precision
    fp16=True,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="steps",                      # evaluate after each epoch
    #save_strategy="epoch",                      # save after each epoch
    save_steps = 500,
    logging_steps=10,                           # log every 10 steps
    save_total_limit=3,                         # save only the last 3 models
    load_best_model_at_end=True,                # load the best model when training ends
    metric_for_best_model="eval_dim_768_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
)


Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information.


In [16]:
from sentence_transformers import SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,  # training arguments
    train_dataset=pairs,
    loss=train_loss,
    evaluator=evaluator,
)

trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
10,0.9373,No log,0.843505,0.959311,0.974961,0.985915,0.843505,0.327074,0.2,0.101095,0.829421,0.954617,0.971831,0.982786,0.919873,0.901638,0.898474,0.84507,0.957746,0.974961,0.98748,0.84507,0.326552,0.2,0.101252,0.830986,0.953052,0.971831,0.984351,0.920941,0.902395,0.899354,0.834116,0.949922,0.968701,0.981221,0.834116,0.323422,0.198435,0.100626,0.819249,0.944444,0.964789,0.978091,0.913018,0.894211,0.891096,0.837246,0.940532,0.959311,0.978091,0.837246,0.320814,0.19687,0.100469,0.823161,0.935837,0.956182,0.975743,0.910541,0.891603,0.88895,0.779343,0.923318,0.948357,0.973396,0.779343,0.313511,0.194679,0.1,0.765258,0.916275,0.945227,0.971049,0.881243,0.854637,0.8513,0.8513
20,0.7118,No log,0.849765,0.960876,0.979656,0.99061,0.849765,0.327595,0.200939,0.101565,0.835681,0.956182,0.976526,0.98748,0.925194,0.907144,0.903909,0.849765,0.964006,0.978091,0.995305,0.849765,0.328638,0.200626,0.102191,0.834898,0.959311,0.974961,0.992958,0.926479,0.907309,0.903512,0.840376,0.957746,0.971831,0.985915,0.840376,0.326552,0.199374,0.101252,0.825509,0.953052,0.968701,0.983568,0.918827,0.900027,0.896893,0.840376,0.946792,0.964006,0.984351,0.840376,0.323422,0.198122,0.101095,0.826291,0.942879,0.961659,0.982003,0.915662,0.896208,0.893699,0.793427,0.924883,0.957746,0.976526,0.793427,0.314554,0.19687,0.100313,0.77856,0.918623,0.955399,0.974178,0.889248,0.864471,0.860818,0.860818


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

TrainOutput(global_step=21, training_loss=0.8058630809897468, metrics={'train_runtime': 337.9295, 'train_samples_per_second': 1.95, 'train_steps_per_second': 0.062, 'total_flos': 0.0, 'train_loss': 0.8058630809897468, 'epoch': 1.0})

In [None]:
### Re-evaluate
from sentence_transformers import SentenceTransformer
import torch
fine_tuned_model = SentenceTransformer(
    args.output_dir, device="cuda" if torch.cuda.is_available() else "cpu"
)
# Evaluate the model
results = evaluator(fine_tuned_model)

for k,v in results.items():
    print(k, v)