In [1]:
import os
import re
from pprint import pprint

In [2]:
# helper to parse the ann files
def extract_relations(ann_path):
    res = []
    with open(ann_path, "r") as fin:
        for line in fin:
            items = line.strip().split("\t")
            if items[0].startswith("R"):
                match = re.match("(.*) Arg1:(.*) Arg2:(.*)", items[1])
                res.append({
                    "id": items[0],
                    "type": match[1],
                    "entities": (match[2], match[3]) if match[2] < match[3] else (match[3], match[2])
                })
    return res

In [3]:
def freq_counter(data, key):
    res = {}
    for item in data:
        k = key(item)
        if k not in res:
            res[k] = 0
        res[k] += 1
    return res

In [4]:
# evaluate the biobert prediction against the annotated results
def evaluate(pred_dir, target_dir, pub_num):
    pred = extract_relations(os.path.join(pred_dir, f"{pub_num}/{pub_num}.ann"))
    target = extract_relations(os.path.join(target_dir, f"{pub_num}/{pub_num}.ann"))
    pred = {x["entities"]: x["type"] for x in pred}
    target = {x["entities"]: x["type"] for x in target}
    
    not_in_target = 0
    not_in_target_list = []
    not_found = len(target)
    not_found_list = []
    found_correct = 0
    found_correct_list = []
    found_incorrect = 0
    found_incorrect_list = []
    
    for pk, pv in pred.items():
        if pk in target or (pk[1], pk[0]) in target:
            tv = target[pk]
            tk = pk
            if pv == tv:
                found_correct += 1
                found_correct_list.append((pk, pv, tk, tv))
            else:
                found_incorrect += 1
                found_incorrect_list.append((pk, pv, tk, tv))
            del target[tk]
            not_found -= 1
        else:
            not_in_target += 1
            not_in_target_list.append((pk, pv))
            
    print(f"\n# correct predictions:\t{found_correct}")
    pprint(freq_counter(found_correct_list, lambda x: x[1]))
    print(f"\n# incorrect predictions:\t{found_incorrect}")
    pprint(freq_counter(found_incorrect_list, lambda x: (x[1], x[3])))
    print(f"\n# extra predictions (false positive):\t{not_in_target}")
    pprint(freq_counter(not_in_target_list, lambda x: x[1]))
    print(f"\n# targets not found (false negative):\t{not_found}")
    pprint(freq_counter([(k, v) for k, v in target.items()], lambda x: x[1]))

In [5]:
pred_dir = "/sbksvol/data/acs-data/acs-re/acs-20210530-gold-3layer-e2e-2"
target_dir = "/sbksvol/data/acs-data/acs-re/acs-20210530-gold-target"
evaluate(pred_dir, target_dir, "sb300091d")
print()
evaluate(pred_dir, target_dir, "sb4001382")


# correct predictions:	24
{'DownRegulator': 1, 'Substrate': 17, 'UpRegulator': 6}

# incorrect predictions:	1
{('Substrate', 'UpRegulator'): 1}

# extra predictions (false positive):	3
{'Substrate': 3}

# targets not found (false negative):	23
{'DownRegulator': 1, 'Substrate': 5, 'UpRegulator': 17}


# correct predictions:	95
{'DownRegulator': 24, 'Substrate': 71}

# incorrect predictions:	3
{('Substrate', 'DownRegulator'): 3}

# extra predictions (false positive):	18
{'DownRegulator': 4, 'Substrate': 14}

# targets not found (false negative):	70
{'DownRegulator': 13, 'Substrate': 56, 'UpRegulator': 1}


In [9]:
pred_dir = "/sbksvol/data/acs-data/acs-re/acs-20210530-gold-3-1024-1"
target_dir = "/sbksvol/data/acs-data/acs-re/acs-20210530-gold-target"
evaluate(pred_dir, target_dir, "sb300091d")
print()
evaluate(pred_dir, target_dir, "sb4001382")


# correct predictions:	18
{'DownRegulator': 1, 'Substrate': 12, 'UpRegulator': 5}

# incorrect predictions:	2
{('Substrate', 'UpRegulator'): 2}

# extra predictions (false positive):	6
{'Substrate': 6}

# targets not found (false negative):	28
{'DownRegulator': 1, 'Substrate': 10, 'UpRegulator': 17}


# correct predictions:	101
{'DownRegulator': 26, 'Substrate': 75}

# incorrect predictions:	1
{('DownRegulator', 'Substrate'): 1}

# extra predictions (false positive):	34
{'DownRegulator': 11, 'Substrate': 23}

# targets not found (false negative):	66
{'DownRegulator': 14, 'Substrate': 51, 'UpRegulator': 1}


In [10]:
# not used for evaluation
# for labeling false positive and conflicts
def get_fp_and_dual(pred_dir, target_dir, pub_num):
    pred = extract_relations(os.path.join(pred_dir, f"{pub_num}/{pub_num}.ann"))
    target = extract_relations(os.path.join(target_dir, f"{pub_num}/{pub_num}.ann"))
    
    def transform(x):
        y = {}
        for item in x:
            ent = item["entities"]
            if ent not in y:
                y[ent] = {
                    "id": [],
                    "type": []
                }
            y[ent]["id"].append(item["id"])
            y[ent]["type"].append(item["type"])
        return y
    
    target_dict = transform(target)
    pred_dict = transform(pred)
    
    res = []
    for pk, pv in pred_dict.items():
        if pk not in target_dict:
            for rid, rtype in zip(pv["id"], pv["type"]):
                res.append({
                    "entities": pk,
                    "id": rid,
                    "type": rtype + "_biobert",
                })
        else:
            for rtype in pv["type"]:
                tv = target_dict[pk]
                if len(tv["type"]) >= 2:
                    for rid, rtype in zip(pv["id"], pv["type"]):
                        res.append({
                            "entities": pk,
                            "id": rid,
                            "type": rtype + "_biobert",
                        })
                    for rid, rtype in zip(tv["id"], tv["type"]):
                        res.append({
                            "entities": pk,
                            "id": rid,
                            "type": rtype,
                        })
                    break
    return res

In [11]:
def to_ann_string(data):
    res = []
    for i, item in enumerate(data):
        res.append(f"R{i}\t{item['type']} Arg1:{item['entities'][0]} Arg2:{item['entities'][1]}")
    return "\n".join(res)

In [13]:
pred_dir = "/sbksvol/data/acs-data/acs-re/acs-20210530-gold-3layer-e2e-2"
target_dir = "/sbksvol/data/acs-data/acs-re/acs-20210530-gold-target"
print(to_ann_string(get_fp_and_dual(pred_dir, target_dir, "sb4001382")))
print()
print(to_ann_string(get_fp_and_dual(pred_dir, target_dir, "sb300091d")))

R0	Substrate_biobert Arg1:T315 Arg2:T952
R1	Substrate_biobert Arg1:T1016 Arg2:T1119
R2	Substrate_biobert Arg1:T626 Arg2:T740
R3	Substrate_biobert Arg1:T664 Arg2:T740
R4	Substrate_biobert Arg1:T515 Arg2:T96
R5	Substrate_biobert Arg1:T487 Arg2:T839
R6	Substrate_biobert Arg1:T661 Arg2:T991
R7	DownRegulator_biobert Arg1:T100 Arg2:T580
R8	Substrate_biobert Arg1:T131 Arg2:T514
R9	Substrate_biobert Arg1:T485 Arg2:T920
R10	DownRegulator_biobert Arg1:T151 Arg2:T851
R11	Substrate_biobert Arg1:T1093 Arg2:T609
R12	DownRegulator_biobert Arg1:T1022 Arg2:T776
R13	Substrate_biobert Arg1:T394 Arg2:T688
R14	Substrate_biobert Arg1:T1056 Arg2:T933
R15	Substrate_biobert Arg1:T181 Arg2:T371
R16	Substrate_biobert Arg1:T255 Arg2:T310
R17	DownRegulator_biobert Arg1:T475 Arg2:T897

R0	Substrate_biobert Arg1:T599 Arg2:T682
R1	Substrate_biobert Arg1:T184 Arg2:T682
R2	Substrate_biobert Arg1:T67 Arg2:T685
