# Inter-annotator agreement for coreference annotation

Algorithms: "Passonneau, 2004, Computing Reliability for Coreference Annotation" and "Krippendorff's alpha"

Krippendorff’s alpha is fundamentally a Kappa-like metric. Its values range from -1 to 1, with 1 representing unanimous agreement between the raters, 0 indicating they’re guessing randomly, and negative values suggesting the raters are systematically disagreeing. 

(This can happen when raters value different things — for example, if rater A thinks a crowded store is a sign of success, but rater B thinks it proves understaffing and poor management). 

## Prepare

In [79]:
import sys
sys.path.append("../../src")

import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from collections import defaultdict

from IPython.display import display, HTML
# display(HTML(df.to_html()))

from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Event
from common_utils.data_loader_utils import load_mimic_cxr_bySection
from common_utils.coref_utils import resolve_mention_and_group_num
from common_utils.file_checker import FileChecker
from common_utils.common_utils import check_and_create_dirs, check_and_remove_dirs

FILE_CHECKER = FileChecker()
START_EVENT = Event()

mpl.style.use("default")

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [80]:
from hydra import compose, initialize
from omegaconf import OmegaConf

config = None
with initialize(version_base=None, config_path="../config", job_name="nlp_ensemble"):
        config = compose(config_name="coreference_resolution", overrides=["+nlp_ensemble@_global_=mimic_cxr"])

In [81]:
from krippendorff import alpha
import numpy as np
from typing import Any

## Example 1

Example from "Passonneau, 2004, Computing Reliability for Coreference Annotation"

The results reported in Passonneau:
- Unweighted Krippendroff's alpha: 0.45 (Our:0.45)
- Weighted Krippendroff's alpha: 0.74 (Our:0.76)

In [82]:
reliability_data =[
    [1,2,4,1,2,2,6,4,6,6,2],
    [1,2,5,1,2,2,7,9,10,7,2],
    [1,2,3,1,2,2,8,3,3,3,2]
]

w = """0	1	1	1	1	1	1	1	1	1
1	0	1	1	1	1	1	1	1	1
1	1	0	0.33	0.33	0.67	0.67	1	0.33	0.33
1	1	0.33	0	0.33	1	1	1	0.33	1
1	1	0.33	0.33	0	1	1	1	1	1
1	1	0.67	1	1	0	0.33	0.33	1	0.33
1	1	0.67	1	1	0.33	0	0.33	1	1
1	1	1	1	1	0.33	0.33	0	1	1
1	1	0.33	0.33	1	1	1	1	0	1
1	1	0.33	1	1	0.33	1	1	1	0"""
weights = np.array([[np.float64(cell) if cell else np.float64(0) for cell in row.split("\t")] for row in w.split("\n")])

def weight_matrix(v1: np.ndarray, v2: np.ndarray, dtype: Any = np.float64, **kwargs) -> np.ndarray:  # noqa
    """ A matirx of weights for nominal label pairs """
    return weights.astype(dtype)

print("Unweighted Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement="nominal"), 6))
print("Weighted Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement=weight_matrix), 6))

Unweighted Krippendroff's alpha: 0.449541
Weighted Krippendroff's alpha: 0.763529


## Example 2

Example from http://en.wikipedia.org/wiki/Krippendorff's_Alpha\

Use np.nan for category like “cannot code,” “no answer,” or “lacking an observation.”

In [83]:
reliability_data = (
    "*    *    *    *    *    3    4    1    2    1    1    3    3    *    3", # coder A
    "1    *    2    1    3    3    4    3    *    *    *    *    *    *    *", # coder B
    "*    *    2    1    3    4    4    *    2    1    1    3    3    *    4", # coder C
)
reliability_data = [d.split() for d in reliability_data]  # convert to 2D list of string items
reliability_data = [[int(cell) if cell!="*" else np.nan for cell in row] for row in reliability_data]
print(reliability_data, "\n")

print("nominal Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement="nominal"), 6))
print("ordinal Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement="ordinal"), 6))
print("interval Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement="interval"), 6))
print("ratio Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement="ratio"), 6))

[[nan, nan, nan, nan, nan, 3, 4, 1, 2, 1, 1, 3, 3, nan, 3], [1, nan, 2, 1, 3, 3, 4, 3, nan, nan, nan, nan, nan, nan, nan], [nan, nan, 2, 1, 3, 4, 4, nan, 2, 1, 1, 3, 3, nan, 4]] 

nominal Krippendroff's alpha: 0.691358
ordinal Krippendroff's alpha: 0.806721
interval Krippendroff's alpha: 0.810845
ratio Krippendroff's alpha: 0.808944


## Resolve BRAT output

Manual:
- ../../output/brat_annotation/round2_merged_new
- ../../output/brat_annotation/MIMIC_manual_Hantao
- ../../output/brat_annotation/MIMIC_manual_Irena

Semiauto:
- ../../output/brat_annotation/round3_merged_new
- ../../output/brat_annotation/MIMIC_semiauto_Hantao
- ../../output/brat_annotation/MIMIC_semiauto_Irena



In [84]:
### Modify ###
# Should not end with "/"
# brat_source_dirs = ["../../output/brat_annotation/round1_IS/",
#                     "../../output/brat_annotation/round1_merged_new"]
brat_source_dirs = ["../../output/brat_annotation/MIMIC_manual_Hantao",
                    "../../output/brat_annotation/round2_merged_new"]


In [85]:
def find_sub_list(sublist,source_list) -> tuple[int,int]:
    """ Returns: start index, end index (inclusive) """
    sll=len(sublist)
    for ind in (i for i,e in enumerate(source_list) if e==sublist[0]):
        if source_list[ind:ind+sll]==sublist:
            return ind,ind+sll-1

In [86]:
class BratMention:
    def __init__(self,uid,brat_id,start,end, mention_str) -> None:
        self.uid = uid # {section_name}_{doc_id} 
        self.brat_id = brat_id
        self.tok_start = start
        self.tok_end = end # Not inclusive
        self.mention_str = mention_str
    
    def __eq__(self, __o: object) -> bool:
        if isinstance(__o, BratMention):
            return self.tok_start == __o.tok_start and self.tok_end == __o.tok_end and self.uid == __o.uid
        elif isinstance(__o, str) and __o.startswith("T"):
            return self.brat_id == __o

    def __hash__(self) -> int:
        return self.tok_start.__hash__() + self.tok_end.__hash__()
    
    def __repr__(self) -> str:
        return self.__str__()

    def __str__(self) -> str:
        return f"{self.mention_str}({self.brat_id})({self.tok_start}-{self.tok_end})({self.uid})"
        
        
class BratCorefGroup:
    def __init__(self) -> None:
        self.coreference_list:list[BratCoreference] = []
        self.mention_list:list[BratMention] = []

    def add(self, ment_a:BratMention, ment_b:BratMention):
        to_be_merged_idxs = set()
        for group_id, _brat_coref_obj in enumerate(self.coreference_list):
            if _brat_coref_obj.hasMention(ment_a) or _brat_coref_obj.hasMention(ment_b):
                to_be_merged_idxs.add(group_id)
                _brat_coref_obj.update([ment_a,ment_b])
                
        if len(to_be_merged_idxs) == 0:
            # Not exist in curr group, thus create a new BratCoreference
            self.coreference_list.append(BratCoreference({ment_a,ment_b}))
        elif len(to_be_merged_idxs) > 1:
            # Exist in multiple groups, thus need to merge
            new_coref_set = set()
            to_be_removed = list(to_be_merged_idxs.copy())
            while to_be_merged_idxs:
                old_coref_set = self.coreference_list[to_be_merged_idxs.pop()].mention_set
                new_coref_set = new_coref_set.union(old_coref_set)
            for index in sorted(to_be_removed, reverse=True):
                del self.coreference_list[index]
            self.coreference_list.append(BratCoreference(new_coref_set))

    def __str__(self) -> str:
        out = []
        for coref_obj in self.coreference_list:
            out.append("\n".join(map(str, coref_obj.mention_set)))
        return "\n\n".join(out)
        
    
class BratCoreference:
    def __init__(self, mentionSet:set) -> None:
        self.aggregrate_id = None
        self.mention_set:set[BratMention] = mentionSet
        
    def hasMention(self, mention_x):
        return mention_x in self.mention_set
    
    def update(self, elements):
        self.mention_set.update(elements)
        
    def __eq__(self, __o: object) -> bool:
        if isinstance(__o, BratCoreference):
            return self.mention_set == __o.mention_set
        return False

    def __hash__(self) -> int:
        hash_value = 0
        for mention in self.mention_set:
            hash_value += hash(mention)
        return hash_value
    
    def __str__(self) -> str:
        return "|".join(map(str, self.mention_set))

    def __repr__(self) -> str:
        return self.__str__()

In [87]:
def resolve_brat_file(ann_file_list,section_name="",doc_id="") -> tuple[list[BratMention],BratCorefGroup]:
    mention_list:list[BratMention] = []
    bratCorefGroup_obj = BratCorefGroup()
    for line in [line.strip() for line in ann_file_list]:
        line_info_list = line.split("\t")
        # print(line_info_list)
        if line[0] == "T":
            # Mention
            mention_id = line_info_list[0]
            ment_start = line_info_list[1].split(" ")[1]
            ment_end = line_info_list[1].split(" ")[-1]
            mention_str = line_info_list[2]
            uid = section_name+"_"+doc_id
            mention_list.append(BratMention(uid,mention_id,ment_start,ment_end, mention_str))
        elif line[0] == "R":
            # relation
            relation_id = line_info_list[0]
            mention_a_id = line_info_list[1].split(" ")[1].split(":")[-1]
            mention_b_id = line_info_list[1].split(" ")[2].split(":")[-1]
            mention_a = mention_list[mention_list.index(mention_a_id)]
            mention_b = mention_list[mention_list.index(mention_b_id)]
            bratCorefGroup_obj.add(mention_a, mention_b)
    return mention_list, bratCorefGroup_obj

In [88]:
def get_reliability_data(aggregrated_mentions, annotator_output_dict, aggregrated_coreferences, use_single_coding=False) -> list[list[int]]:
    """ If use single coding, the np.nan will be replace by a new norminal id.
    e.g. Use np.nan: [0,0,1,1,nan,2,2]; use single_coding: [0,0,1,1,3,2,2]
    """
    reliability_data:list[list[int]] = []
    for annotator, bratCorefGroup_obj_list in annotator_output_dict.items():
        reliability_data_row:list[int] = []
        annotator_mention_list = [_mention_obj for bratCorefGroup_obj in bratCorefGroup_obj_list for _mention_obj in bratCorefGroup_obj.mention_list]
        annotator_coref_list = [_coref_obj for bratCorefGroup_obj in bratCorefGroup_obj_list for _coref_obj in bratCorefGroup_obj.coreference_list]
        for mention_obj in sorted(aggregrated_mentions, key=lambda x: x.uid):
            cell_value = np.nan
            if mention_obj in annotator_mention_list:
                # Find out that the mention belongs to which coref_obj
                for coref_obj in annotator_coref_list:
                    if coref_obj.hasMention(mention_obj):
                        cell_value = coref_obj.aggregrate_id
            # Replace nan with a new coding (consider mention as a coref singleton)
            if np.isnan(cell_value) and use_single_coding:
                brat_coref_obj = BratCoreference({mention_obj})
                brat_coref_obj.aggregrate_id = len(aggregrated_coreferences)
                cell_value = len(aggregrated_coreferences)
                aggregrated_coreferences.add(brat_coref_obj)
                
            reliability_data_row.append(cell_value)
        reliability_data.append(reliability_data_row)
    return reliability_data

In [89]:
def get_weights(aggregrated_coreferences) -> list[list[float]]:
    weights_2dList:list[list[float]] = []
    for coref_obj_row in sorted(aggregrated_coreferences,key=lambda x: x.aggregrate_id):
        weights_row:list[float] = []
        for coref_obj_col in sorted(aggregrated_coreferences,key=lambda x: x.aggregrate_id):
            weight = None
            if coref_obj_col == coref_obj_row:
                weight = 0
            elif coref_obj_col.mention_set.issubset(coref_obj_row.mention_set) or coref_obj_col.mention_set.issuperset(coref_obj_row.mention_set):
                weight = 0.33
            elif not coref_obj_col.mention_set.isdisjoint(coref_obj_row.mention_set):
                weight = 0.67
            else:
                weight = 1
            weights_row.append(weight)
        weights_2dList.append(weights_row)
    return weights_2dList

Compute the alpha for each document, and then get the mean of alphas

There are two option when computing the alpha:
1. Use np.nan for mention that are not observed ->  get_reliability_data(... use_single_coding=False)
2. Consider the mention as a singleton of coreference -> get_reliability_data(... use_single_coding=True)



In [90]:
annotator_output_dict:dict[str,list[BratCorefGroup]] = defaultdict(list)
aggregrated_mentions:set[BratMention] = set()
aggregrated_coreferences:set[BratCoreference] = set()

for section_name in ["findings","impression"]:
    dir_for_docid = os.path.join(brat_source_dirs[0], section_name)
    for doc_id in [f.rstrip(".txt") for f in FILE_CHECKER.filter(os.listdir(dir_for_docid)) if ".txt" in f]:
        # if doc_id != "s50873220" or section_name != "findings":
        #     continue
        
        # Aggregrate the outputs of multiple annotators
        for brat_source_dir in brat_source_dirs:
            annotator = os.path.basename(brat_source_dir)
            brat_dir = os.path.join(brat_source_dir, section_name)
            
            # brat outputs
            with open(os.path.join(brat_dir, doc_id+".txt"), "r", encoding="UTF-8") as f:
                txt_file_str = "".join(f.readlines())
            with open(os.path.join(brat_dir, doc_id+".ann"), "r", encoding="UTF-8") as f:
                ann_file_list = f.readlines()
            
            # Resolve brat files
            mention_list, bratCorefGroup_obj = resolve_brat_file(ann_file_list, section_name, doc_id)
            
            bratCorefGroup_obj.mention_list = mention_list
            annotator_output_dict[annotator].append(bratCorefGroup_obj)
            aggregrated_mentions.update(mention_list)
            aggregrated_coreferences.update(bratCorefGroup_obj.coreference_list)

        
# Assign id to BratCoreference
for coref_id, _bratCoref_obj in enumerate(aggregrated_coreferences):
    _bratCoref_obj.aggregrate_id = coref_id
    # The _bratCoref_obj are different class objects in annotator_output_dict. We need to replace them with the same one.
    for annotator, bratCorefGroup_obj_list in annotator_output_dict.items():
        for bratCorefGroup_obj in bratCorefGroup_obj_list:
            try:
                idx = bratCorefGroup_obj.coreference_list.index(_bratCoref_obj)
                bratCorefGroup_obj.coreference_list[idx] = _bratCoref_obj
            except ValueError:
                pass

reliability_data:list[list[int]] = get_reliability_data(aggregrated_mentions, annotator_output_dict, aggregrated_coreferences, use_single_coding=True)

weights_2dList:list[list[float]] = get_weights(aggregrated_coreferences)
weights_np = np.array(weights_2dList)
def weight_matrix(v1: np.ndarray, v2: np.ndarray, dtype: Any = np.float64, **kwargs) -> np.ndarray:  # noqa
    """ A matirx of weights for nominal label pairs """
    return weights_np.astype(dtype)


print("Unweighted Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement="nominal"), 6))
print("Weighted Krippendroff's alpha:",round(alpha(reliability_data=reliability_data, level_of_measurement=weight_matrix), 6))

Unweighted Krippendroff's alpha: 0.229023
Weighted Krippendroff's alpha: 0.71748


### Check the intermediate data for calculation

In [91]:
print("The canonical form of reliability data:")
for annotator, row in zip([ann for ann, _ in annotator_output_dict.items()],reliability_data):
    print(annotator, row)

print("\nCompared one by one",[f"{i}|{j}" for i,j in zip(reliability_data[0],reliability_data[1])])

print("\n",weights_np)

The canonical form of reliability data:
MIMIC_manual_Hantao [239, 240, 241, 242, 243, 244, 245, 186, 186, 246, 247, 199, 199, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 66, 232, 66, 66, 232, 258, 259, 33, 33, 142, 89, 260, 261, 142, 89, 262, 263, 142, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 85, 276, 277, 85, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 174, 289, 174, 290, 291, 95, 95, 292, 293, 294, 295, 296, 297, 130, 130, 130, 298, 299, 300, 301, 302, 161, 161, 226, 303, 304, 226, 305, 306, 307, 308, 309, 310, 311, 312, 15, 64, 64, 15, 169, 169, 210, 210, 193, 313, 314, 112, 315, 316, 193, 112, 317, 318, 53, 319, 320, 53, 321, 322, 323, 103, 324, 233, 103, 233, 227, 325, 326, 327, 5, 227, 328, 5, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 61, 61, 339, 340, 98, 341, 342, 343, 344, 345, 346, 98, 108, 347, 348, 108, 349, 350, 36, 2, 2, 13, 351, 13, 36, 352, 206, 206, 54, 353, 354, 54, 355, 356, 357, 358, 27, 27, 55, 55, 359, 113, 360, 361, 362