In [269]:
import pandas as pd
import stanza
from tqdm import tqdm, trange
import re
import numpy as np
from matplotlib import pyplot as plt
from textdistance import lcsseq
from collections import Counter
import multiset
import jsonlines

## Read script, parsed_script and annotations

In [74]:
bourne_script = open("annotated-data/bourne.script.txt").read()
bourne_parsed_script = open("annotated-data/bourne.script_parsed.txt").read().strip()
bourne_annotations = pd.read_csv("annotated-data/bourne.coref.csv", index_col = None)

basterds_script = open("annotated-data/basterds.script.txt").read()
basterds_parsed_script = open("annotated-data/basterds.script_parsed.txt").read().strip()
basterds_annotations = pd.read_csv("annotated-data/basterds.coref.csv", index_col = None)

shawshank_script = open("annotated-data/shawshank.script.txt").read()
shawshank_parsed_script = open("annotated-data/shawshank.script_parsed.txt").read().strip()
shawshank_annotations = pd.read_csv("annotated-data/shawshank.coref.csv", index_col = None)

## NLP pipelines

In [3]:
nlp = stanza.Pipeline(processors="tokenize,pos")
nlp2 = stanza.Pipeline(processors="tokenize", tokenize_no_ssplit=True)

2020-11-18 11:12:42 INFO: Loading these models for language: en (English):
| Processor | Package |
-----------------------
| tokenize  | ewt     |
| pos       | ewt     |

2020-11-18 11:12:42 INFO: Use device: gpu
2020-11-18 11:12:42 INFO: Loading: tokenize
2020-11-18 11:12:46 INFO: Loading: pos
2020-11-18 11:12:47 INFO: Done loading processors!
2020-11-18 11:12:47 INFO: Loading these models for language: en (English):
| Processor | Package |
-----------------------
| tokenize  | ewt     |

2020-11-18 11:12:47 INFO: Use device: gpu
2020-11-18 11:12:47 INFO: Loading: tokenize
2020-11-18 11:12:47 INFO: Done loading processors!


## Flatten tokens

In [110]:
def flatten_tokens(parsed_script):
    lines = parsed_script.split("\n")
    tags = []
    texts = []
    docs = []
    flattened_tokens = []
    flattened_token_sizes = []
    flattened_indices = []

    for i, line in tqdm(enumerate(lines), total=len(lines)):
        tag = line[0]
        text = line[2:].strip()
        doc = nlp(text).to_dict()

        tags.append(tag)
        texts.append(text)
        docs.append(doc)

        for j, sent in enumerate(doc):
            for k, token in enumerate(sent):
                flattened_tokens.append(token["text"])
                flattened_token_sizes.append(len(re.sub("\s", "", token["text"])))
                flattened_indices.append((i, j, k))
    
    return {"tag": tags, "text": texts, "doc": docs, "token": flattened_tokens, "index": flattened_indices, "size": flattened_token_sizes}

In [111]:
bourne_info = flatten_tokens(bourne_parsed_script)
basterds_info = flatten_tokens(basterds_parsed_script)
shawshank_info = flatten_tokens(shawshank_parsed_script)

bourne_info["coref"] = bourne_annotations
bourne_info["script"] = bourne_script

basterds_info["coref"] = basterds_annotations
basterds_info["script"] = basterds_script

shawshank_info["coref"] = shawshank_annotations
shawshank_info["script"] = shawshank_script

100%|██████████| 649/649 [00:23<00:00, 28.20it/s]
100%|██████████| 591/591 [00:20<00:00, 28.45it/s]
100%|██████████| 525/525 [00:18<00:00, 27.97it/s]


In [248]:
bourne_info["name"] = "bourne"
basterds_info["name"] = "basterds"
shawshank_info["name"] = "shawshank"

## Find mention span

In [145]:
def find_lcsi(A, B):
    a = [""] + A
    b = [""] + B
    lcsl = np.zeros((len(a), len(b)), dtype=int)
    lcsi = np.full((len(a), len(b)), set())

    for i in range(1, len(a)):
        for j in range(1, len(b)):
            if a[i] == b[j]:
                lcsl[i,j] = lcsl[i-1,j-1] + 1
                lcsi[i,j] = lcsi[i-1,j-1].union({(i-1,j-1)})
            else:
                k, l = i-1, j-1
                if lcsl[i-1,j] > lcsl[k,l]:
                    k, l = i-1, j
                if lcsl[i,j-1] > lcsl[k,l]:
                    k, l = i, j-1
                lcsl[i,j] = lcsl[k,l]
                lcsi[i,j] = lcsi[k,l]
                
    return lcsi[len(a)-1,len(b)-1]

In [201]:
def find_mention_span(begin, end, info):
    script = info["script"]
    qi = begin - 1
    qj = end
    
    n_left_context_tokens = 0
    n_left_context_chars = 0
    n_right_context_tokens = 0
    n_right_context_chars = 0
    
    while qi >= 0:
        while qi >= 0 and re.match("\s", script[qi]):
            qi -= 1
        while qi >= 0 and re.match("\S", script[qi]):
            qi -= 1
            n_left_context_chars += 1
        n_left_context_tokens += 1
        if n_left_context_tokens == 5 or n_left_context_chars >= 50:
            break
    qi += 1
    
    tot = len(script)
    while qj < tot:
        while qj < tot and re.match("\s", script[qj]):
            qj += 1
        while qj < tot and re.match("\S", script[qj]):
            qj += 1
            n_right_context_chars += 1
        n_right_context_tokens += 1
        if n_right_context_tokens == 5 or n_right_context_chars >= 50:
            break
    
    left_context = script[qi: begin].strip()
    right_context = script[end: qj].strip()
    mention = script[begin:end]
    
    lc_doc = nlp2(left_context)
    lc_tokens = [token.text for token in lc_doc.iter_words()]
    rc_doc = nlp2(right_context)
    rc_tokens = [token.text for token in rc_doc.iter_words()]
    mention_doc = nlp2(mention)
    mention_tokens = [token.text for token in mention_doc.iter_words()]
    
    lc_set = multiset.Multiset(lc_tokens)
    rc_set = multiset.Multiset(rc_tokens)
    mention_set = multiset.Multiset(mention_tokens)
    
    h = len(lc_tokens) + len(mention_tokens) + len(rc_tokens)
    tokens = info["token"]
    indices = info["index"]
    ind_match_arr = []
    
    for i in range(len(tokens) - h + 1):
        if tokens[i: i + h] == lc_tokens + mention_tokens + rc_tokens:
            return i + len(lc_tokens), i + len(lc_tokens) + len(mention_tokens)
        else:
            token_set = multiset.Multiset(tokens[i: i + h])
            nc1 = len(token_set.intersection(mention_set))
            token_set.difference_update(mention_set)
            nc2 = len(token_set.intersection(lc_set))
            token_set.difference_update(lc_set)
            nc3 = len(token_set.intersection(rc_set))
            if nc1 and nc2 and nc3:
                ind_match_arr.append((i, nc1 + nc2 + nc3))

    ind_match_arr = sorted(ind_match_arr, key = lambda item: item[1], reverse = True)

    print_lines = []
    
    for ind, match in ind_match_arr[:20]:
        if match == 0:
            break
        lcsi = find_lcsi(tokens[ind: ind + h], lc_tokens + mention_tokens + rc_tokens)
        alignment = []
        for i, j in lcsi:
            if j >= len(lc_tokens) and j < len(lc_tokens) + len(mention_tokens):
                alignment.append((i,j))
        alignment = sorted(alignment)
        matched_tokens = [tokens[ind + i] for i, _ in alignment]
        print_lines.append(f"{match:2d} {len(lcsi):2d} {len(alignment):2d} {matched_tokens} {tokens[ind: ind + h]}")
        if matched_tokens == mention_tokens:
            return ind + alignment[0][0], ind + alignment[-1][0] + 1
    
    return False, False

In [202]:
def find_mention_spans(info):
    n_matched = 0
    begin_span = []
    end_span = []
    
    for ri, row in tqdm(info["coref"].iterrows(), total = len(info["coref"])):
        i, j = find_mention_span(row["begin"], row["end"], info)
        if i:
            begin_span.append(i)
            end_span.append(j)
            n_matched += 1
        else:
            begin_span.append(None)
            end_span.append(None)
    
    info["coref"]["begin_span"] = begin_span
    info["coref"]["end_span"] = end_span
    print(f"{n_matched}/{len(info['coref'])} matched")

In [203]:
find_mention_spans(bourne_info)
find_mention_spans(basterds_info)
find_mention_spans(shawshank_info)

100%|██████████| 911/911 [03:02<00:00,  5.00it/s]
  0%|          | 0/1005 [00:00<?, ?it/s]

869/911 matched


100%|██████████| 1005/1005 [03:06<00:00,  5.40it/s]
  0%|          | 1/887 [00:00<02:00,  7.34it/s]

968/1005 matched


100%|██████████| 887/887 [02:33<00:00,  5.77it/s]

873/887 matched





## Check span

In [226]:
def find_doc_spans(info):
    n_no_overlap = 0
    n_overlap_element = 0
    n_overlap_sentence = 0
    n = 0
    element = np.full(len(info["coref"]), np.nan)
    sent = np.full(len(info["coref"]), np.nan)
    word_begin = np.full(len(info["coref"]), np.nan)
    word_end = np.full(len(info["coref"]), np.nan)

    for ind, row in info["coref"].iterrows():
        element_indices, sentence_indices, word_indices = [], [], []
        if pd.notna(row["begin_span"]):
            for i in range(int(row["begin_span"]), int(row["end_span"])):
                x, y, z = info["index"][i]
                element_indices.append(x)
                sentence_indices.append(y)
                word_indices.append(z)
            if len(set(element_indices)) > 1:
                n_overlap_element += 1
            if len(set(element_indices)) > 1 or len(set(sentence_indices)) > 1:
                n_overlap_sentence += 1
            else:
                n_no_overlap += 1
                element[ind] = element_indices[0]
                sent[ind] = sentence_indices[0]
                word_begin[ind] = word_indices[0]
                word_end[ind] = word_indices[-1]
            n += 1
    info["coref"]["element"] = element
    info["coref"]["sent"] = sent
    info["coref"]["word_begin"] = word_begin
    info["coref"]["word_end"] = word_end

    print(f"{n_overlap_element}/{n} overlap elements")
    print(f"{n_overlap_sentence}/{n} overlap sents")
    print(f"{n_no_overlap}/{n} no overlap")

In [227]:
find_doc_spans(bourne_info)

1/869 overlap elements
4/869 overlap sents
865/869 no overlap


In [228]:
find_doc_spans(basterds_info)

1/968 overlap elements
79/968 overlap sents
889/968 no overlap


In [229]:
find_doc_spans(shawshank_info)

1/873 overlap elements
1/873 overlap sents
872/873 no overlap


## Check non-nested structures within sentences

In [240]:
for info in [bourne_info, basterds_info, shawshank_info]:
    for (element, sent), df in info["coref"].groupby(["element","sent"]):
        spans = [(int(row["word_begin"]), int(row["word_end"])) for _, row in df.iterrows()]
        spans = sorted(spans)
        for i in range(len(spans)):
            for j in range(i + 1, len(spans)):
                x, y = spans[i]
                a, b = spans[j]
                if x < a and a <= y and y < b:
                    display(df)

## Convert to CoNLL-U Format

In [267]:
def convert_to_conll(info, conll_file):
    index_to_mention_start_labels = {}
    index_to_mention_end_labels = {}

    for _, row in info["coref"].iterrows():
        if pd.notna(row["element"]):
            begin_index = int(row["begin_span"])
            end_index = int(row["end_span"]) - 1
            mention_label = int(row["entityNum"])

            if begin_index not in index_to_mention_start_labels:
                index_to_mention_start_labels[begin_index] = set()
            if end_index not in index_to_mention_end_labels:
                index_to_mention_end_labels[end_index] = set()

            index_to_mention_start_labels[begin_index].add(mention_label)
            index_to_mention_end_labels[end_index].add(mention_label)

    ind = 0
    for i, element in enumerate(info["doc"]):
        for j, sent in enumerate(element):
            for k, word in enumerate(sent):
                info["doc"][i][j][k]["coref"] = "-"
                text_labels = []
                start_labels = set()
                end_labels = set()
                if ind in index_to_mention_start_labels:
                    start_labels = index_to_mention_start_labels[ind]
                if ind in index_to_mention_end_labels:
                    end_labels = index_to_mention_end_labels[ind]
                common_labels = start_labels.intersection(end_labels)
                start_labels.difference_update(common_labels)
                end_labels.difference_update(common_labels)

                for label in common_labels:
                    text_labels.append(f"({label})")
                for label in start_labels:
                    text_labels.append(f"({label}")
                for label in end_labels:
                    text_labels.append(f"{label})")
                text = "|".join(text_labels)
                if text:
                    info["doc"][i][j][k]["coref"] = text
                ind += 1
            
    lines = [f"#begin document ({info['name']}); part 0"]
    ind = 0
    for i, element in enumerate(info["doc"]):
        for j, sent in enumerate(element):
            for k, word in enumerate(sent):
                lines.append(f"{info['tag'][i]}\t{j:2d}\t{k:2d}\t{word['text']:>20s}\t{word['xpos']:5s}\t-\t-\t-\t-\t-\t-\t{word['coref']}")
            lines.append("")
    lines.append("#end document")
    conll = "\n".join(lines)
    open(conll_file, "w").write(conll)

In [268]:
convert_to_conll(bourne_info, "annotated-data/bourne.conll")
convert_to_conll(basterds_info, "annotated-data/basterds.conll")
convert_to_conll(shawshank_info, "annotated-data/shawshank.conll")

## Check jsonl

In [275]:
def get_recs(jsonl_file):
    recs = []
    with jsonlines.open(jsonl_file) as reader:
        for obj in reader:
            recs.append(obj)
    return recs

In [284]:
dev_recs = get_recs("corefhoi/data/dev.english.512.jsonlines")
bourne_recs = get_recs("corefhoi/data/bourne.english.512.jsonlines")

In [285]:
len(dev_recs), len(bourne_recs)

(1029, 1)

In [286]:
dev_rec = dev_recs[0]
bourne_rec = bourne_recs[0]

In [287]:
print(dev_rec.keys())
print(bourne_rec.keys())

dict_keys(['doc_key', 'tokens', 'sentences', 'speakers', 'constituents', 'ner', 'clusters', 'sentence_map', 'subtoken_map', 'pronouns'])
dict_keys(['doc_key', 'tokens', 'sentences', 'speakers', 'constituents', 'ner', 'clusters', 'sentence_map', 'subtoken_map', 'pronouns'])


In [288]:
len(dev_rec['sentences']), len(bourne_rec["sentences"])

(2, 22)