In [None]:
import os
import sys

# Absolute path to your project root (set explicitly if needed)
PROJECT_ROOT = "c:/vscode-projects/legal-rag-project"

# Set working directory
os.chdir(PROJECT_ROOT)
print("Working directory set to:", os.getcwd())

# Add to sys.path for module imports
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
    print("Added to sys.path:", PROJECT_ROOT)


In [14]:
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from qdrant_client.models import SparseVector
from retrieval import get_qdrant_client, search_with_precomputed_vectors, match_article
# Load .pkl dataset
with open("evaluation\eval_query_all_embeddings.pkl", "rb") as f:
    eval_dset = pickle.load(f)

queries = eval_dset["queries"]
relevant_articles = eval_dset["relevant_articles"]

# Define configs: model -> collection mapping
model_configs = [
    {"model_name": "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "collection_name": "qwen-laws-2048-chunks"},
    {"model_name": "jinaai/jina-embeddings-v3", "collection_name": "jina-laws-2048-chunks"},
    {"model_name": "Alibaba-NLP/gte-multilingual-base", "collection_name": "gte-laws-2048-chunks"},
    {"model_name": "BAAI/bge-m3", "collection_name": "bge-laws-2048-chunks"},
]



In [12]:
top_k = 5
client = get_qdrant_client()
VERBOSE = False  # Set to False to suppress debugging prints

def compute_metrics(retrieved, relevant, k):
    top_k = retrieved[:k]
    relevant_set = set((g["law_code"], g["law_number"]) for g in relevant)

    correct = [pred for pred in top_k if (pred["law_code"], pred["law_number"]) in relevant_set]
    precision_at_k = len(correct) / k

    hits = 0
    ap = 0.0
    for i, pred in enumerate(top_k):
        if (pred["law_code"], pred["law_number"]) in relevant_set:
            hits += 1
            ap += hits / (i + 1)
    map_at_k = ap / len(relevant) if relevant else 0.0

    rr = 0.0
    for i, pred in enumerate(retrieved):
        if (pred["law_code"], pred["law_number"]) in relevant_set:
            rr = 1 / (i + 1)
            break

    dcg = 0.0
    idcg = sum([1 / np.log2(i + 2) for i in range(min(k, len(relevant)))])
    for i, pred in enumerate(top_k):
        if (pred["law_code"], pred["law_number"]) in relevant_set:
            dcg += 1 / np.log2(i + 2)
    ndcg = dcg / idcg if idcg > 0 else 0.0

    return precision_at_k, map_at_k, rr, ndcg


results = []

for config in model_configs:
    model_name = config["model_name"]
    collection = config["collection_name"]
    embeddings = eval_dset[model_name]

    print(f"\nEvaluating model: {model_name} on collection: {collection}")
    hit1 = 0
    hitk = 0
    mapk = []
    mrr = []
    ndcg = []

    for i, query in enumerate(tqdm(queries)):
        dense_vec = embeddings[i]
        relevant = relevant_articles[i]

        try:
            res = search_with_precomputed_vectors(
                client=client,
                collection_name=collection,
                top_k=top_k,
                retriever_type="dense",
                dense_vector=dense_vec
            )

            retrieved = []
            for r in res.points:
                payload = r.payload or {}
                meta = payload.get("metadata", {}) or {}

                law_code = meta.get("law_code")
                law_number = meta.get("law_number")

                if not law_code or not law_number:
                    if VERBOSE:
                        print(f"[WARN] Missing law_code/law_number in result {i}")
                    continue

                retrieved.append({
                    "law_code": law_code,
                    "law_number": law_number
                })

            if VERBOSE:
                print(f"\nQuery {i}: {query}")
                print(f"→ Retrieved: {len(retrieved)}")
                if retrieved:
                    print("→ First result:", retrieved[0])
                else:
                    print("→ No valid retrievals")

            if not retrieved:
                continue

            p, ap, r_, n = compute_metrics(retrieved, relevant, top_k)

            try:
                top_hit = next(x for x in retrieved if match_article(x, relevant[0]))
                hit1 += int(retrieved.index(top_hit) == 0)
            except StopIteration:
                pass

            hitk += int(p > 0)
            mapk.append(ap)
            mrr.append(r_)
            ndcg.append(n)

        except Exception as e:
            print(f"\n[ERROR] Failed on query: {query}\n{e}")

    total = len(queries)
    results.append({
        "model": model_name,
        "Hit@1": round(hit1 / total, 3),
        f"Hit@{top_k}": round(hitk / total, 3),
        f"MAP@{top_k}": round(np.mean(mapk), 3),
        "MRR": round(np.mean(mrr), 3),
        f"NDCG@{top_k}": round(np.mean(ndcg), 3),
    })

# Optional: Show results table
import pandas as pd
df = pd.DataFrame(results)
display(df.sort_values("MRR", ascending=False))



Evaluating model: Alibaba-NLP/gte-Qwen2-1.5B-instruct on collection: qwen-laws-2048-chunks


100%|██████████| 209/209 [00:05<00:00, 38.71it/s]



Evaluating model: jinaai/jina-embeddings-v3 on collection: jina-laws-2048-chunks


100%|██████████| 209/209 [00:05<00:00, 39.29it/s]



Evaluating model: Alibaba-NLP/gte-multilingual-base on collection: gte-laws-2048-chunks


100%|██████████| 209/209 [00:04<00:00, 42.76it/s]



Evaluating model: BAAI/bge-m3 on collection: bge-laws-2048-chunks


100%|██████████| 209/209 [00:05<00:00, 34.97it/s]


Unnamed: 0,model,Hit@1,Hit@5,MAP@5,MRR,NDCG@5
3,BAAI/bge-m3,0.344,0.627,0.441,0.463,0.483
0,Alibaba-NLP/gte-Qwen2-1.5B-instruct,0.306,0.636,0.437,0.447,0.482
1,jinaai/jina-embeddings-v3,0.297,0.617,0.411,0.433,0.46
2,Alibaba-NLP/gte-multilingual-base,0.311,0.589,0.397,0.427,0.442


In [16]:
# Compare BGE Dense vs Hybrid (using precomputed sparse vectors)

modes = ["dense", "hybrid"]
results_bge = []

for mode in modes:
    print(f"Evaluating BGE (mode: {mode})")
    hit1 = 0
    hitk = 0
    mapk = []
    mrr = []
    ndcg = []

    for i, query in enumerate(tqdm(eval_dset["queries"])):
        dense = eval_dset["BAAI/bge-m3"][i]
        indices = eval_dset["bm25_indices"][i]
        values = eval_dset["bm25_values"][i]
        relevant = eval_dset["relevant_articles"][i]
        sparse_vector = SparseVector(indices=indices, values=values)

        try:
            res = search_with_precomputed_vectors(
                client=client,
                collection_name="bge-laws-2048-chunks",
                top_k=top_k,
                retriever_type=mode,
                dense_vector=dense,
                sparse_vector=sparse_vector,
            )

            retrieved = []
            for r in res.points:
                meta = (r.payload or {}).get("metadata", {})
                retrieved.append({
                    "law_code": meta.get("law_code"),
                    "law_number": meta.get("law_number"),
                })

            if not retrieved:
                continue

            p, ap, r_, n = compute_metrics(retrieved, relevant, top_k)
            hit1 += int(p > 0 and retrieved.index([x for x in retrieved if match_article(x, relevant[0])][0]) == 0)
            hitk += int(p > 0)
            mapk.append(ap)
            mrr.append(r_)
            ndcg.append(n)

        except Exception:
            continue

    total = len(eval_dset["queries"])
    results_bge.append({
        "mode": mode,
        "Hit@1": round(hit1 / total, 3),
        f"Hit@{top_k}": round(hitk / total, 3),
        f"MAP@{top_k}": round(np.mean(mapk), 3),
        "MRR": round(np.mean(mrr), 3),
        f"NDCG@{top_k}": round(np.mean(ndcg), 3),
    })

# Add Sparse-only mode for BGE
print("Evaluating BGE (mode: sparse)")
hit1 = 0
hitk = 0
mapk = []
mrr = []
ndcg = []

for i, query in enumerate(tqdm(eval_dset["queries"])):
    indices = eval_dset["bm25_indices"][i]
    values = eval_dset["bm25_values"][i]
    relevant = eval_dset["relevant_articles"][i]
    sparse_vector = SparseVector(indices=indices, values=values)

    try:
        res = search_with_precomputed_vectors(
            client=client,
            collection_name="bge-laws-2048-chunks",
            top_k=top_k,
            retriever_type="sparse",
            sparse_vector=sparse_vector,
        )

        retrieved = []
        for r in res.points:
            meta = (r.payload or {}).get("metadata", {})
            retrieved.append({
                "law_code": meta.get("law_code"),
                "law_number": meta.get("law_number"),
            })

        if not retrieved:
            continue

        p, ap, r_, n = compute_metrics(retrieved, relevant, top_k)
        hit1 += int(p > 0 and retrieved.index([x for x in retrieved if match_article(x, relevant[0])][0]) == 0)
        hitk += int(p > 0)
        mapk.append(ap)
        mrr.append(r_)
        ndcg.append(n)

    except Exception:
        continue

total = len(eval_dset["queries"])
results_bge.append({
    "mode": "sparse",
    "Hit@1": round(hit1 / total, 3),
    f"Hit@{top_k}": round(hitk / total, 3),
    f"MAP@{top_k}": round(np.mean(mapk), 3),
    "MRR": round(np.mean(mrr), 3),
    f"NDCG@{top_k}": round(np.mean(ndcg), 3),
})

# Show final results
pd.DataFrame(results_bge)


Evaluating BGE (mode: dense)


100%|██████████| 209/209 [00:05<00:00, 40.84it/s]


Evaluating BGE (mode: hybrid)


100%|██████████| 209/209 [00:03<00:00, 56.34it/s]


Evaluating BGE (mode: sparse)


100%|██████████| 209/209 [00:03<00:00, 60.51it/s]


Unnamed: 0,mode,Hit@1,Hit@5,MAP@5,MRR,NDCG@5
0,dense,0.344,0.579,0.446,0.453,0.484
1,hybrid,0.23,0.512,0.354,0.359,0.396
2,sparse,0.072,0.196,0.125,0.125,0.141
