In [15]:
import string
from typing import Any
from collections import defaultdict
import dill
from sklearn.linear_model import LogisticRegression
import numpy as np

from CrossValidation import cross_validation
from MIRA import CostSensitiveMIRA
from Settings import Settings

from function_helpers import get_function_names
from results_procesor import ResultsProcessor
from cost_functions import micro_f1_cost_plusepsilon
from window_based_tagger_config import get_config
from shift_reduce_helper import *

In [6]:
# Global settings
settings = Settings()
root_folder = settings.data_directory + "CoralBleaching/Thesis_Dataset/"
training_folder = root_folder + "Training" + "/"
test_folder = root_folder + "Test" + "/"

coref_root = root_folder + "CoReference/"
coref_output_folder = coref_root + "CRel/"

config = get_config(training_folder)

train_fname = coref_output_folder + "training_crel_anatagged_essays_most_recent_code.dill"
with open(train_fname, "rb") as f:
    pred_tagged_essays_train = dill.load(f)

test_fname = coref_output_folder + "test_crel_anatagged_essays_most_recent_code.dill"
with open(test_fname, "rb") as f:
    pred_tagged_essays_test = dill.load(f)

print(len(pred_tagged_essays_train), len(pred_tagged_essays_test))

Results Dir: /Users/simon.hughes/Google Drive/Phd/Results/
Data Dir:    /Users/simon.hughes/Google Drive/Phd/Data/
Root Dir:    /Users/simon.hughes/GitHub/NlpResearch/
Public Data: /Users/simon.hughes/GitHub/NlpResearch/Data/PublicDatasets/
902 226


In [8]:
cr_tags = get_cr_tags(train_tagged_essays=pred_tagged_essays_train, tag_essays_test=pred_tagged_essays_test)

set_cr_tags = set(cr_tags)
list(set_cr_tags)[0:10]

['Causer:1->Result:4',
 'Causer:14->Result:6',
 'Causer:2->Result:3',
 'Causer:5->Result:5b',
 'Causer:12->Result:50',
 'Causer:4->Result:5b',
 'Causer:3->Result:5',
 'Causer:11->Result:14',
 'Causer:13->Result:7',
 'Causer:7->Result:1']

In [66]:
SENT = "<SENT>"

def get_reg_tags(tags):
    return [t for t in tags if "->" not in t and ":" not in t and t[0].isdigit()]

def flatten_essay(essay):
    flat_tagged_sent = []
    flat_pred_tags = []
    for sent_ix, taggged_sentence in enumerate(essay.sentences):
        flat_tagged_sent.extend(taggged_sentence)
        flat_tagged_sent.append((SENT, set()))

        flat_pred_tags.extend(essay.pred_tagged_sentences[sent_ix])
        flat_pred_tags.append(SENT)
    assert len(flat_tagged_sent) == len(flat_pred_tags), "Tagged essay should be the same length as the predicted tags"
    return flat_tagged_sent, flat_pred_tags

def get_tags_relations_for(tagged_sentence, predicted_tags, cr_tags):

    sent_reg_predicted_tags = set()
    sent_act_cr_tags = set()
    tag2ixs = defaultdict(list)

    tag_seq = [None]  # seed with None
    crel_set_seq = [set()]

    pos_tag_seq = []
    latest_tag_posns = {}
    crel_child_tags = defaultdict(set)
    for i, (wd, tags) in enumerate(tagged_sentence):
        if wd in string.punctuation:
            continue
        if wd == SENT:
            continue

        active_tag = None
        rtag = predicted_tags[i]
        if rtag != EMPTY_TAG:
            active_tag = rtag
            sent_reg_predicted_tags.add(active_tag)
            # if no prev tag and the current matches -2 (a gap of one), skip over
            if active_tag != tag_seq[-1] and \
                    not (tag_seq[-1] is None and (len(tag_seq) > 2) and active_tag == tag_seq[-2]):
                latest_tag_posns[active_tag] = (active_tag, i)
                pos_tag_seq.append((active_tag, i))
            # need to be after we update the latest tag position
            tag2ixs[latest_tag_posns[active_tag]].append(i)
        tag_seq.append(active_tag)

        active_crels = tags.intersection(cr_tags)
        for cr in sorted(active_crels):
            sent_act_cr_tags.add(cr)
            if cr not in crel_set_seq[-1] \
                    and not (cr not in crel_set_seq[-1] and (len(crel_set_seq) > 2) and cr in crel_set_seq[-2]):
                latest_tag_posns[cr] = (cr, i)
        crel_set_seq.append(active_crels)

        # to have child tags, need a tag sequence and a current valid regular tag
        if not active_tag or len(active_crels) == 0:
            continue

        for crel in active_crels:
            l, r = normalize_cr(crel)
            if active_tag in (l, r):
                crel_child_tags[latest_tag_posns[crel]].add(latest_tag_posns[active_tag])

    pos_crels = []
    for (crelation, crix), tag_pairs in crel_child_tags.items():
        l, r = normalize_cr(crelation)
        # unsupported relation
        if l not in sent_reg_predicted_tags or r not in sent_reg_predicted_tags:
            continue
        tag2pair = defaultdict(list)
        for taga, ixa in tag_pairs:
            tag2pair[taga].append((taga, ixa))
        # un-supported relation
        if l not in tag2pair or r not in tag2pair:
            continue

        l_pairs = tag2pair[l]
        r_pairs = tag2pair[r]
        for pairsa in l_pairs:
            for pairsb in r_pairs:
                if pairsa != pairsb:
                    pos_crels.append((pairsa, pairsb))

    tag2span = dict()
    for tagpos, ixs in tag2ixs.items():
        tag2span[tagpos] = (min(ixs), max(ixs))

    return pos_tag_seq, pos_crels, tag2span, sent_reg_predicted_tags, sent_act_cr_tags

In [13]:
essay2collapsed = dict()
for essay_ix, essay in enumerate(pred_tagged_essays_train):
    essay2collapsed[essay.name] = flatten_essay(essay)

In [48]:
def get_sent_crels(sent):
    crels = set()
    for wd, tags in sent:
        cr = set_cr_tags.intersection(tags)
        if cr:
            crels.update(cr)
    return crels

cross_sent_crels = dict()
for essay in pred_tagged_essays_train:
    for ix, sent in enumerate(essay.sentences):
        if ix == 0:
            continue
        sent_crels = get_sent_crels(sent)
        if len(sent_crels) == 0:
            continue
        prev_sent = essay.sentences[ix-1]
        prev_crels = get_sent_crels(prev_sent)
        crel_crossing =  prev_crels.intersection(sent_crels)
        if crel_crossing:
            cross_sent_crels[(essay.name, ix)] = (crel_crossing, prev_sent, sent)
len(cross_sent_crels)

221

In [91]:
essay_name, sent_ix = list(cross_sent_crels.keys())[0] #0,5
crel_crossing, prev_sent, sent = cross_sent_crels[(essay_name, sent_ix)]

print(essay_name, sent_ix)
print(crel_crossing)
print("*" * 80)
for wd, tags in prev_sent:
    reg_tags = get_reg_tags(tags)
    print(wd.ljust(30), set_cr_tags.intersection(tags), reg_tags)
print("*" * 80)
for wd, tags in sent:
    reg_tags = get_reg_tags(tags)
    print(wd.ljust(30), set_cr_tags.intersection(tags), reg_tags)

EBA1415_BGJD_1_CB_ES-05975.ann 8
{'Causer:7->Result:50'}
********************************************************************************
during                         set() []
bleaching                      set() []
;                              set() []
corals                         {'Causer:7->Result:50'} ['50']
turn                           {'Causer:7->Result:50'} ['50']
white                          {'Causer:7->Result:50'} ['50']
due                            {'Causer:7->Result:50'} []
to                             {'Causer:7->Result:50'} []
the                            {'Causer:7->Result:50'} []
ejection                       {'Causer:7->Result:50'} ['7']
or                             {'Causer:7->Result:50'} ['7']
death                          {'Causer:7->Result:50'} ['7']
of                             {'Causer:7->Result:50'} ['7']
the                            {'Causer:7->Result:50'} ['7']
zooxanthellae                  {'Causer:7->Result:50'} ['7']
.               

In [49]:
for essay_name, (flat_tagged_sent, flat_pred_tags) in essay2collapsed.items():    
    pos_tag_seq, pos_crels, tag2span, sent_reg_predicted_tags, sent_act_cr_tags = \
        get_tags_relations_for(flat_tagged_sent, flat_pred_tags, cr_tags)

In [92]:
# essay_name = "EBA1415_BLHT_5_CB_ES-05205.ann"
flat_tagged_sent, flat_pred_tags = essay2collapsed[essay_name]

pos_tag_seq, pos_crels, tag2span, sent_reg_predicted_tags, sent_act_cr_tags = \
        get_tags_relations_for(flat_tagged_sent, flat_pred_tags, cr_tags)

In [95]:
# pos_tag_seq
sent_cnt = 0
for ix, (wd, tags) in enumerate(flat_tagged_sent):    
    if wd == SENT:
        sent_cnt += 1
    if sent_cnt in  (sent_ix,sent_ix-1):
        print(str(ix).ljust(5), wd.ljust(30), ",".join(get_reg_tags(tags)).ljust(5), tags.intersection(set_cr_tags))
    

113   <SENT>                               set()
114   during                               set()
115   bleaching                            set()
116   ;                                    set()
117   corals                         50    {'Causer:7->Result:50'}
118   turn                           50    {'Causer:7->Result:50'}
119   white                          50    {'Causer:7->Result:50'}
120   due                                  {'Causer:7->Result:50'}
121   to                                   {'Causer:7->Result:50'}
122   the                                  {'Causer:7->Result:50'}
123   ejection                       7     {'Causer:7->Result:50'}
124   or                             7     {'Causer:7->Result:50'}
125   death                          7     {'Causer:7->Result:50'}
126   of                             7     {'Causer:7->Result:50'}
127   the                            7     {'Causer:7->Result:50'}
128   zooxanthellae                  7     {'Causer:7->Result:50'}


In [96]:
pos_crels

[(('7', 17), ('50', 25)), (('7', 123), ('50', 115)), (('7', 123), ('50', 135))]

In [100]:
pos_tag_seq

[('7', 17),
 ('50', 25),
 ('50', 95),
 ('6', 99),
 ('14', 102),
 ('50', 115),
 ('7', 123),
 ('50', 135),
 ('50', 165)]

In [97]:
tag2span[('7',123)]

(123, 128)

In [99]:
tag2span[('50',135)]

(135, 136)