In [1]:
from csv import DictReader
from baseline.budi_et_al.model import Model
from baseline.budi_et_al.rule import Rule
from utils.scorers import MUCScorer, B3Scorer, AverageScorer
from utils.clusterers import BestFirstClusterer, get_anaphora_scores_by_antecedent, ClosestFirstClusterer

In [2]:
def convert_field(value: str):
    if value.isnumeric():
        return int(value)
    elif value.replace('.', '', 1).isnumeric():
        return float(value)
    
    return value

In [3]:
data_path = 'data/testing/mention_pairs_for_budi_et_al_implementation.csv'

data = []
m1_ids = []
m2_ids = []
labels = []

with open(data_path, 'r') as f:
    csv_file = DictReader(f)
    
    for row in csv_file:
        for field in row:
            row[field] = convert_field(row[field])
            
        data.append(row)
        m1_ids.append(row['m1_id'])
        m2_ids.append(row['m2_id'])
        labels.append([1-row['is_coreference'], row['is_coreference']])

label_chains = ClosestFirstClusterer().get_chains(get_anaphora_scores_by_antecedent(m1_ids, m2_ids, labels))

In [4]:
rule_fields = ['is_string_match', 'is_string_without_punctuation_match', 'is_abbreviation',
               'is_first_pronoun', 'is_second_pronoun', 'is_on_one_sentence', 'is_substring',
               'first_name_class', 'second_name_class']

rules = [Rule(**{field: rule[field] for field in rule_fields}) for rule in data]

In [5]:
model_path = 'models/budi_et_al/model.csv'

model = Model.load(model_path)

In [6]:
preds = [model.predict_proba(rule) for rule in rules]
preds = [(1-pred, pred) for pred in preds]

In [7]:
base_thresholds = [0.1, 0.01, 0.001, 0.0001, 0.00001]
thresholds = [0] + [base * multiplier for base in base_thresholds for multiplier in range(1, 10)]

thresholds = set([pred[1] for pred in preds])

muc_scorer = MUCScorer()
b3_scorer = B3Scorer()
average_scorer = AverageScorer([muc_scorer, b3_scorer])

def get_sorted_scores(clusterer, pred):
    scores = [] # will be a tuple (average_f1, (prec_muc, rec_muc, f1_muc), (prec_b3, rec_b3, f1_b3), threshold)
    
    for threshold in thresholds:
        predicted_chains = clusterer.get_chains(pred, threshold)
        
#         avg_f1 = average_scorer.get_scores(predicted_chains, label_chains)[2]
        muc = muc_scorer.get_scores(predicted_chains, label_chains)
        b3 = b3_scorer.get_scores(predicted_chains, label_chains)
        avg_f1 = (muc[2] + b3[2]) / 2
        
        scores.append((avg_f1, muc, b3, threshold))
    
    return sorted(scores, reverse=True)

def reorder_score(score):
    avg_f1, muc, b3, threshold = score
    return muc, b3, avg_f1, threshold

def evaluate():
    print('getting anaphora scores by antecedent dict')
    pred = get_anaphora_scores_by_antecedent(m1_ids, m2_ids, preds)
    
    print('get sorted_scores_without_sc_best')
    sorted_scores = get_sorted_scores(BestFirstClusterer(), pred)
    print('Without singleton classifier, best-first:', reorder_score(sorted_scores[0]))

In [8]:
evaluate()

getting anaphora scores by antecedent dict
get sorted_scores_without_sc_best
Without singleton classifier, best-first: ((0.4146341463414634, 0.6538461538461539, 0.5074626865671641), (0.3373749885310579, 0.5555102040816325, 0.41979696893063406), 0.46362982774889905, 0.2268041237113402)
