## Goals
 - Take the merged predictions and evaluate the prediction accuracy using the 2 different approaches
 1. Look at the anaphora tags and then cross-reference co-reference labels
 2. Use the co-reference chains directly

In [86]:
from FindFiles import find_files
from Settings import Settings
from CoRefHelper import EMPTY
from collections import defaultdict
import dill

ANAPHOR = "Anaphor"

DATASET = "CoralBleaching" # CoralBleaching | SkinCancer
PARTITION = "Training" # Training | Test

settings = Settings()
root_folder = settings.data_directory + DATASET + "/Thesis_Dataset/"
merged_predictions_folder = root_folder + "CoReference/"

Results Dir: /Users/simon.hughes/Google Drive/Phd/Results/
Data Dir:    /Users/simon.hughes/Google Drive/Phd/Data/
Root Dir:    /Users/simon.hughes/GitHub/NlpResearch/
Public Data: /Users/simon.hughes/GitHub/NlpResearch/Data/PublicDatasets/


In [27]:
essay_files = find_files(merged_predictions_folder)
if PARTITION == "Training":
    essay_files = [e for e in essay_files if "train" in e]
assert len(essay_files) == 1
with open(essay_files[0], "rb") as f:
    essays = dill.load(f)
len(essays)

902

In [32]:
essay_files

['/Users/simon.hughes/Google Drive/Phd/Data/CoralBleaching/Thesis_Dataset/CoReference/training_processed.dill']

### Validate the Lengths

In [70]:
for e in essays:    
    # map coref ids to sent_ix, wd_ix tuples
    # now look for ana tags that are also corefs, and cross reference
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        ana_tags = e.ana_tagged_sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        ner_tags = e.pred_ner_tags_sentences[sent_ix]
        pos_tags = e.pred_pos_tags_sentences[sent_ix]
        ptags    = e.pred_tagged_sentences[sent_ix]
        
        assert len(sent) == len(coref_ids)
                
        assert len(sent) == len(ana_tags) == len(coref_ids) == len(ner_tags) == len(pos_tags) == len(ptags),\
            (len(sent), len(ana_tags), len(coref_ids), len(ner_tags), len(pos_tags), len(ptags), e.name, sent_ix)
        assert len(sent) > 0

## Look at the Anaphor Tags

In [88]:

cc_tally = defaultdict(int)
cr_tally = defaultdict(int)
for e in essays:
    for sent in e.sentences:
        for wd, tags in sent:
            for t in tags:                
                if ANAPHOR in t and "other" not in t:
                    if "->" in t:
                        cr_tally[t] += 1
                    elif "Anaphor:[" in t:
                        cc_tally[t] += 1
sorted(cc_tally.items())

[('Anaphor:[11]', 6),
 ('Anaphor:[12]', 11),
 ('Anaphor:[13]', 31),
 ('Anaphor:[14]', 28),
 ('Anaphor:[1]', 68),
 ('Anaphor:[2]', 9),
 ('Anaphor:[3]', 39),
 ('Anaphor:[4]', 13),
 ('Anaphor:[50]', 44),
 ('Anaphor:[5]', 15),
 ('Anaphor:[5b]', 7),
 ('Anaphor:[6]', 18),
 ('Anaphor:[7]', 55)]

In [129]:
corefid_2_chain['1']

[(3, 9), (3, 10), (7, 0), (7, 1), (7, 7), (8, 0), (9, 14), (9, 15)]

In [154]:
for e in essays:    
    # map coref ids to sent_ix, wd_ix tuples
    corefid_2_chain = defaultdict(list)
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        for wd_ix in range(len(sent)):
            wd_coref_ids = coref_ids[wd_ix] # Set[str]
            for cr_id in wd_coref_ids:
                pair = (sent_ix, wd_ix)
                corefid_2_chain[cr_id].append(pair)
    
    # now look for ana tags that are also corefs, and cross reference
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        ana_tags = e.ana_tagged_sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        ner_tags = e.pred_ner_tags_sentences[sent_ix]
        pos_tags = e.pred_pos_tags_sentences[sent_ix]
        ptags    = e.pred_tagged_sentences[sent_ix]
        
        for wd_ix in range(len(sent)):
            word = sent[wd_ix]
            is_ana_tag = ana_tags[wd_ix] == ANAPHOR
            wd_coref_ids = coref_ids[wd_ix]
            word, wd_tags = sent[wd_ix]
            if is_ana_tag and len(wd_coref_ids) >= 1:
                for cr_id in wd_coref_ids:
                    chain = corefid_2_chain[cr_id]
                    if len(chain) > 0:
                        
                        pred_words = []
                        pred_tags = []
                        last_sent_ix = 0
                        for ch_sent_ix, ch_wd_ix in chain:
                            if ch_sent_ix == sent_ix and ch_wd_ix == wd_ix:
                                continue
                            # for anaphors only - only look at chain ixs before the current word
                            if ch_sent_ix > sent_ix or (ch_sent_ix == sent_ix and ch_wd_ix >= wd_ix):
                                continue
                            
                            chain_wd, chain_tags = e.sentences[ch_sent_ix][ch_wd_ix]
                            chain_ptag = e.pred_tagged_sentences[ch_sent_ix][ch_wd_ix]
                            pred_words.append(chain_wd)
                            pred_tags.append(chain_ptag)
                            
                            if last_sent_ix != ch_sent_ix:
                                PAD = 8
                                context_start = max(wd_ix-3,0)
                                context_end = min(wd_ix+3, len(e.sentences[sent_ix])-1)
                                context,_ = zip(*e.sentences[sent_ix][context_start: context_end])

                                print(word.upper(), wd_ix)
                                print(" ".join(context), [t for t in wd_tags if "->" not in t]) 
                                print(" ".join(map(lambda s: s.ljust(PAD), pred_words)))
                                print(" ".join(map(lambda s: s.ljust(PAD), pred_tags)))
                                print()
                                pred_words = []
                                pred_tags = []                                

                            last_sent_ix = ch_sent_ix
                        
                        print(word.upper(), wd_ix)
                        print(" ".join(context), [t for t in wd_tags if "->" not in t]) 
                        print(" ".join(map(lambda s: s.ljust(PAD), pred_words)))
                        print(" ".join(map(lambda s: s.ljust(PAD), pred_tags)))
                        print()
                
        assert len(sent) == len(ana_tags) == len(coref_ids) == len(ner_tags) == len(pos_tags) == len(ptags)\
            ,((len(sent), len(ana_tags), len(coref_ids), len(ner_tags), len(pos_tags), len(ptags)))
        assert len(sent) > 0

IT 1
they also give []



THIS 1
they also give ['Anaphor', 'Result', 'Anaphor:[5]', 'Result:Anaphor']



THIS 0
they also give []



IT 0
it turns white []
a       
50      

IT 0
it turns white []
coral    it       it      
50       50       Empty   

THIS 0
it turns white []



THIS 0
it turns white []



THESE 0
these winds drag ['Causer', 'Causer:1']
weaker  
1       

THESE 0
these winds drag ['Causer', 'Causer:1']
trade    winds   
1        1       

THIS 0
this shifting winds []
the     
Empty   

THIS 0
this shifting winds []
shifting trade    winds   
1        1        1       

IT 0
it causes the ['Anaphor', 'Anaphor:[3]', 'Causer:Anaphor', 'Causer']
the     
Empty   

IT 0
it causes the ['Anaphor', 'Anaphor:[3]', 'Causer:Anaphor', 'Causer']
ocean    INFREQUENT
Empty    Empty   

IT 0
it will cause []
coral   
50      

IT 0
it will cause []
bleaching that     it      
50       Empty    Empty   

IT 0
it will cause []
coral    bleaching it      
50       50       Empty   

T

In [80]:
chain

[(11, 1)]

In [47]:
from collections import Counter
Counter(all_ana_tags)

Counter({'Anaphor': 301, 'Empty': 136865})

In [44]:
e.pred_corefids[0]

[set(), set(), set(), set(), set(), set(), set(), set(), set()]