## Goals
 - Take the merged predictions and evaluate the prediction accuracy using the 2 different approaches
 1. Look at the anaphora tags and then cross-reference co-reference labels
 2. Use the co-reference chains directly

In [2]:
import dill
from FindFiles import find_files
from Settings import Settings
from CoRefHelper import EMPTY
from collections import defaultdict
from BrattEssay import ANAPHORA

DATASET = "CoralBleaching" # CoralBleaching | SkinCancer
PARTITION = "Training" # Training | Test

settings = Settings()
root_folder = settings.data_directory + DATASET + "/Thesis_Dataset/"
merged_predictions_folder = root_folder + "CoReference/"

Results Dir: /Users/simon.hughes/Google Drive/Phd/Results/
Data Dir:    /Users/simon.hughes/Google Drive/Phd/Data/
Root Dir:    /Users/simon.hughes/GitHub/NlpResearch/
Public Data: /Users/simon.hughes/GitHub/NlpResearch/Data/PublicDatasets/


In [3]:
essay_files = find_files(merged_predictions_folder)
if PARTITION == "Training":
    essay_files = [e for e in essay_files if "train" in e]
assert len(essay_files) == 1
with open(essay_files[0], "rb") as f:
    essays = dill.load(f)
len(essays)

902

In [4]:
essay_files

['/Users/simon.hughes/Google Drive/Phd/Data/CoralBleaching/Thesis_Dataset/CoReference/training_processed.dill']

### Validate the Lengths

In [5]:
for e in essays:    
    # map coref ids to sent_ix, wd_ix tuples
    # now look for ana tags that are also corefs, and cross reference
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        ana_tags = e.ana_tagged_sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        ner_tags = e.pred_ner_tags_sentences[sent_ix]
        pos_tags = e.pred_pos_tags_sentences[sent_ix]
        ptags    = e.pred_tagged_sentences[sent_ix]
        
        assert len(sent) == len(coref_ids)
                
        assert len(sent) == len(ana_tags) == len(coref_ids) == len(ner_tags) == len(pos_tags) == len(ptags),\
            (len(sent), len(ana_tags), len(coref_ids), len(ner_tags), len(pos_tags), len(ptags), e.name, sent_ix)
        assert len(sent) > 0

## Look at the Anaphor Tags

In [117]:
from results_procesor import is_a_regular_code

cc_tally = defaultdict(int)
cr_tally = defaultdict(int)
reg_tally = defaultdict(int)
for e in essays:
    for sent in e.sentences:
        for wd, tags in sent:
            for t in tags:
                if is_a_regular_code(t):
                    reg_tally[t] += 1
                if ANAPHORA in t and "other" not in t:
                    if "->" in t:
                        cr_tally[t] += 1
                    elif "Anaphor:[" in t:
                        cc_tally[t] += 1
sorted(cc_tally.items())

[('Anaphor:[11]', 6),
 ('Anaphor:[12]', 11),
 ('Anaphor:[13]', 31),
 ('Anaphor:[14]', 28),
 ('Anaphor:[1]', 68),
 ('Anaphor:[2]', 9),
 ('Anaphor:[3]', 39),
 ('Anaphor:[4]', 13),
 ('Anaphor:[50]', 44),
 ('Anaphor:[5]', 15),
 ('Anaphor:[5b]', 7),
 ('Anaphor:[6]', 18),
 ('Anaphor:[7]', 55)]

In [11]:
def build_chain(e):
    corefid_2_chain = defaultdict(list)
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        for wd_ix in range(len(sent)):
            wd_coref_ids = coref_ids[wd_ix] # Set[str]
            for cr_id in wd_coref_ids:
                pair = (sent_ix, wd_ix)
                corefid_2_chain[cr_id].append(pair)
    return corefid_2_chain

In [107]:
from processessays import Essay

def get_ana_tagged_essays(essays, format_ana_tags=True):
    ana_tagged_essays = []
    for eix, e in enumerate(essays):

        ana_tagged_e = Essay(e.name, e.sentences)
        ana_tagged_e.pred_tagged_sentences = []
        ana_tagged_essays.append(ana_tagged_e)

        # map coref ids to sent_ix, wd_ix tuples
        corefid_2_chain = build_chain(e)

        # now look for ana tags that are also corefs, and cross reference
        for sent_ix in range(len(e.sentences)):
            ana_tagged_sent = []
            ana_tagged_e.pred_tagged_sentences.append(ana_tagged_sent)

            sent     = e.sentences[sent_ix]
            ana_tags = e.ana_tagged_sentences[sent_ix]
            coref_ids= e.pred_corefids[sent_ix]
            ner_tags = e.pred_ner_tags_sentences[sent_ix]
            pos_tags = e.pred_pos_tags_sentences[sent_ix]
            ptags    = e.pred_tagged_sentences[sent_ix]    

            for wd_ix in range(len(sent)):
                word = sent[wd_ix]
                is_ana_tag = ana_tags[wd_ix] == ANAPHORA
                wd_coref_ids = coref_ids[wd_ix] # Set[str]
                word, wd_tags = sent[wd_ix]
                pred_cc_tag = ptags[wd_ix]

                wd_ptags = set()
                if pred_cc_tag != EMPTY:
                    wd_ptags.add(pred_cc_tag)

                ana_tagged_sent.append(wd_ptags)
                if is_ana_tag and len(wd_coref_ids) >= 1:
                    for cr_id in wd_coref_ids:            
                        chain = corefid_2_chain[cr_id]
                        if len(chain) > 0:                        
                            for ch_sent_ix, ch_wd_ix in chain:
                                # if it's the current word, skip
                                if ch_sent_ix == sent_ix and ch_wd_ix == wd_ix:
                                    continue
                                # for anaphors only - only look at chain ixs before the current word
                                # if's it's after the current word in the essay, skip
                                if ch_sent_ix > sent_ix or (ch_sent_ix == sent_ix and ch_wd_ix >= wd_ix):
                                    continue

                                chain_wd, chain_tags = e.sentences[ch_sent_ix][ch_wd_ix]
                                chain_ptag = e.pred_tagged_sentences[ch_sent_ix][ch_wd_ix]
                                if chain_ptag != EMPTY:
                                    code = chain_ptag
                                    if format_ana_tags:
                                        code =  "{anaphora}:[{code}]".format(
                                            anaphora=ANAPHORA, code=chain_ptag)
                                    wd_ptags.add(code)
    # validation check    
    #   check essay and sent lengths align
    for e in ana_tagged_essays:
        assert len(e.sentences) == len(e.pred_tagged_sentences)
        for ix in range(len(e.sentences)):
            assert len(e.sentences[ix]) == len(e.pred_tagged_sentences[ix])
    return ana_tagged_essays

In [108]:
%%time
ana_tagged_essays = get_ana_tagged_essays(essays)

CPU times: user 420 ms, sys: 10.2 ms, total: 430 ms
Wall time: 429 ms


In [109]:
# map new tags to existing labels, not anaphora labels
collapsed_ana_tagged_essays = get_ana_tagged_essays(essays, format_ana_tags=False)

In [52]:
from results_procesor import ResultsProcessor

# Modify this function from the Resultsprocessor so that it works with Set[str] of predicted tags 
# as well as scalar strings
def get_wd_level_preds(essays, expected_tags):
    expected_tags = set(expected_tags)
    ysbycode = defaultdict(list)
    for e in essays:
        for sentix in range(len(e.sentences)):
            p_ccodes = e.pred_tagged_sentences[sentix]
            for wordix in range(len(p_ccodes)):
                ptag_set = set(p_ccodes[wordix])  
                for exp_tag in expected_tags:
                    ysbycode[exp_tag].append(ResultsProcessor._ResultsProcessor__get_label_(exp_tag, ptag_set))
    return ysbycode

In [121]:
reg_tags = sorted(reg_tally.keys())
all_ana_tags = sorted(cc_tally.keys())
all_ana_tags

['Anaphor:[11]',
 'Anaphor:[12]',
 'Anaphor:[13]',
 'Anaphor:[14]',
 'Anaphor:[1]',
 'Anaphor:[2]',
 'Anaphor:[3]',
 'Anaphor:[4]',
 'Anaphor:[50]',
 'Anaphor:[5]',
 'Anaphor:[5b]',
 'Anaphor:[6]',
 'Anaphor:[7]']

## Compute Accuracy on Anaphora Tags Only - Word Level

In [125]:
from results_procesor import metrics_to_df

def get_df(mean_metrics):
    df = metrics_to_df(mean_metrics)
    df = df[["code","recall","precision","f1_score","data_points"]]
    df = df.sort_values("code")
    return df[~df.code.str.contains("MEAN")]

In [126]:
act_ys_bycode  = ResultsProcessor.get_wd_level_lbs(ana_tagged_essays,  expected_tags=all_ana_tags)
pred_ys_bycode = get_wd_level_preds(ana_tagged_essays, expected_tags=all_ana_tags)
mean_metrics = ResultsProcessor.compute_mean_metrics(act_ys_bycode, pred_ys_bycode)

get_df(mean_metrics)

Unnamed: 0,code,recall,precision,f1_score,data_points
2,Anaphor:[11],0.166667,1.0,0.285714,137166.0
6,Anaphor:[12],0.0,0.0,0.0,137166.0
10,Anaphor:[13],0.0,0.0,0.0,137166.0
0,Anaphor:[14],0.107143,0.5,0.176471,137166.0
12,Anaphor:[1],0.014706,0.1,0.025641,137166.0
1,Anaphor:[2],0.0,0.0,0.0,137166.0
5,Anaphor:[3],0.0,0.0,0.0,137166.0
8,Anaphor:[4],0.0,0.0,0.0,137166.0
9,Anaphor:[50],0.045455,0.2,0.074074,137166.0
3,Anaphor:[5],0.0,0.0,0.0,137166.0


## Accuracy with No Anaphora Tagging

In [131]:
reg_metrics = ResultsProcessor.compute_mean_metrics_from_tagged_essays(essays, expected_tags=reg_tags)
df = get_df(reg_metrics)
df[df.code == "MICRO_F1"]

Unnamed: 0,code,recall,precision,f1_score,data_points
17,MICRO_F1,0.820049,0.846703,0.833163,1783158.0


## Accuracy with Anaphora Tagging (Adding in Ana Tags as Regular Tags, not Anaphora[xyz]

In [132]:
act_ys_bycode  = ResultsProcessor.get_wd_level_lbs(collapsed_ana_tagged_essays,  expected_tags=reg_tags)
pred_ys_bycode = get_wd_level_preds(collapsed_ana_tagged_essays, expected_tags=reg_tags)
mean_metrics = ResultsProcessor.compute_mean_metrics(act_ys_bycode, pred_ys_bycode)

df = get_df(mean_metrics)
df[df.code == "MICRO_F1"]

Unnamed: 0,code,recall,precision,f1_score,data_points
17,MICRO_F1,0.820049,0.84602,0.832832,1783158.0


#### For CB Training Data, it Mildly Hurts the F1 Score (very slighly)