In [1]:
import os
import sys

# Absolute path to your project root (set explicitly if needed)
PROJECT_ROOT = "c:/vscode-projects/legal-rag-project"

# Set working directory
os.chdir(PROJECT_ROOT)
print("Working directory set to:", os.getcwd())

# Add to sys.path for module imports
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
    print("Added to sys.path:", PROJECT_ROOT)


Working directory set to: c:\vscode-projects\legal-rag-project
Added to sys.path: c:/vscode-projects/legal-rag-project


In [2]:
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from qdrant_client.models import SparseVector
from retrieval import get_qdrant_client, search_with_precomputed_vectors, match_article
from logger import get_logger

logger = get_logger("EVALUATE-LAWS")
# Load .pkl dataset
with open("evaluation\eval_query_all_embeddings.pkl", "rb") as f:
    eval_dset = pickle.load(f)

queries = eval_dset["queries"]
relevant_articles = eval_dset["relevant_articles"]

# Define configs: model -> collection mapping
model_configs = [
    {"model_name": "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "collection_name": "qwen-laws-2048-chunks"},
    {"model_name": "jinaai/jina-embeddings-v3", "collection_name": "jina-laws-2048-chunks"},
    {"model_name": "Alibaba-NLP/gte-multilingual-base", "collection_name": "gte-laws-2048-chunks"},
    {"model_name": "BAAI/bge-m3", "collection_name": "bge-laws-2048-chunks"},
]



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
top_k = 5
client = get_qdrant_client()
VERBOSE = False  # Set to True to enable debug logging

def compute_metrics(retrieved, relevant, k):
    top_k = retrieved[:k]
    relevant_set = set((g["law_code"], g["law_number"]) for g in relevant)

    correct = [pred for pred in top_k if (pred["law_code"], pred["law_number"]) in relevant_set]
    precision_at_k = len(correct) / k
    recall_at_k = len(correct) / len(relevant) if relevant else 0.0

    hits = 0
    ap = 0.0
    for i, pred in enumerate(top_k):
        if (pred["law_code"], pred["law_number"]) in relevant_set:
            hits += 1
            ap += hits / (i + 1)
    map_at_k = ap / len(relevant) if relevant else 0.0

    rr = 0.0
    for i, pred in enumerate(retrieved):
        if (pred["law_code"], pred["law_number"]) in relevant_set:
            rr = 1 / (i + 1)
            break

    dcg = 0.0
    idcg = sum([1 / np.log2(i + 2) for i in range(min(k, len(relevant)))])
    for i, pred in enumerate(top_k):
        if (pred["law_code"], pred["law_number"]) in relevant_set:
            dcg += 1 / np.log2(i + 2)
    ndcg = dcg / idcg if idcg > 0 else 0.0

    return precision_at_k, map_at_k, rr, ndcg, recall_at_k

results = []

for config in model_configs:
    model_name = config["model_name"]
    collection = config["collection_name"]
    embeddings = eval_dset[model_name]

    logger.info(f"Evaluating model: {model_name} on collection: {collection}")
    hit1 = 0
    hitk = 0
    failures = 0
    pk = []
    mapk = []
    mrr = []
    ndcg = []
    recall_values = []
    hits = []

    progress = tqdm(enumerate(queries), total=len(queries))
    total = len(queries)
    for i, query in progress:
        dense_vec = embeddings[i]
        relevant = relevant_articles[i]

        try:
            try:
                res = search_with_precomputed_vectors(
                    client=client,
                    collection_name=collection,
                    top_k=top_k,
                    retriever_type="dense",
                    dense_vector=dense_vec
                )
            except Exception as e:
                logger.error(f"[Search Error] Query {i}: {query}\n{e}")
                failures += 1
                continue

            try:
                retrieved = []
                for r in res.points:
                    payload = r.payload or {}
                    meta = payload.get("metadata", {})
                    law_code = meta.get("law_code")
                    law_number = meta.get("law_number")
                    if law_code and law_number:
                        retrieved.append({
                            "law_code": law_code,
                            "law_number": law_number
                        })
                seen = set()
                deduped = []
                for item in retrieved:
                    key = (item["law_code"], item["law_number"])
                    if key not in seen:
                        deduped.append(item)
                        seen.add(key)
                retrieved = deduped

            except Exception as e:
                logger.error(f"[Parsing Error] Query {i}: {query}\n{e}")
                failures += 1
                continue

            if not retrieved:
                logger.warning(f"[Empty Result] Query {i}: {query}")
                continue

            try:
                p, ap, r_, n, recall = compute_metrics(retrieved, relevant, top_k)
            except Exception as e:
                logger.error(f"[Metric Error] Query {i}: {query}\n{e}")
                failures += 1
                continue

            try:
                top_hit = next(x for x in retrieved if match_article(x, relevant[0]))
                hit1 += int(retrieved.index(top_hit) == 0)
            except StopIteration:
                pass
            except Exception as e:
                logger.warning(f"[Hit@1 Index Error] Query {i}: {query}\n{e}")
            hitk_it = int(p > 0)
            hitk += hitk_it
            hits.append(hitk_it)
            pk.append(p)
            mapk.append(ap)
            mrr.append(r_)
            ndcg.append(n)
            recall_values.append(recall)
            if hitk_it < recall:
                logger.warning(
                    f"[Hit@{top_k} Mismatch] Query {i}: Hit@{top_k} < Recall@{top_k}\n"
                    f"Hit@{top_k}: {hitk_it} | Recall@{top_k}: {recall}\n"
                    f"Precision@{top_k}: {p}\n"
                    f"Query: {query}\n"
                    f"Relevant: {relevant}\n"
                    f"Retrieved: {retrieved}"
                )
            if VERBOSE:
                logger.info(
                    f"[Query {i}] P@{top_k}: {p:.3f} | R@{top_k}: {recall:.3f} | AP: {ap:.3f} | RR: {r_:.3f} | NDCG: {n:.3f}"
                )

        except Exception as e:
            logger.exception(f"[Unknown Error] Query {i}: {query}\n{e}")
            failures += 1

        progress.set_description(
            f"{model_name} | Hit@{top_k} Rate: {hitk/total:.3f} | Recall@{top_k}: {np.mean(recall_values):.3f}"
        )
    # print(all(h >= r for h, r in zip(hits, recall_values)))
    results.append({
        "model": model_name,
        "Hit@1": round(hit1 / total, 3),
        f"Hit@{top_k}": round(hitk / total, 3),
        f"MAP@{top_k}": round(np.mean(mapk), 3),
        "MRR": round(np.mean(mrr), 3),
        f"NDCG@{top_k}": round(np.mean(ndcg), 3),
        f"Recall@{top_k}": round(np.mean(recall_values), 3),
        f"Precision@{top_k}": round(np.mean(pk), 3),
    })

models_results = pd.DataFrame(results)
display(models_results.sort_values("MRR", ascending=False).reset_index(drop=True))
models_results.to_csv(
    rf"C:\vscode-projects\legal-rag-project\eval_results\models_comparison_at{top_k}.csv", index=False
)



2025-04-21 22:58:47,020 - INFO - EVALUATE-LAWS - Evaluating model: Alibaba-NLP/gte-Qwen2-1.5B-instruct on collection: qwen-laws-2048-chunks
Alibaba-NLP/gte-Qwen2-1.5B-instruct | Hit@5 Rate: 0.646 | Recall@5: 0.591: 100%|██████████| 206/206 [00:06<00:00, 33.49it/s]
2025-04-21 22:58:53,181 - INFO - EVALUATE-LAWS - Evaluating model: jinaai/jina-embeddings-v3 on collection: jina-laws-2048-chunks
jinaai/jina-embeddings-v3 | Hit@5 Rate: 0.626 | Recall@5: 0.579: 100%|██████████| 206/206 [00:04<00:00, 48.74it/s]
2025-04-21 22:58:57,412 - INFO - EVALUATE-LAWS - Evaluating model: Alibaba-NLP/gte-multilingual-base on collection: gte-laws-2048-chunks
Alibaba-NLP/gte-multilingual-base | Hit@5 Rate: 0.597 | Recall@5: 0.553: 100%|██████████| 206/206 [00:04<00:00, 42.80it/s]
2025-04-21 22:59:02,229 - INFO - EVALUATE-LAWS - Evaluating model: BAAI/bge-m3 on collection: bge-laws-2048-chunks
BAAI/bge-m3 | Hit@5 Rate: 0.636 | Recall@5: 0.576: 100%|██████████| 206/206 [00:04<00:00, 42.63it/s]


Unnamed: 0,model,Hit@1,Hit@5,MAP@5,MRR,NDCG@5,Recall@5,Precision@5
0,BAAI/bge-m3,0.35,0.636,0.429,0.47,0.477,0.576,0.134
1,Alibaba-NLP/gte-Qwen2-1.5B-instruct,0.311,0.646,0.412,0.454,0.467,0.591,0.138
2,jinaai/jina-embeddings-v3,0.301,0.626,0.408,0.439,0.459,0.579,0.132
3,Alibaba-NLP/gte-multilingual-base,0.316,0.597,0.401,0.433,0.448,0.553,0.131


In [4]:
modes = ["dense", "hybrid", "sparse"]
results_bge = []

for mode in modes:
    print(f"Evaluating BGE (mode: {mode})")
    hit1 = 0
    hitk = 0
    pk = []
    mapk = []
    mrr = []
    ndcg = []
    recall_values = []
    hits = []

    progress = tqdm(enumerate(eval_dset["queries"]), total=len(eval_dset["queries"]))
    total = len(eval_dset["queries"])

    for i, query in progress:
        dense = eval_dset["BAAI/bge-m3"][i]
        indices = eval_dset["bm25_indices"][i]
        values = eval_dset["bm25_values"][i]
        relevant = eval_dset["relevant_articles"][i]
        sparse_vector = SparseVector(indices=indices, values=values)

        try:
            res = search_with_precomputed_vectors(
                client=client,
                collection_name="bge-laws-2048-chunks",
                top_k=top_k,
                retriever_type=mode,
                dense_vector=dense,
                sparse_vector=sparse_vector,
            )
        except Exception as e:
            logger.error(f"[Search Error] Query {i}: {query}\n{e}")
            continue

        try:
            retrieved = []
            for r in res.points:
                meta = (r.payload or {}).get("metadata", {})
                law_code = meta.get("law_code")
                law_number = meta.get("law_number")
                if law_code and law_number:
                    retrieved.append({
                        "law_code": law_code,
                        "law_number": law_number
                    })
            seen = set()
            deduped = []
            for item in retrieved:
                key = (item["law_code"], item["law_number"])
                if key not in seen:
                    deduped.append(item)
                    seen.add(key)
            retrieved = deduped
        except Exception as e:
            logger.error(f"[Parsing Error] Query {i}: {query}\n{e}")
            continue

        if not retrieved:
            logger.warning(f"[Empty Result] Query {i}: {query}")
            continue

        try:
            p, ap, r_, n, recall = compute_metrics(retrieved, relevant, top_k)
            
        except Exception as e:
            logger.error(f"[Metric Error] Query {i}: {query}\n{e}")
            continue

        try:
            top_hit = next(x for x in retrieved if match_article(x, relevant[0]))
            hit1 += int(retrieved.index(top_hit) == 0)
        except StopIteration:
            pass
        except Exception as e:
            logger.warning(f"[Hit@1 Index Error] Query {i}: {query}\n{e}")

        hitk_it = int(p > 0)
        hitk += hitk_it
        hits.append(hitk_it)
        pk.append(p)
        mapk.append(ap)
        mrr.append(r_)
        ndcg.append(n)
        recall_values.append(recall)

        if hitk_it < recall:
            logger.warning(
                f"[Hit@{top_k} Mismatch] Query {i}: Hit@{top_k} < Recall@{top_k}\n"
                f"Hit@{top_k}: {hitk_it} | Recall@{top_k}: {recall}\n"
                f"Precision@{top_k}: {p}\n"
                f"Query: {query}\n"
                f"Relevant: {relevant}\n"
                f"Retrieved: {retrieved}"
            )

        if VERBOSE:
            logger.info(
                f"[Query {i}] Mode: {mode} | P@{top_k}: {p:.3f} | R@{top_k}: {recall:.3f} | AP: {ap:.3f} | RR: {r_:.3f} | NDCG: {n:.3f}"
            )

        progress.set_description(
            f"{mode} | Hit@{top_k}: {hitk/total:.3f} | Recall@{top_k}: {np.mean(recall_values):.3f}"
        )

    results_bge.append({
        "mode": mode,
        "Hit@1": round(hit1 / total, 3),
        f"Hit@{top_k}": round(hitk / total, 3),
        f"MAP@{top_k}": round(np.mean(mapk), 3),
        "MRR": round(np.mean(mrr), 3),
        f"NDCG@{top_k}": round(np.mean(ndcg), 3),
        f"Recall@{top_k}": round(np.mean(recall_values), 3),
        f"Precision@{top_k}": round(np.mean(pk), 3),
    })

mode_results = pd.DataFrame(results_bge)
display(mode_results.sort_values("MRR", ascending=False).reset_index(drop=True))
mode_results.to_csv(
    rf"C:\vscode-projects\legal-rag-project\eval_results\retrieval_modes_comparison_at{top_k}.csv",
    index=False
)


Evaluating BGE (mode: dense)


dense | Hit@5: 0.636 | Recall@5: 0.576: 100%|██████████| 206/206 [00:04<00:00, 47.58it/s]


Evaluating BGE (mode: hybrid)


hybrid | Hit@5: 0.553 | Recall@5: 0.504: 100%|██████████| 206/206 [00:04<00:00, 50.66it/s]


Evaluating BGE (mode: sparse)


sparse | Hit@5: 0.214 | Recall@5: 0.193: 100%|██████████| 206/206 [00:03<00:00, 67.19it/s]


Unnamed: 0,mode,Hit@1,Hit@5,MAP@5,MRR,NDCG@5,Recall@5,Precision@5
0,dense,0.35,0.636,0.429,0.47,0.477,0.576,0.134
1,hybrid,0.214,0.553,0.329,0.361,0.382,0.504,0.116
2,sparse,0.073,0.214,0.12,0.135,0.142,0.193,0.046


In [5]:
from retrieval import ONNXReranker

# Load ONNX-based BGE-M3 reranker
top_rerank_k = 10  # number retrieved before rerank
reranker = ONNXReranker("bge-reranker-v2-m3-onnx-o3-cpu/model.onnx")
# eval_dset = {k: v[:25] for k, v in eval_dset.items()}

results_reranked = []
modes = ["dense", "hybrid"]

for mode in modes:
    print(f"Evaluating BGE reranker on mode: {mode}")
    hit1 = 0
    hitk = 0
    pk = []
    mapk = []
    mrr = []
    ndcg = []
    recall_values = []
    hits = []

    total = len(eval_dset["queries"])
    progress = tqdm(enumerate(eval_dset["queries"]), total=total)

    for i, query in progress:
        dense = eval_dset["BAAI/bge-m3"][i]
        indices = eval_dset["bm25_indices"][i]
        values = eval_dset["bm25_values"][i]
        relevant = eval_dset["relevant_articles"][i]
        sparse_vector = SparseVector(indices=indices, values=values)

        try:
            res = search_with_precomputed_vectors(
                client=client,
                collection_name="bge-laws-2048-chunks",
                top_k=top_rerank_k,
                retriever_type=mode,
                dense_vector=dense,
                sparse_vector=sparse_vector,
            )
        except Exception as e:
            logger.error(f"[Search Error] Query {i}: {query}\n{e}")
            continue

        # Collect and deduplicate retrieved documents
        try:
            docs = []
            seen = set()
            for r in res.points:
                payload = r.payload or {}
                meta = payload.get("metadata", {})
                law_code = meta.get("law_code")
                law_number = meta.get("law_number")
                text = payload.get("text", "")
                if law_code and law_number:
                    key = (law_code, law_number)
                    if key not in seen:
                        docs.append({
                            "text": text,
                            "law_code": law_code,
                            "law_number": law_number
                        })
                        seen.add(key)
        except Exception as e:
            logger.error(f"[Parsing Error] Query {i}: {query}\n{e}")
            continue

        if not docs:
            logger.warning(f"[Empty Result] Query {i}: {query}")
            continue

        # Rerank
        try:
            pairs = [(query, doc["text"]) for doc in docs]
            scores = reranker.predict(pairs)
            reranked = [doc for doc, _ in sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)][:top_k]
            retrieved = [{"law_code": doc["law_code"], "law_number": doc["law_number"]} for doc in reranked]
        except Exception as e:
            logger.error(f"[Reranking Error] Query {i}: {query}\n{e}")
            continue

        # Metrics
        try:
            p, ap, r_, n, recall = compute_metrics(retrieved, relevant, top_k)
        except Exception as e:
            logger.error(f"[Metric Error] Query {i}: {query}\n{e}")
            continue

        # Hit@1 logic
        try:
            top_hit = next(x for x in retrieved if match_article(x, relevant[0]))
            hit1 += int(retrieved.index(top_hit) == 0)
        except StopIteration:
            pass
        except Exception as e:
            logger.warning(f"[Hit@1 Index Error] Query {i}: {query}\n{e}")

        hitk_it = int(p > 0)
        hitk += hitk_it
        hits.append(hitk_it)
        pk.append(p)
        mapk.append(ap)
        mrr.append(r_)
        ndcg.append(n)
        recall_values.append(recall)

        if hitk_it < recall:
            logger.warning(
                f"[Hit@{top_k} Mismatch] Query {i}: Hit@{top_k} < Recall@{top_k}\n"
                f"Hit@{top_k}: {hitk_it} | Recall@{top_k}: {recall}\n"
                f"Precision@{top_k}: {p}\n"
                f"Query: {query}\n"
                f"Relevant: {relevant}\n"
                f"Retrieved: {retrieved}"
            )

        if VERBOSE:
            logger.info(
                f"[Query {i}] Mode: {mode} | P@{top_k}: {p:.3f} | R@{top_k}: {recall:.3f} | AP: {ap:.3f} | RR: {r_:.3f} | NDCG: {n:.3f}"
            )

        progress.set_description(
            f"{mode}+rerank | Hit@{top_k}: {hitk/total:.3f} | Recall@{top_k}: {np.mean(recall_values):.3f}"
        )

    results_reranked.append({
        "mode": mode + "+rerank",
        "Hit@1": round(hit1 / total, 3),
        f"Hit@{top_k}": round(hitk / total, 3),
        f"MAP@{top_k}": round(np.mean(mapk), 3),
        "MRR": round(np.mean(mrr), 3),
        f"NDCG@{top_k}": round(np.mean(ndcg), 3),
        f"Recall@{top_k}": round(np.mean(recall_values), 3),
        f"Precision@{top_k}": round(np.mean(pk), 3),
    })

# Display & save results
reranked_df = pd.DataFrame(results_reranked)
display(reranked_df.sort_values("MRR", ascending=False).reset_index(drop=True))

reranked_df.to_csv(
    rf"C:\vscode-projects\legal-rag-project\eval_results\reranked_modes_comparison_at{top_k}.csv",
    index=False
)


Evaluating BGE reranker on mode: dense


dense+rerank | Hit@5: 0.578 | Recall@5: 0.633:  85%|████████▌ | 176/206 [4:07:35<2:14:40, 269.36s/it]2025-04-22 03:07:14,113 - ERROR - retrieval.tools - Exception in: search_with_precomputed_vectors()
Traceback (most recent call last):
  File "c:\vscode-projects\legal-rag-project\.venv\lib\site-packages\httpx\_transports\default.py", line 101, in map_httpcore_exceptions
    yield
  File "c:\vscode-projects\legal-rag-project\.venv\lib\site-packages\httpx\_transports\default.py", line 250, in handle_request
    resp = self._pool.handle_request(req)
  File "c:\vscode-projects\legal-rag-project\.venv\lib\site-packages\httpcore\_sync\connection_pool.py", line 256, in handle_request
    raise exc from None
  File "c:\vscode-projects\legal-rag-project\.venv\lib\site-packages\httpcore\_sync\connection_pool.py", line 236, in handle_request
    response = connection.handle_request(
  File "c:\vscode-projects\legal-rag-project\.venv\lib\site-packages\httpcore\_sync\connection.py", line 103, in ha

Evaluating BGE reranker on mode: hybrid


hybrid+rerank | Hit@5: 0.636 | Recall@5: 0.580: 100%|██████████| 206/206 [4:25:59<00:00, 77.47s/it]   


Unnamed: 0,mode,Hit@1,Hit@5,MAP@5,MRR,NDCG@5,Recall@5,Precision@5
0,dense+rerank,0.427,0.68,0.491,0.536,0.534,0.617,0.144
1,hybrid+rerank,0.413,0.636,0.468,0.508,0.506,0.58,0.135
