In [1]:
import os
import json
import pandas as pd
import numpy as np
from sklearn import metrics

In [2]:
IOU_THRESH = 0.5

In [3]:
def get_questions_from_csv():
    df = pd.read_csv("./data/category_descriptions.csv")
    q_dict = {}
    for i in range(df.shape[0]):
        category = df.iloc[i, 0].split("Category: ")[1]
        description = df.iloc[i, 1].split("Description: ")[1]
        q_dict[category.title()] = description
    return q_dict

In [32]:
qtype_dict = get_questions_from_csv()
qtype_dict

{'Document Name': 'The name of the contract',
 'Parties': 'The two or more parties who signed the contract',
 'Agreement Date': 'The date of the contract',
 'Effective Date': 'The date when the contract is effective\xa0',
 'Expiration Date': "On what date will the contract's initial term expire?",
 'Renewal Term': 'What is the renewal term after the initial term expires? This includes automatic extensions and unilateral extensions with prior notice.',
 'Notice Period To Terminate Renewal': 'What is the notice period required to terminate renewal?',
 'Governing Law': "Which state/country's law governs the interpretation of the contract?",
 'Most Favored Nation': 'Is there a clause that if a third party gets better terms on the licensing or sale of technology/goods/services described in the contract, the buyer of such technology/goods/services under the contract shall be entitled to those better terms?',
 'Non-Compete': 'Is there a restriction on the ability of a party to compete with th

In [6]:
def load_json(path):
    with open(path, "r") as f:
        dict = json.load(f)
    return dict

In [7]:
def get_preds(nbest_preds_dict, conf=None):
    results = {}
    for question_id in nbest_preds_dict:
        list_of_pred_dicts = nbest_preds_dict[question_id]
        preds = {}
        for pred_dict in list_of_pred_dicts:
            text = pred_dict["text"]
            prob = pred_dict["probability"]
            if not text == "":  # don't count empty string as a prediction
                preds[text] = prob
        preds_list = [pred for pred in preds.keys() if preds[pred] > conf]
        results[question_id] = preds_list
    return results

In [8]:
def get_answers(test_json_dict):
    results = {}

    data = test_json_dict["data"]
    for contract in data:
        for para in contract["paragraphs"]:
            qas = para["qas"]
            for qa in qas:
                id = qa["id"]
                answers = qa["answers"]
                answers = [answers[i]["text"] for i in range(len(answers))]
                results[id] = answers

    return results

In [9]:
def get_jaccard(gt, pred):
    remove_tokens = [".", ",", ";", ":"]
    for token in remove_tokens:
        gt = gt.replace(token, "")
        pred = pred.replace(token, "")
    gt = gt.lower()
    pred = pred.lower()
    gt = gt.replace("/", " ")
    pred = pred.replace("/", " ")

    gt_words = set(gt.split(" "))
    pred_words = set(pred.split(" "))

    intersection = gt_words.intersection(pred_words)
    union = gt_words.union(pred_words)
    jaccard = len(intersection) / len(union)
    return jaccard

In [10]:
def compute_precision_recall(gt_dict, preds_dict, category=None):
    tp, fp, fn = 0, 0, 0

    for key in gt_dict:
        if category and category not in key:
            continue

        substr_ok = "Parties" in key

        answers = gt_dict[key]
        preds = preds_dict[key]

        # first check if answers is empty
        if len(answers) == 0:
            if len(preds) > 0:
                fp += len(preds)  # false positive for each one
        else:
            for ans in answers:
                assert len(ans) > 0
                # check if there is a match
                match_found = False
                for pred in preds:
                    if substr_ok:
                        is_match = get_jaccard(ans, pred) >= IOU_THRESH or ans in pred
                    else:
                        is_match = get_jaccard(ans, pred) >= IOU_THRESH
                    if is_match:
                        match_found = True

                if match_found:
                    tp += 1
                else:
                    fn += 1

            # now also get any fps by looping through preds
            for pred in preds:
                # Check if there's a match. if so, don't count (don't want to double count based on the above)
                # but if there's no match, then this is a false positive.
                # (Note: we get the true positives in the above loop instead of this loop so that we don't double count
                # multiple predictions that are matched with the same answer.)
                match_found = False
                for ans in answers:
                    assert len(ans) > 0
                    if substr_ok:
                        is_match = get_jaccard(ans, pred) >= IOU_THRESH or ans in pred
                    else:
                        is_match = get_jaccard(ans, pred) >= IOU_THRESH
                    if is_match:
                        match_found = True

                if not match_found:
                    fp += 1

    precision = tp / (tp + fp) if tp + fp > 0 else np.nan
    recall = tp / (tp + fn) if tp + fn > 0 else np.nan

    return precision, recall

In [11]:
def process_precisions(precision):
    """
    Processes precisions to ensure that precision and recall don't both get worse
    Assumes the list precision is sorted in order of recalls
    """
    precision_best = precision[::-1]
    for i in range(1, len(precision_best)):
        precision_best[i] = max(precision_best[i-1], precision_best[i])
    precision = precision_best[::-1]
    return precision

In [12]:
def get_prec_at_recall(precisions, recalls, confs, recall_thresh=0.9):
    """
    Assumes recalls are sorted in increasing order
    """
    processed_precisions = process_precisions(precisions)
    prec_at_recall = 0
    for prec, recall, conf in zip(processed_precisions, recalls, confs):
        if recall >= recall_thresh:
            prec_at_recall = prec
            break
    return prec_at_recall, conf

In [13]:
def get_precisions_recalls(pred_dict, gt_dict, category=None):
    precisions = [1]
    recalls = [0]
    confs = []
    for conf in list(np.arange(0.99, 0, -0.01)) + [0.001, 0]:
        conf_thresh_pred_dict = get_preds(pred_dict, conf)
        prec, recall = compute_precision_recall(gt_dict, conf_thresh_pred_dict, category=category)
        precisions.append(prec)
        recalls.append(recall)
        confs.append(conf)
    return precisions, recalls, confs

In [14]:
def get_aupr(precisions, recalls):
    processed_precisions = process_precisions(precisions)
    aupr = metrics.auc(recalls, processed_precisions)
    if np.isnan(aupr):
        return 0
    return aupr

In [15]:
def get_results(model_path, gt_dict, verbose=False):
    predictions_path = os.path.join(model_path, "nbest_predictions_.json")
    name = model_path.split("/")[-1]

    pred_dict = load_json(predictions_path)

    assert sorted(list(pred_dict.keys())) == sorted(list(gt_dict.keys()))

    precisions, recalls, confs = get_precisions_recalls(pred_dict, gt_dict)
    prec_at_90_recall, _ = get_prec_at_recall(precisions, recalls, confs, recall_thresh=0.9)
    prec_at_80_recall, _ = get_prec_at_recall(precisions, recalls, confs, recall_thresh=0.8)
    aupr = get_aupr(precisions, recalls)

    if verbose:
        print("AUPR: {:.3f}, Precision at 80% Recall: {:.3f}, Precision at 90% Recall: {:.3f}".format(aupr, prec_at_80_recall, prec_at_90_recall))

    # now save results as a dataframe and return
    results = {"name": name, "aupr": aupr, "prec_at_80_recall": prec_at_80_recall, "prec_at_90_recall": prec_at_90_recall}
    return results

In [16]:
test_json_path = "./data/test.json"
model_path = "./models/deberta-v2-xlarge"
save_dir = "./results"
if not os.path.exists(save_dir): os.mkdir(save_dir)

In [17]:
gt_dict = load_json(test_json_path)
gt_dict = get_answers(gt_dict)

In [None]:
for k,v in gt_dict.items():
    print(k,v)

In [25]:
results = get_results(model_path, gt_dict, verbose=True)

AUPR: 0.478, Precision at 80% Recall: 0.440, Precision at 90% Recall: 0.178


In [27]:
results

{'name': 'deberta-v2-xlarge',
 'aupr': 0.4779347388727093,
 'prec_at_80_recall': 0.4403155490969483,
 'prec_at_90_recall': 0.17828358208955225}

In [28]:
save_path = os.path.join(save_dir, "{}.json".format(model_path.split("/")[-1]))
save_path

'./results/deberta-v2-xlarge.json'

In [29]:
# with open(save_path, "w") as f:
    # f.write("{}\n".format(results))