## Goals
 - Take the merged predictions and evaluate the prediction accuracy using the 2 different approaches
 1. Look at the anaphora tags and then cross-reference co-reference labels
 2. Use the co-reference chains directly

In [84]:
import dill
import pandas as pd
from FindFiles import find_files
from Settings import Settings
from CoRefHelper import EMPTY
from collections import defaultdict
from BrattEssay import ANAPHORA
from results_procesor import ResultsProcessor
from results_procesor import metrics_to_df

# progress bar widget
from ipywidgets import IntProgress
from IPython.display import display

DATASET = "CoralBleaching" # CoralBleaching | SkinCancer
PARTITION = "Training" # Training | Test

settings = Settings()
root_folder = settings.data_directory + DATASET + "/Thesis_Dataset/"
stanford_coref_predictions_folder = root_folder + "CoReference/"
print("CoRef Data: ", stanford_coref_predictions_folder)

Results Dir: /Users/simon.hughes/Google Drive/Phd/Results/
Data Dir:    /Users/simon.hughes/Google Drive/Phd/Data/
Root Dir:    /Users/simon.hughes/GitHub/NlpResearch/
Public Data: /Users/simon.hughes/GitHub/NlpResearch/Data/PublicDatasets/
CoRef Data:  /Users/simon.hughes/Google Drive/Phd/Data/CoralBleaching/Thesis_Dataset/CoReference/


In [9]:
def get_essays(folder, partition):
    essay_files = find_files(folder)
    if partition == "Training":
        essay_files = [e for e in essay_files if "train" in e]
    else:
        essay_files = [e for e in essay_files if "test" in e]
    assert len(essay_files) == 1
    print("Found file", essay_files[0])
    with open(essay_files[0], "rb") as f:
        loaded_essays = dill.load(f)
    return loaded_essays

essays = get_essays(stanford_coref_predictions_folder, PARTITION)

Found file /Users/simon.hughes/Google Drive/Phd/Data/CoralBleaching/Thesis_Dataset/CoReference/training_processed.dill


### Validate the Lengths

In [10]:
def validate_essays(essays):
    for e in essays:    
        # map coref ids to sent_ix, wd_ix tuples
        # now look for ana tags that are also corefs, and cross reference
        for sent_ix in range(len(e.sentences)):
            sent     = e.sentences[sent_ix]
            ana_tags = e.ana_tagged_sentences[sent_ix]
            coref_ids= e.pred_corefids[sent_ix]
            ner_tags = e.pred_ner_tags_sentences[sent_ix]
            pos_tags = e.pred_pos_tags_sentences[sent_ix]
            ptags    = e.pred_tagged_sentences[sent_ix]

            assert len(sent) == len(coref_ids)

            assert len(sent) == len(ana_tags) == len(coref_ids) == len(ner_tags) == len(pos_tags) == len(ptags),\
                (len(sent), len(ana_tags), len(coref_ids), len(ner_tags), len(pos_tags), len(ptags), e.name, sent_ix)
            assert len(sent) > 0
            
validate_essays(essays)

In [11]:
def tally_essay_attributes(essays, attribute_name="pred_pos_tags_sentences"):
    tally = defaultdict(int)
    for e in essays:
        nested_list = getattr(e, attribute_name)
        for lst in nested_list:
            for item in lst:
                if type(item) == str:
                    tally[item] +=1
                elif type(item) == set:
                    for i in item:
                        tally[i] +=1
                else:
                    raise Exception("Unexpected item type")
    return tally

In [12]:
ner_tally = tally_essay_attributes(essays, attribute_name="pred_ner_tags_sentences")
pos_tally = tally_essay_attributes(essays, attribute_name="pred_pos_tags_sentences")

## Look at the Anaphor Tags

In [13]:
from results_procesor import is_a_regular_code

cc_tally = defaultdict(int)
cr_tally = defaultdict(int)
reg_tally = defaultdict(int)
for e in essays:
    for sent in e.sentences:
        for wd, tags in sent:
            for t in tags:
                if is_a_regular_code(t):
                    reg_tally[t] += 1
                if ANAPHORA in t and "other" not in t:
                    if "->" in t:
                        cr_tally[t] += 1
                    elif "Anaphor:[" in t:
                        cc_tally[t] += 1

reg_tags = sorted(reg_tally.keys())
all_ana_tags = sorted(cc_tally.keys())
all_ana_tags

['Anaphor:[11]',
 'Anaphor:[12]',
 'Anaphor:[13]',
 'Anaphor:[14]',
 'Anaphor:[1]',
 'Anaphor:[2]',
 'Anaphor:[3]',
 'Anaphor:[4]',
 'Anaphor:[50]',
 'Anaphor:[5]',
 'Anaphor:[5b]',
 'Anaphor:[6]',
 'Anaphor:[7]']

In [16]:
def build_chain(e):
    """ Takes an essay object, and creats a map of Dict[str, List[Tuple{int,int}]]
        which maps a coref id (essay scope) to a list of (sent_ix,wd_ix) pairs
    """
    corefid_2_chain = defaultdict(list)
    for sent_ix in range(len(e.sentences)):
        sent     = e.sentences[sent_ix]
        coref_ids= e.pred_corefids[sent_ix]
        for wd_ix in range(len(sent)):
            wd_coref_ids = coref_ids[wd_ix] # Set[str]
            for cr_id in wd_coref_ids:
                pair = (sent_ix, wd_ix)
                corefid_2_chain[cr_id].append(pair)
    return corefid_2_chain

In [18]:
def build_segmented_chain(e):
    """ Takes an essay object, and creats a map of Dict[str, List[List[Tuple{int,int}]]
        which maps a coref id (essay scope) to a nested list of (sent_ix,wd_ix) pairs.
        The nested list has a separate inner list for every distinct coreference seq/phrase
    """

    corefid_2_chain = build_chain(e)
    corefid_2_segmented_chain = dict()
    for cref, pairs in corefid_2_chain.items():
        segmented = [[pairs[0]]]
        corefid_2_segmented_chain[cref] = segmented
        last_sent_ix, last_wd_ix = pairs[0]
        for pair in pairs[1:]:
            sent_ix, wd_ix = pair
            if sent_ix != last_sent_ix or (wd_ix - last_wd_ix) > 1:
                # create a new nested list
                segmented.append([])
            # append pair to last list item
            segmented[-1].append(pair)        
            last_sent_ix, last_wd_ix = pair
    return corefid_2_segmented_chain

In [19]:
corefid_2_segmented_chain = build_segmented_chain(essays[2])
for cref, seg_chain in sorted(corefid_2_segmented_chain.items()):
    print(cref)
    for lst in seg_chain:
        print("\t", str(lst))

1
	 [(0, 23), (0, 24), (0, 25)]
2
	 [(7, 0), (7, 1)]
	 [(7, 8)]
3
	 [(2, 0), (2, 1)]
	 [(3, 2)]
	 [(3, 9)]


In [41]:
from processessays import Essay
import warnings

def get_processed_essays(essays, format_ana_tags=True, filter_to_predicted_tags=True, look_back_only=True,
                         max_cref_phrase_len=None, ner_ch_filter=None, pos_filter=None, pos_ch_filter=None
                         ):
    """
    Create a copy of essays, augmenting the pred_tagged_sentences object with additional anaphora tags
    
    essays:                   List[Essay] objects - merged tagged essays
    format_ana_tags:          bool - Add ana tags as Anaphor[xyz] or as just the regular concept codes
    filter_to_predicted_tags: bool - Filter to just the predicted anaphor tags
    look_back_only:           bool - Only look to coreferences occuring earlier in the essay
    max_cref_phrase_len:      Union(int,None) - if specified, maximum coreference length to consider
    ner_ch_filter:            Union(Set[str],None) - if specified, filters to words in the cref chain
                                with one of those NER tags
    pos_filter:               Union(Set[str],None) - if specified, filters crefs to words with one of those POS tags
    pos_ch_filter:            Union(Set[str],None) - if specified, filters to words in the cref chain 
                                with one of those POS tags
    """
    if ner_ch_filter and EMPTY in ner_ch_filter:
        warnings.warn("EMPTY tag in NER filter ", UserWarning)
    if pos_filter and EMPTY in pos_filter:
        warnings.warn("EMPTY tag in POS filter ", UserWarning)
    if pos_ch_filter and EMPTY in pos_ch_filter:
        warnings.warn("EMPTY tag in POS chain filter ", UserWarning)
    
    ana_tagged_essays = []
    for eix, e in enumerate(essays):

        ana_tagged_e = Essay(e.name, e.sentences)
        ana_tagged_e.pred_tagged_sentences = []
        ana_tagged_essays.append(ana_tagged_e)

        # map coref ids to sent_ix, wd_ix tuples
        corefid_2_chain = build_segmented_chain(e)

        # now look for ana tags that are also corefs, and cross reference
        for sent_ix in range(len(e.sentences)):
            ana_tagged_sent = []
            ana_tagged_e.pred_tagged_sentences.append(ana_tagged_sent)

            sent     = e.sentences[sent_ix]
            
            # SENTENCE LEVEL TAGS / PREDICTIONS
            ana_tags = e.ana_tagged_sentences[sent_ix]
            coref_ids= e.pred_corefids[sent_ix]
            ner_tags = e.pred_ner_tags_sentences[sent_ix]
            pos_tags = e.pred_pos_tags_sentences[sent_ix]
            ptags    = e.pred_tagged_sentences[sent_ix]    

            for wd_ix in range(len(sent)):                            
                pos_tag = pos_tags[wd_ix] # POS tag             
                
                word, _ = sent[wd_ix] # ignore actual tags
                pred_cc_tag = ptags[wd_ix] # predict cc tag
                
                is_ana_tag = ana_tags[wd_ix] == ANAPHORA
                wd_coref_ids = coref_ids[wd_ix] # Set[str]

                # note we are changing this to a set rather than a single string
                wd_ptags = set()
                # add predicted concept code tag (filtered out by evaluation code, which filters to specific tags)
                if pred_cc_tag != EMPTY:
                    wd_ptags.add(pred_cc_tag)

                # initialize predicted tags, inc. cc tag
                # DON'T run continue until after this point
                ana_tagged_sent.append(wd_ptags)
                
                if len(wd_coref_ids) == 0:
                    continue

                # POS FILTER - for cref words and NOT words in the cref chain
                if pos_filter and pos_tag not in pos_filter:
                    continue
                    
                if filter_to_predicted_tags and not is_ana_tag:
                    continue
                    
                # Get codes for corresponding co-ref chain entries
                for cr_id in wd_coref_ids:                        
                    segmented_chain = corefid_2_chain[cr_id]
                    for cref_phrase in segmented_chain: # iterate thru the list of sent_ix,wd_ix's
                                                        # in 1 cref phrase

                        # LENGTH FILTER
                        if max_cref_phrase_len and len(cref_phrase) > max_cref_phrase_len:
                            continue                            

                        for ch_sent_ix, ch_wd_ix in cref_phrase:
                            # if it's the current word, skip
                            if ch_sent_ix == sent_ix and ch_wd_ix == wd_ix:
                                continue
                            # for anaphors only - only look at chain ixs before the current word
                            # if's it's after the current word in the essay, skip
                            if look_back_only:
                                # sentence later in the essay, or same sentence but word is after current word
                                if ch_sent_ix > sent_ix or \
                                  (ch_sent_ix == sent_ix and ch_wd_ix >= wd_ix):
                                    continue

                            chain_ptag = e.pred_tagged_sentences[ch_sent_ix][ch_wd_ix]
                            ch_ner_tag = e.pred_ner_tags_sentences[ch_sent_ix][ch_wd_ix]
                            ch_pos_tag = e.pred_pos_tags_sentences[ch_sent_ix][ch_wd_ix]

                            # CHAIN WORD TYPE FILTERS
                            # NER TAG TYPE FILTER - on chain 
                            if ner_ch_filter and ch_ner_tag not in ner_ch_filter:
                                continue
                            # POS TAG TYPE FILTER - on chain
                            if pos_ch_filter and ch_pos_tag not in pos_ch_filter:
                                continue

                            if chain_ptag != EMPTY:
                                code = chain_ptag
                                if format_ana_tags:
                                    code =  "{anaphora}:[{code}]".format(
                                        anaphora=ANAPHORA, code=chain_ptag)
                                wd_ptags.add(code)
    # validation check    
    #   check essay and sent lengths align
    for e in ana_tagged_essays:
        assert len(e.sentences) == len(e.pred_tagged_sentences)
        for ix in range(len(e.sentences)):
            assert len(e.sentences[ix]) == len(e.pred_tagged_sentences[ix])
            
    return ana_tagged_essays

In [26]:
# Modify this function from the Resultsprocessor so that it works with Set[str] of predicted tags 
# as well as scalar strings
def get_wd_level_preds(essays, expected_tags):
    expected_tags = set(expected_tags)
    ysbycode = defaultdict(list)
    for e in essays:
        for sentix in range(len(e.sentences)):
            p_ccodes = e.pred_tagged_sentences[sentix]
            for wordix in range(len(p_ccodes)):
                tags = p_ccodes[wordix]
                if type(tags) == str:
                    ptag_set = {tags}
                elif type(tags) in (set,list):
                    ptag_set = set(tags)   
                else:
                    raise Exception("Unrecognized tag type")
                for exp_tag in expected_tags:
                    ysbycode[exp_tag].append(ResultsProcessor._ResultsProcessor__get_label_(exp_tag, ptag_set))
    return ysbycode

In [27]:
def get_df(mean_metrics):
    df = metrics_to_df(mean_metrics)
    df = df[["code","recall","precision","f1_score","data_points"]]
    df = df.sort_values("code")
    return df[~df.code.str.contains("MEAN")]

In [51]:
def get_metrics_raw(essays, expected_tags, micro_only=False):
    act_ys_bycode  = ResultsProcessor.get_wd_level_lbs(essays,  expected_tags=expected_tags)
    pred_ys_bycode = get_wd_level_preds(essays, expected_tags=expected_tags)
    mean_metrics = ResultsProcessor.compute_mean_metrics(act_ys_bycode, pred_ys_bycode)
    return mean_metrics

In [28]:
def get_metrics(essays, expected_tags, micro_only=False):
    act_ys_bycode  = ResultsProcessor.get_wd_level_lbs(essays,  expected_tags=expected_tags)
    pred_ys_bycode = get_wd_level_preds(essays, expected_tags=expected_tags)
    mean_metrics = ResultsProcessor.compute_mean_metrics(act_ys_bycode, pred_ys_bycode)
    df = get_df(mean_metrics)
    if micro_only:
        df = df[df.code == "MICRO_F1"]
    return df

In [50]:
# pos_tally

In [32]:
pos_nouns = set([pos for pos in pos_tally.keys() if pos.strip()[:2] == "NN"])
pos_verbs = set([pos for pos in pos_tally.keys() if pos.strip()[:2] == "VB"])
pos_pronouns = {"PRP","PRP$", "WP", "WP$"}
pos_determiners = {"DT","WDT","PDT"} # the, a, which, that, etc
pos_pron_dt = pos_pronouns | pos_determiners
# for meaning of pen treebank tags - https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
pos_filter = {"IN"} | pos_nouns | pos_pronouns # WDT is a Wh Determiner, and PDT is pre-determiner, such as 'all' or 'half'
pos_nouns, pos_verbs, pos_filter

({'NN', 'NNP', 'NNPS', 'NNS'},
 {'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'},
 {'IN', 'NN', 'NNP', 'NNPS', 'NNS', 'PRP', 'PRP$', 'WP', 'WP$'})

In [60]:
dict_pos_filter = {
            "None": None,
            "PRN": pos_pronouns,
            "DT": pos_determiners,
            "PRN+DT": pos_pron_dt
}

dict_pos_ch_filter = {
    "None": None,
    "NN": pos_nouns,
    "VB": pos_verbs
}

In [65]:
def blank_if_none(val):
    return "-" if (val is None or not val or str(val).lower() == "none") else val

In [58]:
phrase_len = [None] + list(range(1,11))
phrase_len

[None, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

In [82]:
look_back_vals = [True,False]

# Grid Search With Anaphora Prediction Filters

In [87]:
filter_to_predicted_tags = True
format_ana_tags=True # Format tags with Anaphora[xyz]

# set up progress bar
max_count = len(look_back_vals)  * len(phrase_len) * len(dict_pos_filter) * len(dict_pos_ch_filter)
iprogress_bar = IntProgress(min=0, max=max_count) # instantiate the bar
display(iprogress_bar) # display the bar

LOOK_BACK = "Look back"
MAX_PHRASE = "Max phrase"
POS_FLTR = "POS filter"
POS_CHAIN_FLTR = "Pos chain filter"

rows_ana = []
for look_back_only in look_back_vals:
    for pos_key, pos_filter in dict_pos_filter.items():
        for pos_ch_key, pos_ch_filter in dict_pos_ch_filter.items():                
            for max_cref_phrase_len in phrase_len:
                proc_essays = get_processed_essays(
                    essays=essays, format_ana_tags=format_ana_tags, 
                    filter_to_predicted_tags=filter_to_predicted_tags, look_back_only=look_back_only,
                    max_cref_phrase_len=max_cref_phrase_len, ner_ch_filter=None, 
                    pos_filter=pos_filter, pos_ch_filter=pos_ch_filter)
                metrics = get_metrics_raw(proc_essays, all_ana_tags,  micro_only=True)
                row = metrics["MICRO_F1"]
                row[LOOK_BACK] = look_back_only
                row[MAX_PHRASE] = blank_if_none(max_cref_phrase_len)
                row[POS_FLTR] = blank_if_none(pos_key)
                row[POS_CHAIN_FLTR] = blank_if_none(pos_ch_key)
                rows_ana.append(row)
                iprogress_bar.value += 1

df_results_ana = pd.DataFrame(rows_ana)

In [88]:
df_disp = df_results_ana[["f1_score","precision","recall", LOOK_BACK, MAX_PHRASE, POS_FLTR, POS_CHAIN_FLTR]]
df_disp.sort_values("f1_score", ascending=False).head()

Unnamed: 0,f1_score,precision,recall,Look back,Max phrase,POS filter,Pos chain filter
0,0.037534,0.241379,0.020349,True,-,-,-
10,0.037534,0.241379,0.020349,True,10,-,-
21,0.037534,0.241379,0.020349,True,10,-,NN
20,0.037534,0.241379,0.020349,True,9,-,NN
109,0.037534,0.241379,0.020349,True,10,PRN+DT,-


## Grid Search without Anaphora Predictions

In [None]:
filter_to_predicted_tags = False # just use the raw coref chains
format_ana_tags=True # Format tags with Anaphora[xyz]

# set up progress bar
max_count = len(look_back_vals)  * len(phrase_len) * len(dict_pos_filter) * len(dict_pos_ch_filter)
iprogress_bar = IntProgress(min=0, max=max_count) # instantiate the bar
display(iprogress_bar) # display the bar

LOOK_BACK = "Look back"
MAX_PHRASE = "Max phrase"
POS_FLTR = "POS filter"
POS_CHAIN_FLTR = "Pos chain filter"

rows_chain = []
for look_back_only in look_back_vals:
    for pos_key, pos_filter in dict_pos_filter.items():
        for pos_ch_key, pos_ch_filter in dict_pos_ch_filter.items():                
            for max_cref_phrase_len in phrase_len:
                proc_essays = get_processed_essays(
                    essays=essays, format_ana_tags=format_ana_tags, 
                    filter_to_predicted_tags=filter_to_predicted_tags, look_back_only=look_back_only,
                    max_cref_phrase_len=max_cref_phrase_len, ner_ch_filter=None, 
                    pos_filter=pos_filter, pos_ch_filter=pos_ch_filter)
                metrics = get_metrics_raw(proc_essays, all_ana_tags,  micro_only=True)
                row = metrics["MICRO_F1"]
                row[LOOK_BACK] = look_back_only
                row[MAX_PHRASE] = blank_if_none(max_cref_phrase_len)
                row[POS_FLTR] = blank_if_none(pos_key)
                row[POS_CHAIN_FLTR] = blank_if_none(pos_ch_key)
                rows_chain.append(row)
                iprogress_bar.value += 1

df_results_chain = pd.DataFrame(rows_chain)

In [None]:
df_disp = df_results_chain[["f1_score","precision","recall", LOOK_BACK, MAX_PHRASE, POS_FLTR, POS_CHAIN_FLTR]]
df_disp.sort_values("f1_score", ascending=False).head()

In [None]:
#TODO - delete this cell
del rows