In [1]:
import sys
sys.path.append("../")

import pandas as pd
import pickle
import annotations
from utils import open_pickle
from collections import defaultdict

from sklearn.metrics import classification_report

In [2]:
PREDS_PATH = "../biobert_re/output/test_predictions.txt"
ACTUALS_PATH = "../biobert_re/dataset/test_labels_rel.pkl"

In [3]:
preds_df = pd.read_csv(PREDS_PATH, sep='\t')
actual_dicts = open_pickle(ACTUALS_PATH)

In [8]:
def convert_actual_labels(actual_dicts):
    actual_df_dict = defaultdict(list)
    
    for i, actual_dict in enumerate(actual_dicts):
        actual_df_dict["index"].append(i)
        actual_df_dict["label"].append(actual_dict['label'])
        actual_df_dict["relation"].append(actual_dict['relation'].name)

    df = pd.DataFrame(actual_df_dict)
    return df

def gen_classification_reports(df):
    report = {}
    
    print("-"*55)
    print("Overall Classification Report")
    print(classification_report(
            df.label.astype(int),
            df.prediction.astype(int)))
    
    overall_report = classification_report(df.label.astype(int), 
                                           df.prediction.astype(int), 
                                           output_dict = True)
    report["overall"] = overall_report
    
    for relation in df.relation.unique():
        sub_df = df[df.relation==relation]
        print("-"*55)
        print("Classification Report for {} Relation".format(relation))
        print(classification_report(
                sub_df.label.astype(int),
                sub_df.prediction.astype(int)))
        
        rel_report = classification_report(sub_df.label.astype(int), 
                                           sub_df.prediction.astype(int), 
                                           output_dict = True)
        report[relation] = rel_report
    
    print("-"*55)
    return report

In [9]:
actual_df = convert_actual_labels(actual_dicts)

In [10]:
final_df = actual_df.merge(preds_df, how="left", on="index", suffixes=("_actual", "_predicted"))

In [12]:
report = gen_classification_reports(final_df)

-------------------------------------------------------
Overall Classification Report
              precision    recall  f1-score   support

           0       0.98      0.97      0.98      5448
           1       0.98      0.99      0.98      8203

    accuracy                           0.98     13651
   macro avg       0.98      0.98      0.98     13651
weighted avg       0.98      0.98      0.98     13651

-------------------------------------------------------
Classification Report for Dosage-Drug Relation
              precision    recall  f1-score   support

           0       0.99      0.98      0.99       548
           1       0.99      0.99      0.99       864

    accuracy                           0.99      1412
   macro avg       0.99      0.99      0.99      1412
weighted avg       0.99      0.99      0.99      1412

-------------------------------------------------------
Classification Report for Form-Drug Relation
              precision    recall  f1-score   support

 