In [14]:
import pickle

def calculate_iou(set1, set2):
    """
    Calculate the Intersection over Union (IoU) of two sets.

    Args:
    set1 (set): The first set.
    set2 (set): The second set.

    Returns:
    float: The IoU of the two sets.
    """
    # Calculate the intersection
    intersection = set1.intersection(set2)

    # Calculate the union
    union = set1.union(set2)

    # Calculate the IoU
    iou = len(intersection) / len(union)

    return iou

file_path = 'data/val_result.pkl'

# Open the file in read-binary mode and load the list of dictionaries
with open(file_path, 'rb') as file:
    val_result_list = pickle.load(file)

In [15]:
result, label = val_result_list[0]

In [23]:
exact_match = 0

result_predicate_set = set()
result_subject_set = set()
result_object_set = set()
result_subject_type_set = set()
result_object_type_set = set()

label_predicate_set = set()
label_subject_set = set()
label_object_set = set()
label_subject_type_set = set()
label_object_type_set = set()

def try_add_to_set(s, d, key, element_name):
    try:
        s.add(d[key])
    except KeyError:
        print(f'KeyError: {key} not found in {element_name}')
    except Exception as e:
        print(f'Error with {element_name}: {(key, e)}')
        
for i, (result, label) in enumerate(val_result_list):
    if result == label:
        exact_match += 1

    if 'predicate' not in result:
        continue
    if 'subject' not in result:
        continue
    if 'object' not in result:
        continue
    
    try_add_to_set(result_predicate_set, result, 'predicate', 'result')
    try_add_to_set(result_subject_set, result, 'subject', 'result')
    try_add_to_set(result_object_set, result, 'object', 'result')
    try_add_to_set(result_subject_type_set, result, 'subject_type', 'result')
    try_add_to_set(result_object_type_set, result, 'object_type', 'result')

    try_add_to_set(label_predicate_set, label, 'predicate', 'label')
    try_add_to_set(label_subject_set, label, 'subject', 'label')
    try_add_to_set(label_object_set, label, 'object', 'label')
    try_add_to_set(label_subject_type_set, label, 'subject_type', 'label')
    try_add_to_set(label_object_type_set, label, 'object_type', 'label')

predicate_iou = calculate_iou(result_predicate_set, label_predicate_set)
subject_iou = calculate_iou(result_subject_set, label_subject_set)
object_iou = calculate_iou(result_object_set, label_object_set)
subject_type_iou = calculate_iou(result_subject_type_set, label_subject_type_set)
object_type_iou = calculate_iou(result_object_type_set, label_object_type_set)

perfect_match_rate = exact_match / len(val_result_list)

Error with result: ('object', TypeError("unhashable type: 'list'"))
Error with result: ('object_type', TypeError("unhashable type: 'list'"))


In [25]:
print("Predicate IOU: ", predicate_iou)
print("Subject IOU: ", subject_iou)
print("Object IOU: ", object_iou)
print("Subject Type IOU: ", subject_type_iou)
print("Object Type IOU: ", object_type_iou)

Predicate IOU:  0.5679012345679012
Subject IOU:  0.519280205655527
Object IOU:  0.24156588160407383
Subject Type IOU:  0.5555555555555556
Object Type IOU:  0.7857142857142857


In [26]:
result_list, label_list = zip(*val_result_list)

In [30]:
import pandas as pd

result_df, label_df = pd.DataFrame(result_list), pd.DataFrame(label_list)

In [32]:
result_df.head()

Unnamed: 0,subject,subject_type,object,object_type,predicate
0,胆囊炎,疾病,30%,流行病学,死亡率
1,骨性关节炎,疾病,踝关节和腕关节,部位,发病部位
2,急性胰腺炎,疾病,早期 ERCP,检查,辅助检查
3,乙型肝炎,疾病,HBsAg阳性的人性交时如果对方未接种疫苗或无自然免疫应采取防护，不应与他人共用牙刷或剃须刀...,社会学,高危因素
4,感染性心内膜炎,疾病,葡萄球菌,疾病,相关（导致）
