In [27]:
import pandas as pd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import sys
import os

In [28]:
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.insert(0, parent_dir)
from services.scenario_service import ScenarioService

In [72]:
test_df = pd.read_csv('test_data/test_scenarios.csv')    
test_df['text_length'] = test_df['text'].apply(len)
test_df = test_df.sort_values(by='text_length', ascending=True)
text_df = test_df[test_df['text_length'] >= 100]
text_df = text_df.head(10)
text_df

Unnamed: 0,id,text,sections,text_length
7147,139473951,"The application for anticipatory bail is, thus...","379, 34, 324, 354, 498, 323, 326",100
6742,126690523,The application for anticipatory bail is dispo...,"325, 147, 379, 307, 438, 506, 323",100
5912,101344383,There can hardly be any cavil that there has t...,"304, 304, 5, 279, 188, 338, 337, 300",101
5459,88592299,"The application for anticipatory bail is, thus...","325, 34, 447, 438, 506, 323, 354",102
9898,193398036,"The application for anticipatory bail is, thus...","498, 506, 376, 323, 325, 34",104
4228,5340308,"Considered I.A. No.11252/18, an application fo...","323, 365, 394",105
3843,41693864,This application has been filed with the praye...,"34, 323, 379",105
4136,51485411,"Accordingly, the prayer for anticipatory bail ...","353, 332, 188, 506, 34, 186",105
4728,67502505,And In the matter of: Imran Mondal @ Imran Isl...,"438, 306, 161, 509",105
1028,115122832,"The application for anticipatory bail is, thus...","379, 341, 325, 307, 342, 323, 34",106


In [73]:
def evaluate_scenarios(service, top_k=5):
    test_df = pd.read_csv('test_data/test_scenarios.csv')    
    test_df['text_length'] = test_df['text'].apply(len)
    test_df = test_df.sort_values(by='text_length', ascending=True)
    text_df = test_df[test_df['text_length'] >= 100]
    text_df = text_df.head(10)
    precision_scores = []
    recall_scores = []
    reciprocal_ranks = []

    for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Evaluating"):
        query = row["text"]
        ground_truth = set(s.strip() for s in row["sections"].split(","))

        top_sections = service.get_top_scenarios(query, history=[], top_k=top_k)
        predicted = [sec["Section Number"] for sec in top_sections]

        hits = sum(1 for sec in predicted if sec in ground_truth)
        precision = hits / top_k
        precision_scores.append(precision)
        print (precision_scores)

        recall = hits / len(ground_truth)
        recall_scores.append(recall)
        print(recall_scores)

        rr = 0
        for idx, sec in enumerate(predicted):
            if sec in ground_truth:
                rr = 1 / (idx + 1)
                break
        reciprocal_ranks.append(rr)
    results = {
        "Mean Precision@k": round(sum(precision_scores) / 4, 4),
        "Mean Recall@k": round(sum(recall_scores) / 4, 4),
        "MRR": round(sum(reciprocal_ranks) / 4, 4),
    }
    return results, precision_scores, recall_scores, reciprocal_ranks


In [None]:
scenario_service = ScenarioService(dataset_path="../data/Updated_BNS_Dataset.csv")
results, precisions, recalls, mrrs = evaluate_scenarios(scenario_service, top_k=5)
results