In [1]:
import os
import json
import jsonlines
import nltk
import numpy as np
import pandas as pd
import re
import spacy
from spacy import displacy
import tqdm
import unidecode

In [2]:
screenplay_parse_file = ("/home/sbaruah_usc_edu/mica_text_coref/data/"
                         "movie_coref/parse.csv")
movie_and_raters_file = ("/home/sbaruah_usc_edu/mica_text_coref/data/"
                         "movie_coref/movies.txt")
screenplays_dir = ("/home/sbaruah_usc_edu/mica_text_coref/data/"
                   "movie_coref/screenplay")
annotations_dir = "/home/sbaruah_usc_edu/mica_text_coref/data/movie_coref/csv"
spacy_model = "en_core_web_sm"

In [3]:
# Initialize spacy model and movie data jsonlines
spacy.require_gpu()
nlp = spacy.load(spacy_model)
movie_data = []

# Read screenplay parse and movie names
movies, raters = [], []
parse_df = pd.read_csv(screenplay_parse_file, index_col=None)
with open(movie_and_raters_file) as fr:
    for line in fr:
        movie, rater = line.split()
        movies.append(movie)
        raters.append(rater)

In [36]:
movie_data = []

for movie, rater in tqdm.tqdm(zip(movies, raters), total=len(movies),
                              unit="movie"):
    # Read movie script, movie parse, and movie coreference annotations
    script_filepath = os.path.join(screenplays_dir, f"{movie}.txt")
    annotation_filepath = os.path.join(annotations_dir, f"{movie}.csv")
    with open(script_filepath, encoding="utf-8") as fr:
        script = fr.read()
    lines = script.splitlines()
    parsetags = parse_df[parse_df["movie"] == movie]["robust"].tolist()
    annotation_df = pd.read_csv(annotation_filepath, index_col=None)
    items = []
    for _, row in annotation_df.iterrows():
        begin, end, character = row["begin"], row["end"], row["entityLabel"]
        items.append((begin, end, character))
    items = sorted(items)

    # Find non-whitespace offset of coreference annotations
    begins, ends, characters, wsbegins, wsends = [], [], [], [], []
    prev_begin, prev_wsbegin = 0, 0
    for i, (begin, end, character) in enumerate(items):
        if i == 0:
            wsbegin = len(re.sub("\s", "", script[:begin]))
        else:
            wsbegin = prev_wsbegin + len(re.sub("\s", "", script[prev_begin: begin]))
        prev_begin, prev_wsbegin = begin, wsbegin
        wsend = wsbegin + len(re.sub("\s", "", script[begin: end]))
        begins.append(begin)
        ends.append(end)
        characters.append(character)
        wsbegins.append(wsbegin)
        wsends.append(wsend)

    # Find segments (blocks of adjacent lines with same movieparser tags)
    i = 0
    segment_texts, segment_tags = [], []
    while i < len(lines):
        j = i + 1
        while j < len(lines) and parsetags[j] == parsetags[i]:
            j += 1
        segment = re.sub("\s+", " ", " ".join(lines[i: j]).strip())
        segment = (" ".join(nltk.wordpunct_tokenize(segment))).strip()
        segment = re.sub("\s+", " ", segment.strip())
        if segment:
            segment_texts.append(segment)
            segment_tags.append(parsetags[i])
        i = j

    # Run each segment through spacy pipeline
    docs = nlp.pipe(segment_texts, batch_size=10200)

    # Tokenize each spacy token using nltk.wordpunct_tokenizer
    # Find tokens, token sentence ids, and token movieparser tags
    (tokens, token_heads, token_postags, token_nertags, token_begins,
        token_ends, token_tags, token_sentids) = [], [], [], [], [], [], [], []
    c, s, n = 0, 0, 0
    for i, doc in enumerate(docs):
        for sent in doc.sents:
            for stoken in sent:
                text = stoken.text
                ascii_text = unidecode.unidecode(text, errors="strict")
                assert ascii_text != "", f"token={text}, ascii_text={ascii_text}"
                postag = stoken.tag_
                nertag = stoken.ent_type_
                if not nertag:
                    nertag = "-"
                token_begin = c
                c += len(re.sub("\s+", "", text))
                token_end = c
                token_sentid = s
                tokens.append(ascii_text)
                token_heads.append(n + stoken.head.i)
                token_begins.append(token_begin)
                token_ends.append(token_end)
                token_sentids.append(token_sentid)
                token_tags.append(segment_tags[i])
                token_postags.append(postag)
                token_nertags.append(nertag)
            n += len(sent)
            s += 1

    # Match mentions to tokens
    mention_begins, mention_ends = [], []
    for wsbegin, wsend, begin, end in zip(wsbegins, wsends, begins, ends):
        try:
            i = token_begins.index(wsbegin)
        except Exception:
            i = None
        try:
            j = token_ends.index(wsend)
        except Exception:
            mention = script[begin: end].rstrip()
            right_context = script[end:].lstrip()
            if mention.endswith(".") and right_context.startswith(".."):
                wsend -= 1
                try:
                    j = token_ends.index(wsend)
                except Exception:
                    j = None
            else:
                j = None
        if i is None or j is None:
            mention = script[begin: end]
            context = script[begin-10: end+10]
            print(f"mention = '{mention}'")
            print(f"context = '{context}'")
            if i is None:
                print("Could not match start of mention")
            if j is None:
                print("Could not match end of mention")
            print()
        mention_begins.append(i)
        mention_ends.append(j)

    # Create speakers array
    speakers = np.full(len(tokens), fill_value="-", dtype=object)
    i = 0
    while i < len(tokens):
        if token_tags[i] == "C":
            j = i + 1
            while j < len(tokens) and token_tags[j] == token_tags[i]:
                j += 1
            k = j
            utterance_token_indices = []
            while k < len(tokens) and token_tags[k] not in "SC":
                if token_tags[k] in "DE":
                    utterance_token_indices.append(k)
                k += 1
            if utterance_token_indices:
                speaker = " ".join(tokens[i: j])
                cleaned_speaker = re.sub("\([^\)]+\)", "", speaker).strip()
                speaker = cleaned_speaker if cleaned_speaker else speaker
                for l in utterance_token_indices:
                    speakers[l] = speaker
            i = k
        else:
            i += 1
    speakers = speakers.tolist()

    # Create character to mention and head
    clusters: dict[str, list[list[int]]] = {}
    for character, mention_begin, mention_end in zip(characters, mention_begins,
                                                     mention_ends):
        if character not in clusters:
            clusters[character] = []
        token_indexes_with_outside_head = []
        for i in range(mention_begin, mention_end + 1):
            head_index = token_heads[i]
            if (head_index == i or head_index < mention_begin or
                head_index > mention_end):
                token_indexes_with_outside_head.append(i)
        mention_head = mention_end
        if len(token_indexes_with_outside_head) == 1:
            mention_head = token_indexes_with_outside_head[0]
        clusters[character].append([mention_begin, mention_end, mention_head])
    
    # Find sentence offsets
    sentence_offsets: list[list[int]] = []
    i = 0
    while i < len(token_sentids):
        j = i + 1
        while j < len(token_sentids) and token_sentids[i] == token_sentids[j]:
            j += 1
        sentence_offsets.append([i, j])
        i = j

    # Create movie json
    movie_data.append({
        "movie": movie,
        "rater": rater,
        "token": tokens,
        "pos": token_postags,
        "ne": token_nertags,
        "parse": token_tags,
        "speaker": speakers,
        "sent_offset": sentence_offsets,
        "clusters": clusters
    })

100%|██████████| 9/9 [00:23<00:00,  2.59s/movie]


In [37]:
def remove_characters(movie_data: list[dict[str, any]]) -> list[dict[str, any]]:
    '''Removes character names preceding an utterance
    '''
    # Initialize new movie data
    new_movie_data = []

    # Loop over movies
    tbar = tqdm.tqdm(movie_data, total=len(movie_data), unit="movie")
    for mdata in tbar:
        (movie, rater, tokens, postags, nertags, parsetags, sentence_offsets,
         speakers, clusters) = (
            mdata["movie"], mdata["rater"], mdata["token"], mdata["pos"],
            mdata["ne"], mdata["parse"], mdata["sent_offset"], 
            mdata["speaker"], mdata["clusters"])
        tbar.set_description(movie)

        # removed[x] is the number of tokens to remove from tokens[:x]
        # if tags[i: j] == "C" and is followed by some utterance, 
        # then we should remove tokens[i: j]
        # removed[: i] remains unchanged
        # removed[i: j] = -1
        # removed[j:] += j - i
        removed = np.zeros(len(tokens), dtype=int)
        i = 0
        while i < len(tokens):
            if parsetags[i] == "C":
                j = i + 1
                while j < len(tokens) and parsetags[j] == parsetags[i]:
                    j += 1
                k = j
                utterance_token_indices = []
                while k < len(tokens) and parsetags[k] not in "SC":
                    if parsetags[k] in "DE":
                        utterance_token_indices.append(k)
                    k += 1
                if utterance_token_indices:
                    removed[i: j] = -1
                    removed[j:] += j - i
                i = k
            else:
                i += 1

        # Find the new tokens, pos tags, ner tags, parse tags, and speakers
        newtokens, newpostags, newnertags, newparsetags, newspeakers = (
            [], [], [], [], [])
        i = 0
        while i < len(tokens):
            if removed[i] != -1:
                newtokens.append(tokens[i])
                newpostags.append(postags[i])
                newnertags.append(nertags[i])
                newparsetags.append(parsetags[i])
                newspeakers.append(speakers[i])
            i += 1

        # Find new sentence offsets
        new_sentence_offsets = []
        for i, j in sentence_offsets:
            assert all(removed[i: j] == -1) or all(removed[i: j] != -1), (
                "All tokens or none of the tokens of a sentence should be "
                "removed")
            if all(removed[i: j] != -1):
                i = i - removed[i]
                j = j - removed[i]
                new_sentence_offsets.append([i, j])

        # Find new clusters
        new_clusters: dict[str, list[list[int]]] = {}
        for character, mentions in clusters.items():
            new_mentions = []
            for begin, end, head in mentions:
                assert (all(removed[begin: end + 1] == -1) or
                        all(removed[begin: end + 1] != -1)), (
                            "All tokens or none of the tokens of a mention"
                            f" should be removed, mention = [{begin},{end},{head}]")
                if all(removed[begin: end + 1] != -1):
                    begin = begin - removed[begin]
                    end = end - removed[end]
                    head = head - removed[head]
                    new_mentions.append([begin, end, head])
            if new_mentions:
                new_clusters[character] = new_mentions

        # Create movie json
        new_movie_data.append({
            "movie": movie,
            "rater": rater,
            "token": newtokens,
            "pos": newpostags,
            "ne": newnertags,
            "parse": newparsetags,
            "speaker": newspeakers,
            "sent_offset": new_sentence_offsets,
            "clusters": new_clusters
        })

    return new_movie_data

In [39]:
movie_nocharacters_data = remove_characters(movie_data)

basterds: 100%|██████████| 9/9 [00:00<00:00, 15.42movie/s]          


In [40]:
def add_says(movie_data: list[dict[str, any]]) -> list[dict[str, any]]:
    '''
    Inserts 'says' between character name and utterance block. Give the token
    'says' a unique tag `A`.
    '''
    # Initialize new movie data
    new_movie_data = []

    # Loop over each movie
    tbar = tqdm.tqdm(movie_data, total=len(movie_data), unit="movie")
    for mdata in tbar:
        (movie, rater, tokens, postags, nertags, parsetags, sentence_offsets,
         speakers, clusters) = (
            mdata["movie"], mdata["rater"], mdata["token"], mdata["pos"],
            mdata["ne"], mdata["parse"], mdata["sent_offset"], 
            mdata["speaker"], mdata["clusters"])
        tbar.set_description(movie)

        # added[x] is the number of 'says' added in tokens[:x]
        added = np.zeros(len(tokens), dtype=int)
        i = 0
        while i < len(tokens):
            if parsetags[i] == "C":
                j = i + 1
                while j < len(tokens) and parsetags[j] == parsetags[i]:
                    j += 1
                k = j
                utterance_token_indices = []
                while k < len(tokens) and parsetags[k] not in "SC":
                    if parsetags[k] in "DE":
                        utterance_token_indices.append(k)
                    k += 1
                if utterance_token_indices:
                    added[j:] += 1
                i = k
            else:
                i += 1

        # Find new tokens, pos tags, ner tags, parse tags, and speakers
        newtokens, newpostags, newnertags, newparsetags, newspeakers = (
            [], [], [], [], [])
        i = 0
        while i < len(tokens):
            newtokens.append(tokens[i])
            newpostags.append(postags[i])
            newnertags.append(nertags[i])
            newparsetags.append(parsetags[i])
            newspeakers.append(speakers[i])
            if i < len(tokens) - 1 and added[i] < added[i + 1]:
                newtokens.append("says")
                newpostags.append("VBZ")
                newnertags.append("-")
                newparsetags.append("A")
                newspeakers.append("-")
            i += 1

        # Find new sentence offsets
        new_sentence_offsets = []
        while k < len(sentence_offsets):
            i, j = sentence_offsets[k]
            if k < len(sentence_offsets) - 1 and added[j - 1] < added[j] and (
                parsetags[j] in "DE"):
                k = k + 1
                j = sentence_offsets[k][1]
            i = i + added[i]
            j = j + added[j]
            k = k + 1
            new_sentence_offsets.append([i, j])

        # Find new clusters
        new_clusters: dict[str, list[list[int]]] = {}
        for character, mentions in clusters.items():
            new_mentions = []
            for begin, end, head in mentions:
                begin = begin + added[begin]
                end = end + added[end]
                head = head + added[head]
                new_mentions.append([begin, end, head])
            new_clusters[character] = new_mentions

        # Create the new movie json
        new_movie_data.append({
            "movie": movie,
            "rater": rater,
            "token": newtokens,
            "pos": newpostags,
            "ne": newnertags,
            "parse": newparsetags,
            "speaker": newspeakers,
            "sent_offset": new_sentence_offsets,
            "clusters": new_clusters
        })

    return new_movie_data

In [41]:
# Insert 'says' after tokens with movieparse tag = "C"
movie_addsays_data = add_says(movie_data)

basterds: 100%|██████████| 9/9 [00:00<00:00, 33.13movie/s]   


In [44]:
json.dump(movie_data, open(
    "/home/sbaruah_usc_edu/mica_text_coref/data/temp/movie_data.json", "w"),
    indent=2)

In [45]:
def convert_offsets_to_list(offsets: list[list[int]]) -> list[int]:
    ids = []
    for i, j in offsets:
        ids.append([k for k in range(j - i)])
    return ids

def prepare_for_wlcoref(movie_data: list[dict[str, any]]) -> (
    list[dict[str, any]]):
    """Convert to jsonlines format which can be used as input to the word-level
    coreference model.
    """
    new_movie_data = []
    for mdata in movie_data:
        new_movie_data.append({
            "document_id": f"wb/{mdata['movie']}",
            "cased_words": mdata["token"],
            "sent_id": convert_offsets_to_list(mdata["sent_offset"]),
            "speaker": mdata["speaker"]
        })
    return new_movie_data

In [46]:
wl_movie_data = prepare_for_wlcoref(movie_data)
wl_movie_addsays_data = prepare_for_wlcoref(movie_addsays_data)
wl_movie_nocharacters_data = prepare_for_wlcoref(movie_nocharacters_data)

In [4]:
class Mention:

    def __init__(self, begin: int, end: int, head: int | None) -> None:
        self.begin = begin
        self.end = end
        self.head = head

    def __hash__(self) -> int:
        return hash((self.begin, self.end))

    def __lt__(self, other: "Mention") -> bool:
        return (self.begin, self.end) < (other.begin, other.end)

    def __repr__(self) -> str:
        return f"({self.begin},{self.end})"

class CorefDocument:

    def __init__(self, json: dict[str, any]) -> None:
        self.movie: str = json["movie"]
        self.rater: str = json["rater"]
        self.token: list[str] = json["token"]
        self.parse: list[str] = json["parse"]
        self.pos: list[str] = json["pos"]
        self.ner: list[str] = json["ner"]
        self.speaker: list[str] = json["speaker"]
        self.sentence_offsets: list[tuple[int, int]] = json["sent_offset"]
        self.clusters: dict[str, set[Mention]] = {}
        for character, mentions in json["clusters"].items():
            mentions = set([Mention(*x) for x in mentions])
            self.clusters[character] = mentions
    
    def __repr__(self) -> str:
        desc = "Script\n=====\n\n"
        for i, j in self.sentence_offsets:
            sentence = self.token[i: j]
            desc += f"{sentence}\n"
        desc += "\n\nClusters\n========\n\n"
        for character, mentions in self.clusters.items():
            desc += f"{character}\n"
            sorted_mentions = sorted(mentions)
            mention_texts = []
            for mention in sorted_mentions:
                mention_text = " ".join(
                    self.token[mention.begin: mention.end + 1])
                mention_head = self.token[mention.head]
                mention_texts.append(f"{mention_text} ({mention_head})")
            n_rows = math.ceil(len(mention_texts)/3)
            for i in range(n_rows):
                row_mention_texts = mention_texts[i * 3: (i + 1) * 3]
                row_desc = "     ".join(
                    [f"{mention_text:25s}" for mention_text in row_mention_texts])
                desc += row_desc + "\n"
            desc += "\n"
        return desc

class CorefCorpus:

    def __init__(self, file: str | None = None) -> None:
        self.documents: list[CorefDocument] = []
        if file is not None:
            with jsonlines.open(file) as reader:
                for json in reader:
                    self.documents.append(CorefDocument(json))
    
    def __len__(self) -> int:
        return len(self.documents)

    def __getitem__(self, i) -> CorefDocument:
        return self.documents[i]

In [3]:
with jsonlines.open("/home/sbaruah_usc_edu/mica_text_coref/data/movie_coref/results/regular/movie.jsonlines") as reader:
    for obj in reader:
        print(obj["movie"])

avengers_endgame
dead_poets_society
john_wick
prestige
quiet_place
zootopia
shawshank
bourne
basterds


In [5]:
corpus = CorefCorpus("/home/sbaruah_usc_edu/mica_text_coref/data/movie_coref/results/regular/movie.jsonlines")

In [6]:
len(corpus.documents)

9