## Goals
 - Take the merged predictions and evaluate the prediction accuracy using the 2 different approaches
 1. Look at the anaphora tags and then cross-reference co-reference labels
 2. Use the co-reference chains directly

In [2]:
import dill
from FindFiles import find_files
from Settings import Settings
from CoRefHelper import EMPTY
from collections import defaultdict
from BrattEssay import ANAPHORA

DATASET = "CoralBleaching" # CoralBleaching | SkinCancer
PARTITION = "Training" # Training | Test

settings = Settings()
root_folder = settings.data_directory + DATASET + "/Thesis_Dataset/"
merged_predictions_folder = root_folder + "CoReference/"

Results Dir: /Users/simon.hughes/Google Drive/Phd/Results/
Data Dir:    /Users/simon.hughes/Google Drive/Phd/Data/
Root Dir:    /Users/simon.hughes/GitHub/NlpResearch/
Public Data: /Users/simon.hughes/GitHub/NlpResearch/Data/PublicDatasets/


In [3]:
essay_files = find_files(merged_predictions_folder)
if PARTITION == "Training":
    essay_files = [e for e in essay_files if "train" in e]
assert len(essay_files) == 1
with open(essay_files[0], "rb") as f:
    essays = dill.load(f)
len(essays)

902

In [4]:
essay_files

['/Users/simon.hughes/Google Drive/Phd/Data/CoralBleaching/Thesis_Dataset/CoReference/training_processed.dill']

### Validate the Lengths

In [5]:
for e in essays:    
    # map coref ids to sent_ix, wd_ix tuples
    # now look for ana tags that are also corefs, and cross reference
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        ana_tags = e.ana_tagged_sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        ner_tags = e.pred_ner_tags_sentences[sent_ix]
        pos_tags = e.pred_pos_tags_sentences[sent_ix]
        ptags    = e.pred_tagged_sentences[sent_ix]
        
        assert len(sent) == len(coref_ids)
                
        assert len(sent) == len(ana_tags) == len(coref_ids) == len(ner_tags) == len(pos_tags) == len(ptags),\
            (len(sent), len(ana_tags), len(coref_ids), len(ner_tags), len(pos_tags), len(ptags), e.name, sent_ix)
        assert len(sent) > 0

## Look at the Anaphor Tags

In [10]:
cc_tally = defaultdict(int)
cr_tally = defaultdict(int)
for e in essays:
    for sent in e.sentences:
        for wd, tags in sent:
            for t in tags:                
                if ANAPHORA in t and "other" not in t:
                    if "->" in t:
                        cr_tally[t] += 1
                    elif "Anaphor:[" in t:
                        cc_tally[t] += 1
sorted(cc_tally.items())

[('Anaphor:[11]', 6),
 ('Anaphor:[12]', 11),
 ('Anaphor:[13]', 31),
 ('Anaphor:[14]', 28),
 ('Anaphor:[1]', 68),
 ('Anaphor:[2]', 9),
 ('Anaphor:[3]', 39),
 ('Anaphor:[4]', 13),
 ('Anaphor:[50]', 44),
 ('Anaphor:[5]', 15),
 ('Anaphor:[5b]', 7),
 ('Anaphor:[6]', 18),
 ('Anaphor:[7]', 55)]

In [11]:
def build_chain(e):
    corefid_2_chain = defaultdict(list)
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        for wd_ix in range(len(sent)):
            wd_coref_ids = coref_ids[wd_ix] # Set[str]
            for cr_id in wd_coref_ids:
                pair = (sent_ix, wd_ix)
                corefid_2_chain[cr_id].append(pair)
    return corefid_2_chain

In [26]:
from processessays import Essay

tally = defaultdict(int)

ana_tagged_essays = []
for eix, e in enumerate(essays):
    
    ana_tagged_e = Essay(e.name, e.sentences)
    ana_tagged_e.pred_tagged_sentences = []
    ana_tagged_essays.append(ana_tagged_e)
    
    # map coref ids to sent_ix, wd_ix tuples
    corefid_2_chain = build_chain(e)
    
    # now look for ana tags that are also corefs, and cross reference
    for sent_ix in range(len(e.sentences)):
        ana_tagged_sent = []
        ana_tagged_e.pred_tagged_sentences.append(ana_tagged_sent)
        
        sent     = e.sentences[sent_ix]
        ana_tags = e.ana_tagged_sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        ner_tags = e.pred_ner_tags_sentences[sent_ix]
        pos_tags = e.pred_pos_tags_sentences[sent_ix]
        ptags    = e.pred_tagged_sentences[sent_ix]    
        
        for wd_ix in range(len(sent)):
            word = sent[wd_ix]
            is_ana_tag = ana_tags[wd_ix] == ANAPHORA
            wd_coref_ids = coref_ids[wd_ix] # Set[str]
            word, wd_tags = sent[wd_ix]
            pred_cc_tag = ptags[wd_ix]
            
            wd_ptags = set()
            if pred_cc_tag != EMPTY:
                wd_ptags.add(pred_cc_tag)
                
            ana_tagged_sent.append(wd_ptags)
            if is_ana_tag and len(wd_coref_ids) >= 1:
                for cr_id in wd_coref_ids:            
                    chain = corefid_2_chain[cr_id]
                    if len(chain) > 0:                        
                        for ch_sent_ix, ch_wd_ix in chain:
                            # if it's the current word, skip
                            if ch_sent_ix == sent_ix and ch_wd_ix == wd_ix:
                                continue
                            # for anaphors only - only look at chain ixs before the current word
                            # if's it's after the current word in the essay, skip
                            if ch_sent_ix > sent_ix or (ch_sent_ix == sent_ix and ch_wd_ix >= wd_ix):
                                continue
                            
                            chain_wd, chain_tags = e.sentences[ch_sent_ix][ch_wd_ix]
                            chain_ptag = e.pred_tagged_sentences[ch_sent_ix][ch_wd_ix]
                            if chain_ptag != EMPTY:
                                formatted_code =  "{anaphora}:[{code}]".format(
                                    anaphora=ANAPHORA, code=chain_ptag)
                                wd_ptags.add(formatted_code)
                                tally[formatted_code] +=1
                                
                                assert len(ana_ptags) <=1, ana_ptags
        

In [28]:
for e in ana_tagged_essays:
    assert len(e.sentences) == len(e.pred_tagged_sentences)
    for ix in range(len(e.sentences)):
        assert len(e.sentences[ix]) == len(e.pred_tagged_sentences[ix])

In [27]:
from pprint import pprint
pprint(sorted(tally.items()))

[('Anaphor:[11]', 2),
 ('Anaphor:[14]', 33),
 ('Anaphor:[1]', 23),
 ('Anaphor:[3]', 3),
 ('Anaphor:[50]', 22),
 ('Anaphor:[6]', 7)]


In [29]:
ana_tagged_essays[35].pred_tagged_sentences

[[set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  {'50'},
  {'50'},
  set(),
  {'50'},
  set(),
  {'50'},
  {'50'},
  {'50'},
  {'50'},
  {'50'},
  {'50'},
  {'50'},
  {'50'},
  set(),
  set(),
  set(),
  set(),
  set(),
  set()],
 [set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  {'50'},
  set(),
  set(),
  set(),
  {'50'},
  {'50'},
  {'50'},
  {'50'},
  {'50'},
  set()],
 [set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  {'50'},
  {'50'},
  set(),
  set(),
  set(),
  set(),
  set()],
 [set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set()],
 [set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set(),
  set