In [1]:
!pip install -U -q sentence-transformers -q qdrant-client[fastembed]
!pip install vncorenlp

# Download VnCoreNLP-1.1.1.jar & its word segmentation component (i.e. RDRSegmenter) 
!mkdir -p vncorenlp/models/wordsegmenter
!wget https://raw.githubusercontent.com/vncorenlp/VnCoreNLP/master/VnCoreNLP-1.1.1.jar
!wget https://raw.githubusercontent.com/vncorenlp/VnCoreNLP/master/models/wordsegmenter/vi-vocab
!wget https://raw.githubusercontent.com/vncorenlp/VnCoreNLP/master/models/wordsegmenter/wordsegmenter.rdr
!mv VnCoreNLP-1.1.1.jar vncorenlp/ 
!mv vi-vocab vncorenlp/models/wordsegmenter/
!mv wordsegmenter.rdr vncorenlp/models/wordsegmenter/

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 23.8.0 requires cubinlinker, which is not installed.
cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
cudf 23.8.0 requires ptxcompiler, which is not installed.
cuml 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
dask-cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
apache-beam 2.46.0 requires dill<0.3.2,>=0.3.1.1, but you have dill 0.3.8 which is incompatible.
apache-beam 2.46.0 requires numpy<1.25.0,>=1.14.3, but you have numpy 1.26.4 which is incompatible.
apache-beam 2.46.0 requires protobuf<4,>3.12.2, but you have protobuf 4.25.3 which is incompatible.
apache-beam 2.46.0 requires pyarrow<10.0.0,>=3.0.0, but you have pyarrow 15.0.2 which is incompatible.
cudf 23.8

In [2]:
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
from torch.utils.data import Dataset, DataLoader
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Filter, FieldCondition, VectorParams, SearchRequest
import time
from vncorenlp import VnCoreNLP
from transformers import AutoTokenizer, AutoModel

In [3]:
from functools import reduce
import json

class RetrievalMetrics:
    def __init__(self, retrieval_results:dict, true_results:str='true_results.json')->None:
        """
        - retrieval_results: str, path to the json file containing the retrieval results
        - true_results: str, path to the json file containing the true relevant articles
        - Both files follow the format:
        {
            "question_id": ["article_id_1", "article_id_2", ...]
        }
        Note that the article_ids are sorted by relevance, i.e., the first article is the most relevant one.
        - question_id is taken as is from the dataset
        - article_id is the law title concatenated with the article number using a % sign, for example: "28/2020/nđ-cp%21"
        """

        self.retrieval_results = retrieval_results
        with open(true_results, 'r') as f:
            self.true_results = json.load(f)
        self.num_questions = len(self.retrieval_results)
        assert len(set(self.retrieval_results.keys()) & set(self.true_results.keys())) == len(self.retrieval_results)

    
    def get_recall_at_k(self, question_id:str=None, k:int=1)->float:
        
        if question_id:
            retrieved_articles = self.retrieval_results[question_id]
            true_relevant_articles = self.true_results[question_id]
            assert k <= len(retrieved_articles)

            recall = len(set(retrieved_articles[:k]) & set(true_relevant_articles)) / len(true_relevant_articles)
            return recall
        
        else:
            recalls = []
            for question_id in self.retrieval_results:
                print('recall at k',self.get_recall_at_k(question_id, k))
                recalls.append(self.get_recall_at_k(question_id, k))
            return sum(recalls) / self.num_questions
    
    def get_mrr(self)->float:
        mrr = 0
        for question_id in self.retrieval_results:
            retrieved_articles = self.retrieval_results[question_id]
            true_relevant_articles = self.true_results[question_id]
            for i, article_id in enumerate(retrieved_articles):
                if article_id in true_relevant_articles:
                    mrr += 1/(i+1)
                    break
        return mrr / self.num_questions


In [4]:
class CorpusDataset(Dataset):
    def __init__(self, dataframe, type_data):
        self.dataframe = dataframe
        self.type_data = type_data

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        if self.type_data == 'valid':
            return row['question_id'], row['relevant_articles'], row['segmented_question']
            
class BiEncoder:
    def __init__(self, 
                 url = 'https://4d31c99a-9390-4174-8547-167526f0138b.europe-west3-0.gcp.cloud.qdrant.io:6333',
                 api_key = 'WcOwiXDWcFcKJBXkI2zH9LHQ2he0npQNkYTEmS84UGx4kcLwRbMVKg',
                 old_checkpoint = 'bkai-foundation-models/vietnamese-bi-encoder',
                 tunned_checkpoint = '/kaggle/input/checkpoint-1/best_checkpoint.pt',
                 tunned = False,
                 segment = 'rdr', #pyvi
                 valid_qna = '/kaggle/input/new-data-for-acc-grad/valid_qna.csv',
                 true_valid_results = '/kaggle/input/new-data-for-acc-grad/true_valid_results.json',
                 
                ):
        #-----------Setup connection------#
        self.url = url
        self.apikey = api_key
        self.client = QdrantClient(
            url=self.url, 
            api_key=self.apikey,
        )
        self.collection_name = 'embedding_legal_1'
        
        #--------Segment-----------#
        if segment == 'rdr':
            self.rdrsegmenter = VnCoreNLP("/kaggle/working/vncorenlp/VnCoreNLP-1.1.1.jar", annotators="wseg", max_heap_size='-Xmx500m') 
            
        #-------Model--------#
        self.tokenizer = AutoTokenizer.from_pretrained(old_checkpoint)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = AutoModel.from_pretrained(old_checkpoint).to(self.device)
        if tunned == True:
            checkpoint = torch.load(tunned_checkpoint, map_location=self.device)
            self.model.load_state_dict(checkpoint['model'])
            self.model.to(self.device)
            
        self.valid_encode = 4
        self.valid_qna = valid_qna
        self.true_valid_results = true_valid_results

    def encode(self, segmented_questions):
        encoded_input = self.tokenizer(segmented_questions, padding=True, truncation=True, return_tensors='pt', max_length=256)
        # Compute token embeddings
        encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
        with torch.no_grad():
            model_output = self.model(**encoded_input)

        # Perform pooling. In this case, mean pooling.
        sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
        return sentence_embeddings
        
    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
        
    def query(self, question = 'chạy xe phân_khối không có bằng lái bị phạt bao_nhiêu ?', topk = 10, segmented = True):
        if not segmented:
            segmented_question = " ".join(self.rdrsegmenter.tokenize(question)[0])
        else: 
            segmented_question = question
        print('segmented_question',segmented_question)
        results = self.client.search(
                    collection_name = self.collection_name,
                    query_vector=self.encode([question])[0],
                    limit=topk,
                )
        return {point.id: {
            'content': point.payload['content'],
            'law_article_id': point.payload['law_article_id']
            }
            for point in results
        }
    
    def validate(self, topk = 10, limit = None):
        test_df = pd.read_csv(self.valid_qna) if limit is None else pd.read_csv(self.valid_qna)[:limit]
        # Create DataLoader
        dataset = CorpusDataset(test_df, type_data = 'valid')
        dataloader = DataLoader(dataset, batch_size=self.valid_encode, shuffle=False)

        recall = 0
        retrieval_results = {}
        for batch in dataloader:
            ques_ids, law_article_ids, questions = batch
            vectors = self.encode(questions)
            search_queries = [SearchRequest(vector=vector, limit=topk, with_payload = True) for vector in vectors]
            wait = 0.1
            while True:
                try:
                    results = self.client.search_batch(collection_name=self.collection_name, requests=search_queries)
                    for qid, points in zip(ques_ids, results):
                        retrieval_results[qid] = list([p.payload['law_article_id']['law_id'] + '%' + p.payload['law_article_id']['article_id'] for p in points])
                    break
                except:
                    time.sleep(wait)
                    wait*=2
        metrics = RetrievalMetrics(retrieval_results = retrieval_results, true_results = self.true_valid_results)
        recall = metrics.get_recall_at_k(k=topk)
        mrr = metrics.get_mrr()
        
        return recall, mrr
    
biencoder = BiEncoder(tunned = False)
# print(biencoder.query(question = 'khái_niệm quỹ đại_chúng', segmented = True, topk = 10))

recall, mrr = biencoder.validate(
    limit = None,
    topk = 10,
)

print('recall', recall)
print('mrr',mrr)

tokenizer_config.json:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/895k [00:00<?, ?B/s]

bpe.codes:   0%|          | 0.00/1.14M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/22.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/167 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/540M [00:00<?, ?B/s]

recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 0.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 0.0
recall at k 1.0
recall at k 1.0
recall at k 0.5
recall at k 0.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 0.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall at k 1.0
recall a