In [1]:
from pymilvus import (
    connections,
    Collection,
    FieldSchema,
    CollectionSchema,
    DataType,
    utility,
)
import pandas as pd
from scipy import sparse
import numpy as np
import gc
import json

connections.connect(uri="http://localhost:19530")  # Replace with your Milvus server IP

In [2]:
from pymilvus import AnnSearchRequest
from pymilvus import WeightedRanker, RRFRanker

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
collection = Collection(name="bat_dong_san")
collection.load()

In [4]:
def query(query_embeds, rerank, limit=3):
    content_sparse_search_params = {
        "data": query_embeds["sparse"],
        "anns_field": "content_sparse",
        "param": {
            "metric_type": "IP",
        },
        "limit": limit,
    }
    content_sparse_request = AnnSearchRequest(**content_sparse_search_params)

    content_dense_search_params = {
        "data": query_embeds["dense"],
        "anns_field": "content_dense",
        "param": {
            "metric_type": "IP",
            "ef": 20,
        },
        "limit": limit,
    }
    content_dense_request = AnnSearchRequest(**content_dense_search_params)

    reqs = [
        content_sparse_request,
        content_dense_request,
    ]

    res = collection.hybrid_search(
        reqs,  # List of AnnSearchRequests created in step 1
        rerank,  # Reranking strategy specified in step 2
        limit=limit,  # Number of final search results to return
        output_fields=["id", "url", "title_text", "content_text"],
    )
    return res

In [5]:
def get_retrieval(df, embed_col, rerank, limit=10):
    full_context = []

    for i, row in df.iterrows():
        res = query(row[embed_col], rerank=rerank, limit=limit)

        context = []
        for item in res[0]:
            item_dict = {
                "id": item.id,
                "score": item.score,
                "title": item.entity.get("title_text"),
                "content": item.entity.get("content_text"),
            }
            context.append(item_dict)
        full_context.append(context)
    return full_context

In [6]:
def hit_rate(expected_ids, actual_ids, k):
    """
    Hit Rate (HR) measures whether the expected item is in the actual top-k list.
    """
    return int(any([item in actual_ids for item in expected_ids]))


def mrr(expected_ids, actual_ids, k):
    """
    Mean Reciprocal Rank (MRR) evaluates how soon the relevant item appears in the result list.
    """
    for rank, item in enumerate(actual_ids[:k], start=1):
        if item in expected_ids:
            return 1 / rank
    return 0


def precision(expected_ids, actual_ids, k):
    """
    Precision is the fraction of retrieved documents that are relevant.
    """
    relevant_items = [item for item in actual_ids[:k] if item in expected_ids]
    return len(relevant_items) / k


def recall(expected_ids, actual_ids, k, relevant_chunks):
    """
    Recall is the fraction of relevant documents that are retrieved.
    """
    relevant_items = [item for item in actual_ids[:k] if item in expected_ids]
    return min(1, len(relevant_items) / relevant_chunks)

In [7]:
def calculate_metrics(df, context_col, k):
    actual_ids_col = "actual_ids"
    df[actual_ids_col] = df[context_col].apply(
        lambda x: list([item["id"] // 10000 for item in x])
    )
    df["hit_rate"] = df.apply(
        lambda x: hit_rate([x["id"]], x[actual_ids_col], k), axis=1
    )
    df["mrr"] = df.apply(lambda x: mrr([x["id"]], x[actual_ids_col], k), axis=1)
    df["precision"] = df.apply(
        lambda x: precision([x["id"]], x[actual_ids_col], k), axis=1
    )
    df["recall"] = df.apply(
        lambda x: recall([x["id"]], x[actual_ids_col], k, x["relevant_chunks"]),
        axis=1,
    )

In [8]:
def describe_metric(df, metric_col, k, question_type):
    mean = df[metric_col].mean()
    std = df[metric_col].std()
    median = df[metric_col].median()
    q1 = df[metric_col].quantile(0.25)
    q3 = df[metric_col].quantile(0.75)
    confidence_99 = df[metric_col].quantile(1 - 0.99)
    confidence_95 = df[metric_col].quantile(1 - 0.95)
    confidence_90 = df[metric_col].quantile(1 - 0.90)
    error_point = len(df[df[metric_col] == 0]) / len(df)

    index = pd.MultiIndex.from_tuples(
        ((question_type, metric_col, k),), names=["question_type", "metric", "k"]
    )
    describe_metric = pd.DataFrame(
        {
            "mean": mean,
            "std": std,
            "median": median,
            "confidence_25": q3,
            "confidence_75": q1,
            "confidence_90": confidence_90,
            "confidence_95": confidence_95,
            "confidence_99": confidence_99,
            "zero_point": error_point,
        },
        index=index,
    )
    return describe_metric

In [9]:
df = pd.read_pickle("data/benchmark.pkl")
print(df.shape)
df.head()

(200, 12)


Unnamed: 0,id,document_name,relevant_chunks,easy_question,easy_answer,medium_question,medium_answer,hard_question,hard_answer,easy_embeddings,medium_embeddings,hard_embeddings
0,4461,Quyết định 08/2022/QĐ-UBND về tỷ lệ phần trăm ...,1,Tỷ lệ phần trăm điều tiết nguồn thu tiền sử dụ...,\n Căn cứ tại Điều 1 Quyết định 08/...,Những cơ quan nào có trách nhiệm tổ chức thực ...,\n Căn cứ tại khoản 2 Điều 2 Quyết ...,Cơ quan nào có thẩm quyền ban hành Nghị quyết ...,\n Căn cứ tại Điều 1 Quyết định 08/...,"{'dense': [[-0.03312935, 0.0031744768, -0.0036...","{'dense': [[-0.058976743, -0.014345074, -0.024...","{'dense': [[-0.056709077, 0.011795056, 0.00993..."
1,3101,Quyết định 330/QĐ-UBND năm 2023 về Kế hoạch tổ...,5,Dự thảo Luật Đất đai (sửa đổi) sẽ được lấy ý k...,\n Căn cứ tại văn bản Quyết định 33...,Mục đích của việc tổ chức lấy ý kiến Nhân dân ...,\n Căn cứ tại văn bản Quyết định 33...,Cơ quan nào sẽ chịu trách nhiệm thi hành Quyết...,\n Căn cứ tại văn bản Quyết định 33...,"{'dense': [[-0.0021552492, 0.037226226, 0.0037...","{'dense': [[-0.024163222, 0.00011042302, -0.01...","{'dense': [[-0.005862655, 0.03228869, 0.001264..."
2,445,Nghị quyết 20/NQ-HĐND thông qua Danh mục điều ...,2,"Tỉnh Bà Rịa - Vũng Tàu có những huyện, thị xã,...",\n Căn cứ tại Nghị quyết 20/NQ-HĐND...,Dự án Đường Long Sơn - Cái Mép có diện tích đấ...,\n Căn cứ tại Điều 2 Nghị quyết 20/...,Tỉnh Bà Rịa - Vũng Tàu có những dự án nào cần ...,\n Căn cứ tại Điều 1 Nghị quyết 20/...,"{'dense': [[-0.064513795, 0.066465534, -0.0119...","{'dense': [[-0.043495033, 0.008877934, -0.0416...","{'dense': [[-0.0394588, 0.030524949, 0.0020539..."
3,189,Nghị quyết 29/NQ-HĐND điều chỉnh danh mục các ...,15,Nghị quyết 29/NQ-HĐND được ban hành vào ngày b...,\n Căn cứ tại phần đầu của Nghị quy...,Nghị quyết 29/NQ-HĐND quy định về những dự án ...,\n Căn cứ tại Điều 1 Nghị quyết 29/...,Nghị quyết 29/NQ-HĐND quy định về việc thu hồi...,\n Căn cứ tại Điều 1 Nghị quyết 29/...,"{'dense': [[-0.013742263, -0.009551991, -0.024...","{'dense': [[-0.039806806, -0.010978846, -0.007...","{'dense': [[-0.038123883, 0.032869037, -0.0161..."
4,3077,Quyết định 11/2023/QĐ-UBND quy định đơn giá câ...,18,Khi nào thì được áp dụng quyết định về đơn giá...,\n Căn cứ vào Điều 5 Quyết định 11/...,Nếu cây trồng không được nêu trong Phụ lục đơn...,\n Căn cứ vào khoản 3 Điều 3 Quyết ...,"Trong trường hợp dự án, hạng mục dự án đã được...",\n Căn cứ vào khoản 2 Điều 5 Quyết ...,"{'dense': [[-0.062342368, 0.030936865, -0.0087...","{'dense': [[-0.055531275, 0.005023834, -0.0154...","{'dense': [[-0.027396763, 0.024663005, 0.01161..."


In [10]:
rerank = RRFRanker()
# rerank = WeightedRanker(0.9, 0.1)

In [11]:
df["easy_context"] = get_retrieval(df, "easy_embeddings", rerank, 10)
df["medium_context"] = get_retrieval(df, "medium_embeddings", rerank, 10)
df["hard_context"] = get_retrieval(df, "hard_embeddings", rerank, 10)

In [12]:
list_df = []
for context in ["easy_context", "medium_context", "hard_context"]:
    for k in [1, 2, 3, 5, 10]:
        calculate_metrics(df=df, context_col=context, k=k)
        for metric in ["hit_rate", "mrr", "precision", "recall"]:
            metrics = describe_metric(
                df=df,
                metric_col=metric,
                k=k,
                question_type=context.replace("_context", "_question"),
            )
            list_df.append(metrics)
full_metrics = pd.concat(list_df)
full_metrics.sort_index(level=[0, 1, 2], ascending=True, inplace=True)

In [13]:
full_metrics[full_metrics.index.get_level_values("question_type") == "easy_question"]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mean,std,median,confidence_25,confidence_75,confidence_90,confidence_95,confidence_99,zero_point
question_type,metric,k,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
easy_question,hit_rate,1,0.695,0.461563,1.0,1.0,0.0,0.0,0.0,0.0,0.305
easy_question,hit_rate,2,0.695,0.461563,1.0,1.0,0.0,0.0,0.0,0.0,0.305
easy_question,hit_rate,3,0.695,0.461563,1.0,1.0,0.0,0.0,0.0,0.0,0.305
easy_question,hit_rate,5,0.695,0.461563,1.0,1.0,0.0,0.0,0.0,0.0,0.305
easy_question,hit_rate,10,0.695,0.461563,1.0,1.0,0.0,0.0,0.0,0.0,0.305
easy_question,mrr,1,0.395,0.490077,0.0,1.0,0.0,0.0,0.0,0.0,0.605
easy_question,mrr,2,0.46,0.465816,0.5,1.0,0.0,0.0,0.0,0.0,0.475
easy_question,mrr,3,0.47,0.459335,0.5,1.0,0.0,0.0,0.0,0.0,0.445
easy_question,mrr,5,0.48625,0.446321,0.5,1.0,0.0,0.0,0.0,0.0,0.375
easy_question,mrr,10,0.495962,0.437033,0.5,1.0,0.0,0.0,0.0,0.0,0.305


In [14]:
full_metrics[full_metrics.index.get_level_values("question_type") == "medium_question"]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mean,std,median,confidence_25,confidence_75,confidence_90,confidence_95,confidence_99,zero_point
question_type,metric,k,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
medium_question,hit_rate,1,0.675,0.46955,1.0,1.0,0.0,0.0,0.0,0.0,0.325
medium_question,hit_rate,2,0.675,0.46955,1.0,1.0,0.0,0.0,0.0,0.0,0.325
medium_question,hit_rate,3,0.675,0.46955,1.0,1.0,0.0,0.0,0.0,0.0,0.325
medium_question,hit_rate,5,0.675,0.46955,1.0,1.0,0.0,0.0,0.0,0.0,0.325
medium_question,hit_rate,10,0.675,0.46955,1.0,1.0,0.0,0.0,0.0,0.0,0.325
medium_question,mrr,1,0.395,0.490077,0.0,1.0,0.0,0.0,0.0,0.0,0.605
medium_question,mrr,2,0.455,0.46805,0.5,1.0,0.0,0.0,0.0,0.0,0.485
medium_question,mrr,3,0.466667,0.460621,0.5,1.0,0.0,0.0,0.0,0.0,0.45
medium_question,mrr,5,0.482417,0.448056,0.5,1.0,0.0,0.0,0.0,0.0,0.38
medium_question,mrr,10,0.489234,0.441552,0.5,1.0,0.0,0.0,0.0,0.0,0.325


In [15]:
full_metrics[full_metrics.index.get_level_values("question_type") == "hard_question"]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mean,std,median,confidence_25,confidence_75,confidence_90,confidence_95,confidence_99,zero_point
question_type,metric,k,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
hard_question,hit_rate,1,0.685,0.465682,1.0,1.0,0.0,0.0,0.0,0.0,0.315
hard_question,hit_rate,2,0.685,0.465682,1.0,1.0,0.0,0.0,0.0,0.0,0.315
hard_question,hit_rate,3,0.685,0.465682,1.0,1.0,0.0,0.0,0.0,0.0,0.315
hard_question,hit_rate,5,0.685,0.465682,1.0,1.0,0.0,0.0,0.0,0.0,0.315
hard_question,hit_rate,10,0.685,0.465682,1.0,1.0,0.0,0.0,0.0,0.0,0.315
hard_question,mrr,1,0.455,0.49922,0.0,1.0,0.0,0.0,0.0,0.0,0.545
hard_question,mrr,2,0.5,0.478167,0.5,1.0,0.0,0.0,0.0,0.0,0.455
hard_question,mrr,3,0.513333,0.468541,0.5,1.0,0.0,0.0,0.0,0.0,0.415
hard_question,mrr,5,0.525083,0.458334,0.5,1.0,0.0,0.0,0.0,0.0,0.365
hard_question,mrr,10,0.531938,0.451381,0.5,1.0,0.0,0.0,0.0,0.0,0.315
