In [2]:
import json
from datasets import load_dataset
from oauth2client.service_account import ServiceAccountCredentials
import os
import gspread

dataset = "cuad"

service_account_json_file = os.path.join('../../', 'thesis_service_account.json')
if not os.path.exists(service_account_json_file):
    raise FileNotFoundError(f"[!] service account JSON file not found: {service_account_json_file}")

workshop = "ylkhayat/{dataset_prefix}-generation-workshop"
if dataset == "cuad":
    workshop = workshop.format(dataset_prefix="CUAD")
elif dataset == "obli_qa":
    workshop = workshop.format(dataset_prefix="OBLI_QA")
else:
    raise ValueError("Unknown dataset")

# data_dir = "bm25_relevant_passages_oracle_documents"
data_dir = "dense_relevant_passages_oracle_documents/sentence-transformers_all-MiniLM-L6-v2"
current_dataset = load_dataset(workshop, data_dir=data_dir, split="test")
print(json.dumps(current_dataset[8], indent=4))
def add_record(data, args):
    sheet_columns_args = ["dataset", "  setup"]
    sheet_columns_data = [f"hit_rate.{k}" for k in data["hit_rate"].keys()]

    new_row = []
    for key in sheet_columns_args:
        keys = key.split('.')
        value = args
        for k in keys:
            value = value.get(k, None)
            if value is None:
                break
        new_row.append(value)
    for key in sheet_columns_data:
        keys = key.split('.')
        value = data
        for k in keys:
            value = value.get(k, None)
            if value is None:
                break
        new_row.append(value)
    short_setup = ''.join([part[0] for part in args["setup"].split("_")])
    new_row[0] = new_row[0].upper()
    new_row[1] = short_setup.upper()
    scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/spreadsheets",
             "https://www.googleapis.com/auth/drive.file", "https://www.googleapis.com/auth/drive"]
    creds = ServiceAccountCredentials.from_json_keyfile_name(service_account_json_file, scope)
    client = gspread.authorize(creds)
    spreadsheet_url = "https://docs.google.com/spreadsheets/d/1bE5AbY1hrqlR-_v-ohLCgHA6hFvRCTpH4lrRPqXm9UU/edit?usp=sharing"
    sheet = client.open_by_url(spreadsheet_url)
    worksheet = sheet.worksheet(f'Dataset — Hit Rate')
    worksheet.append_row(new_row, value_input_option="USER_ENTERED")
    print("[!] added record!")

# Define the top_k values
top_k_values = [3, 6, 10, 20, 50]

# Initialize accumulators
hits = {f"top_{k}": 0 for k in top_k_values}
sum_ratios = {f"top_{k}": 0.0 for k in top_k_values}
total = 0

for example in current_dataset:
    raw_gold_text = example['gold_text'].strip()

    if dataset == 'cuad':
        if raw_gold_text.startswith("No Highlights"):
            gold_lines = []
        elif raw_gold_text.startswith("Highlights:"):
            gold_lines = []
            lines = raw_gold_text.split('\n')
            for line in lines[1:]:
                line = line.strip()
                if line.startswith("- '") and line.endswith("'"):
                    extracted = line[3:-1].strip()
                    gold_lines.append(extracted)
                else:
                    extracted = line[1:].strip()
                    gold_lines.append(extracted)
        else:
            gold_lines = [l.strip() for l in raw_gold_text.split('\n') if l.strip()]
    else:
        gold_lines = [line.strip() for line in raw_gold_text.split('\n') if line.strip()]

    top_passages = example['top_50_passages'] if 'top_50_passages' in example else example['top_10_passages']
    if not top_passages or not gold_lines:
        continue

    total += 1

    def count_found(gold_set, passages):
        return sum(1 for gold_line in gold_set if any(gold_line in passage for passage in passages))

    for k in top_k_values:
        top_k_passages = top_passages[:k]
        found = count_found(gold_lines, top_k_passages)
        ratio = found / len(gold_lines) if gold_lines else 0.0
        hits[f"top_{k}"] += found
        sum_ratios[f"top_{k}"] += ratio

avg_ratios = {key: sum_ratios[key] / total if total > 0 else 0.0 for key in sum_ratios}
scores = {"hit_rate": avg_ratios}

args = {
    "dataset": dataset,
    "setup": data_dir,
}
add_record(scores, args)

{
    "docid": "LohaCompanyltd_20191209_F-1_EX-10.16_11917878_EX-10.16_Supply Agreement__Most Favored Nation",
    "previous_text": "Highlight the parts (if any) of this contract related to \"Most Favored Nation\" that should be reviewed by a lawyer. Details: Is there a clause that if a third party gets better terms on the licensing or sale of technology/goods/services described in the contract, the buyer of such technology/goods/services under the contract shall be entitled to those better terms?",
    "gold_text": "No Highlights",
    "citations": [
        [
            "REFERENCE",
            "Exhibit 10.16 SUPPLY CONTRACT Contract No: Date: The buyer/End-User: Shenzhen LOHAS Supply Chain Management Co., Ltd. ADD: Tel No. : Fax No. : The seller: ADD: The Contract is concluded and signed by the Buyer and Seller on , in Hong Kong. 1. General provisions 1.1 This is a framework agreement, the terms and conditions are applied to all purchase orders which signed by this agreement (herei