In [12]:
import io
import re

In [13]:
REMOVED_CHAR = ["/", "%", "*"]
NORMALIZE_DICT = {"/.": ".",
                  "/?": "?",
                  "-LRB-": "(",
                  "-RRB-": ")",
                  "-LCB-": "{",
                  "-RCB-": "}",
                  "-LSB-": "[",
                    "-RSB-": "]"}

In [14]:
def clean_token(token):
    cleaned_token = token
    if cleaned_token in NORMALIZE_DICT:
        cleaned_token = NORMALIZE_DICT[cleaned_token]
    if cleaned_token not in REMOVED_CHAR:
        for char in REMOVED_CHAR:
            cleaned_token = cleaned_token.replace(char, u'')
    if len(cleaned_token) == 0:
        cleaned_token = ","
    return cleaned_token

In [18]:
def load_file(full_name, debug=False):
    '''
    load a *._conll file
    Input: full_name: path to the file
    Output: list of tuples for each conll doc in the file, where the tuple contains:
        (utts_text ([str]): list of the utterances in the document 
         utts_tokens ([[str]]): list of the tokens (conll words) in the document 
         utts_corefs: list of coref objects (dicts) with the following properties:
            coref['label']: id of the coreference cluster,
            coref['start']: start index (index of first token in the utterance),
            coref['end': end index (index of last token in the utterance).
         utts_speakers ([str]): list of the speaker associated to each utterances in the document 
         name (str): name of the document
         part (str): part of the document
        )
    '''
    docs = []
    with io.open(full_name, 'rt', encoding='utf-8', errors='strict') as f:
        lines = list(f)#.readlines()
        utts_text = []
        utts_tokens = []
        utts_corefs = []
        utts_speakers = []
        tokens = []
        corefs = []
        index = 0
        speaker = ""
        name = ""
        part = ""
        #print("********************only once ********************************")
        for li, line in enumerate(lines):
            cols = line.split()
            if debug: print("line", li, "cols:", cols)
            # End of utterance
            if len(cols) == 0:
                if tokens:
                    if debug: print("End of utterance")
                    utts_text.append(u''.join(t + u' ' for t in tokens))
                    utts_tokens.append(tokens)
                    utts_speakers.append(speaker)
                    utts_corefs.append(corefs)
                    tokens = []
                    corefs = []
                    #print("index reset cols = 0 ", index)
                    index = 0
                    speaker = ""
                    continue
            # End of doc
            elif len(cols) == 2:
                if debug: print("End of doc")
                if cols[0] == "#end":
                    if debug: print("Saving doc")
                    docs.append((utts_text, utts_tokens, utts_corefs, utts_speakers, name, part))
                    utts_text = []
                    utts_tokens = []
                    utts_corefs = []
                    utts_speakers = []
                else:
                    raise ValueError("Error on end line " + line)
            # New doc
            elif len(cols) == 5:
                if debug: print("New doc")
                if cols[0] == "#begin":
                    name = re.match(r"\((.*)\);", cols[2]).group(1)
                    try:
                        part = cols[4]
                    except ValueError:
                        print("Error parsing document part " + line)
                    if debug: print("New doc", name, part, name[:2])
                    tokens = []
                    corefs = []
                    #print("index reset")
                    index = 0
                else:
                    raise ValueError("Error on begin line " + line)
            # Inside utterance
            elif len(cols) > 7:
                if debug: print("Inside utterance")
                assert (cols[0] == name and int(cols[1]) == int(part)), "Doc name or part error " + line
                assert (int(cols[2]) == index), "Index error on " + line
                if speaker:
                    assert (cols[9] == speaker), "Speaker changed in " + line + speaker
                else:
                    speaker = cols[9]
                    if debug: print("speaker", speaker)
                if cols[-1] != u'-':
                    coref_expr = cols[-1].split(u'|')
                    if debug: print("coref_expr", coref_expr)
                    if not coref_expr:
                        raise ValueError("Coref expression empty " + line)
                    for tok in coref_expr:
                        if debug: print("coref tok", tok)
                        try:
                            match = re.match(r"^(\(?)(\d+)(\)?)$", tok)
                        except:
                            print("error getting coreferences for line " + line)
                        assert match is not None, "Error parsing coref " + tok + " in " + line
                        num = match.group(2)
                        assert (num is not u''), "Error parsing coref " + tok + " in " + line
                        if match.group(1) == u'(':
                            if debug: print("New coref", num)
                            corefs.append({'label': num, 'start': index, 'end': None})
                        if match.group(3) == u')':
                            j = None
                            for i in range(len(corefs)-1, -1, -1):
                                if debug: print("i", i)
                                if corefs[i]['label'] == num and corefs[i]['end'] is None:
                                    j = i
                                    break
                            assert (j is not None), "coref closing error " + line
                            if debug: print("End coref", num)
                            corefs[j]['end'] = index
                tokens.append(clean_token(cols[3]))
                index += 1
                #print("index manoj",index)
            else:
                raise ValueError("Line not standard " + line)
    return docs

In [19]:
train_path = "/home/forcerequestspring19_gmail_com/neural/neural-coref/conll_data/train.english.v4_gold_conll"
test_path  = "/home/forcerequestspring19_gmail_com/neural/neural-coref/conll_data/test.english.v4_gold_conll"
dev_path   = "/home/forcerequestspring19_gmail_com/neural/neural-coref/conll_data/dev.english.v4_gold_conll"

In [20]:
test_data =  load_file(test_path)

In [21]:
train_data =  load_file(train_path)

In [23]:
import numpy as np

In [29]:
np.save("train_data.npy", train_data)
np.save("test_data.npy", test_data)

In [None]:
import spacy
from spacy import displacy
from collections import Counter
import en_core_web_lg
nlp = en_core_web_lg.load()

In [25]:
import re

In [26]:
debug = True

In [27]:
NO_COREF_LIST = ["i", "me", "my", "you", "your"]

MENTION_TYPE = {"PRONOMINAL": 0, "NOMINAL": 1, "PROPER": 2, "LIST": 3}
MENTION_LABEL = {0: "PRONOMINAL", 1: "NOMINAL", 2: "PROPER", 3: "LIST"}

PROPERS_TAGS = ["NN", "NNS", "NNP", "NNPS"]
ACCEPTED_ENTS = ["PERSON", "NORP", "FACILITY", "ORG", "GPE", "LOC", "PRODUCT", "EVENT", "WORK_OF_ART", "LANGUAGE"]
WHITESPACE_PATTERN = r"\s+|_+"
UNKNOWN_WORD = "*UNK*"
MISSING_WORD = "<missing>"
MAX_ITER = 100

In [28]:
def extract_mentions_spans(doc):
    '''
    Extract potential mentions from a spacy parsed Doc
    '''
    if debug: print('===== doc ====:', doc)
    for c in doc:
        if debug: print("🚧 span search:", c, "head:", c.head, "tag:", c.tag_, "pos:", c.pos_, "dep:", c.dep_)
    # Named entities
    mentions_spans = list(ent for ent in doc.ents if ent.label_ in ACCEPTED_ENTS)

    if debug: print("==-- ents:", list(((ent, ent.label_) for ent in mentions_spans)))
    for sent in doc.sents:
        spans = _extract_from_sent(doc, sent, True)
        mentions_spans = mentions_spans + spans
    spans_set = set()
    cleaned_mentions_spans = []
    for spans in mentions_spans:
        if spans.end > spans.start and (spans.start, spans.end) not in spans_set:
            cleaned_mentions_spans.append(spans)
            spans_set.add((spans.start, spans.end))

    return cleaned_mentions_spans

In [29]:
def _extract_from_sent(doc, span, blacklist=True, debug=False):
    '''
    Extract Pronouns and Noun phrases mentions from a spacy Span
    '''
    keep_tags = re.compile(r"N.*|PRP.*|DT|IN")
    leave_dep = ["det", "compound", "appos"]
    keep_dep = ["nsubj", "dobj", "iobj", "pobj"]
    nsubj_or_dep = ["nsubj", "dep"]
    conj_or_prep = ["conj", "prep"]
    remove_pos = ["CCONJ", "INTJ", "ADP"]
    lower_not_end = ["'s", ',', '.', '!', '?', ':', ';']

    # Utility to remove bad endings
    def cleanup_endings(left, right, token):
        minchild_idx = min(left + [token.i]) if left else token.i
        maxchild_idx = max(right + [token.i]) if right else token.i
        # Clean up endings and begginging
        while maxchild_idx >= minchild_idx and (doc[maxchild_idx].pos_ in remove_pos
                                           or doc[maxchild_idx].lower_ in lower_not_end):
            if debug: print("Removing last token", doc[maxchild_idx].lower_, doc[maxchild_idx].tag_)
            maxchild_idx -= 1 # We don't want mentions finishing with 's or conjunctions/punctuation
        while minchild_idx <= maxchild_idx and (doc[minchild_idx].pos_ in remove_pos 
                                           or doc[minchild_idx].lower_ in lower_not_end):
            if debug: print("Removing first token", doc[minchild_idx].lower_, doc[minchild_idx].tag_)
            minchild_idx += 1 # We don't want mentions starting with 's or conjunctions/punctuation
        return minchild_idx, maxchild_idx+1

    mentions_spans = []
    for token in span:
        if debug: print("🚀 tok:", token, "tok.tag_:", token.tag_, "tok.pos_:", token.pos_, "tok.dep_:", token.dep_)

        if blacklist and token.lower_ in NO_COREF_LIST:
            if debug: print("token in no_coref_list")
            continue
        if (not keep_tags.match(token.tag_) or token.dep_ in leave_dep) and not token.dep_ in keep_dep:
            if debug: print("not pronoun or no right dependency")
            continue

        # pronoun
        if re.match(r"PRP.*", token.tag_):
            if debug: print("PRP")
            endIdx = token.i + 1

            span = doc[token.i: endIdx]
            if debug: print("==-- PRP store:", span)
            mentions_spans.append(span)

            # when pronoun is a part of conjunction (e.g., you and I)
            if token.n_rights > 0 or token.n_lefts > 0:
                span = doc[token.left_edge.i : token.right_edge.i+1]
                if debug: print("==-- in conj store:", span)
                mentions_spans.append(span)
            continue

        # Add NP mention
        if debug:
            print("NP or IN:", token.lower_)
            if token.tag_ == 'IN':
                print("IN tag")
        # Take care of 's
        if token.lower_ == "'s":
            if debug: print("'s detected")
            h = token.head
            j = 0
            while h.head.i != h.i and j < MAX_ITER:
                if debug:
                    print("token head:", h, h.dep_, "head:", h.head)
                    print(id(h.head), id(h))
                if h.dep_ == "nsubj":
                    minchild_idx = min((c.left_edge.i for c in doc if c.head.i == h.head.i and c.dep_ in nsubj_or_dep),
                                       default=token.i)
                    maxchild_idx = max((c.right_edge.i for c in doc if c.head.i == h.head.i and c.dep_ in nsubj_or_dep),
                                       default=token.i)
                    if debug: print("'s', i1:", doc[minchild_idx], " i2:", doc[maxchild_idx])
                    span = doc[minchild_idx : maxchild_idx+1]
                    if debug: print("==-- 's' store:", span)
                    mentions_spans.append(span)
                    break
                h = h.head
                j += 1
            assert j != MAX_ITER
            continue

        # clean up
        for c in doc:
            if debug and c.head.i == token.i: print("🚧 token in span:", c, "- head & dep:", c.head, c.dep_)
        left = list(c.left_edge.i for c in doc if c.head.i == token.i)
        right = list(c.right_edge.i for c in doc if c.head.i == token.i)
        if token.tag_ == 'IN' and token.dep_ == "mark" and len(left) == 0 and len(right) == 0:
            left = list(c.left_edge.i for c in doc if c.head.i == token.head.i)
            right = list(c.right_edge.i for c in doc if c.head.i == token.head.i)
        if debug:
            print("left side:", left)
            print("right side:", right)
            minchild_idx = min(left) if left else token.i
            maxchild_idx = max(right) if right else token.i
            print("full span:", doc[minchild_idx:maxchild_idx+1])
        start, end = cleanup_endings(left, right, token)
        if start == end:
            continue
        if doc[start].lower_ == "'s":
            continue # we probably already have stored this mention
        span = doc[start:end]
        if debug:
            print("cleaned endings span:", doc[start:end])
            print("==-- full span store:", span)
        mentions_spans.append(span)
        if debug and token.tag_ == 'IN':
            print("IN tag")
        if any(tok.dep_ in conj_or_prep for tok in span):
            if debug: print("Conjunction found, storing first element separately")
            for c in doc:
                if c.head.i == token.i and c.dep_ not in conj_or_prep:
                    if debug: print("left no conj:", c, 'dep & edge:', c.dep_, c.left_edge)
                    if debug: print("right no conj:", c, 'dep & edge:', c.dep_, c.right_edge)
            left_no_conj = list(c.left_edge.i for c in doc if c.head.i == token.i and c.dep_ not in conj_or_prep)
            right_no_conj = list(c.right_edge.i for c in doc if c.head.i == token.i and c.dep_ not in conj_or_prep)
            if debug: print("left side no conj:", [doc[i] for i in left_no_conj])
            if debug: print("right side no conj:", [doc[i] for i in right_no_conj])
            start, end = cleanup_endings(left_no_conj, right_no_conj, token)
            if start == end:
                continue
            span = doc[start:end]
            if debug: print("==-- full span store:", span)
            mentions_spans.append(span)
    if debug: print("mentions_spans inside", mentions_spans)
    return mentions_spans

In [31]:
doc = nlp('He is Manoj, he went to college')
extract_mentions_spans(doc)

===== doc ====: He is Manoj, he went to college
🚧 span search: He head: is tag: PRP pos: PRON dep: nsubj
🚧 span search: is head: went tag: VBZ pos: VERB dep: ccomp
🚧 span search: Manoj head: is tag: NNP pos: PROPN dep: attr
🚧 span search: , head: went tag: , pos: PUNCT dep: punct
🚧 span search: he head: went tag: PRP pos: PRON dep: nsubj
🚧 span search: went head: went tag: VBD pos: VERB dep: ROOT
🚧 span search: to head: went tag: IN pos: ADP dep: prep
🚧 span search: college head: to tag: NN pos: NOUN dep: pobj
==-- ents: [(Manoj, 'GPE')]


[Manoj, He, he, college]