In [7]:
from mica_text_coref.coref.movie_coref import data

from collections import Counter
import jsonlines
import math
import numpy as np
import re
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import (RobertaTokenizerFast, RobertaTokenizer, RobertaModel,
                          BertTokenizerFast, BertTokenizer, BertModel)
from transformers import AutoTokenizer

In [2]:
roberta_fast_tokenizer = AutoTokenizer.from_pretrained(
    "roberta-large", use_fast=True)
bert_fast_tokenizer = AutoTokenizer.from_pretrained(
    "bert-large-cased", use_fast=True)

In [3]:
corpus = data.CorefCorpus(("/home/sbaruah_usc_edu/mica_text_coref/data/"
                           "movie_coref/results/regular/movie.jsonlines"))

In [4]:
parse_texts = ["".join(document.parse) for document in corpus]

In [8]:
Counter(list("".join(parse_texts)))

Counter({'S': 9608,
         'N': 101599,
         'C': 11794,
         'D': 73447,
         'E': 3637,
         'O': 1214,
         'T': 355,
         'M': 150})

In [11]:
"SNCDE".index("C")

2

In [11]:
text = ("The model can behave as an encoder (with only self-attention) as well "
"as a decoder, in which case a layer of cross-attention is added between the "
"self-attention layers, following the architecture described in Attention is all "
"you need_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion "
"Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.")
print(text)

The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in Attention is all you need_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.


In [15]:
roberta_output = roberta_fast_tokenizer(text, return_offsets_mapping=True)
bert_output = bert_fast_tokenizer(text, return_offsets_mapping=True)
print(roberta_output)
print(bert_output)

{'input_ids': [0, 133, 1421, 64, 18871, 25, 41, 9689, 15362, 36, 5632, 129, 1403, 12, 2611, 19774, 43, 25, 157, 25, 10, 5044, 15362, 6, 11, 61, 403, 10, 10490, 9, 2116, 12, 2611, 19774, 16, 355, 227, 5, 1403, 12, 2611, 19774, 13171, 6, 511, 5, 9437, 1602, 11, 35798, 16, 70, 47, 240, 1215, 30, 4653, 1173, 12599, 605, 1543, 6, 440, 424, 840, 10129, 254, 6, 234, 8907, 30471, 271, 6, 18493, 2413, 4890, 329, 330, 1688, 405, 6, 12655, 1499, 1454, 6, 11572, 260, 234, 4, 11507, 6, 10920, 281, 329, 20838, 8, 12285, 493, 6189, 366, 1350, 11040, 4, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 3), (4, 9), (10, 13), (14, 20), (21, 23), (24, 26), (27, 30), (30, 34), (35, 36), (36, 40

In [16]:
print("roberta:")
for input_id, (begin, end) in zip(
    roberta_output["input_ids"], roberta_output["offset_mapping"]):
    token = text[begin: end]
    print(f"input_id = {input_id}, token = '{token}'")
print()
print("bert:")
for input_id, (begin, end) in zip(
    bert_output["input_ids"], bert_output["offset_mapping"]):
    token = text[begin: end]
    print(f"input_id = {input_id}, token = '{token}'")

roberta:
input_id = 0, token = ''
input_id = 133, token = 'The'
input_id = 1421, token = 'model'
input_id = 64, token = 'can'
input_id = 18871, token = 'behave'
input_id = 25, token = 'as'
input_id = 41, token = 'an'
input_id = 9689, token = 'enc'
input_id = 15362, token = 'oder'
input_id = 36, token = '('
input_id = 5632, token = 'with'
input_id = 129, token = 'only'
input_id = 1403, token = 'self'
input_id = 12, token = '-'
input_id = 2611, token = 'att'
input_id = 19774, token = 'ention'
input_id = 43, token = ')'
input_id = 25, token = 'as'
input_id = 157, token = 'well'
input_id = 25, token = 'as'
input_id = 10, token = 'a'
input_id = 5044, token = 'dec'
input_id = 15362, token = 'oder'
input_id = 6, token = ','
input_id = 11, token = 'in'
input_id = 61, token = 'which'
input_id = 403, token = 'case'
input_id = 10, token = 'a'
input_id = 10490, token = 'layer'
input_id = 9, token = 'of'
input_id = 2116, token = 'cross'
input_id = 12, token = '-'
input_id = 2611, token = 'att'
inpu

In [24]:
texts = ["""Daenerys had dismissed Tyrion long ago, and now she stood before one of the windows in her private bedchamber, staring out onto the open sea. It was pitch black outside, though the occasional strikes of lightning lit up the sky and revealed the rough waves crashing against the shore.""",
"""It was cold, despite the fire raging in the hearth, though she hardly felt it wrapped in her cloak. She had owned it for barely a day and yet it had already become her favorite item in her entire wardrobe, easily displacing anything she had seen made for her arrival to Westeros.""",
"""She wrapped it tighter around her, the softness of the fur that lined it doing a great deal to comfort her, though after a while all it did was make her think of the man who had given her the cloak in the first place."""]
print(texts)

['Daenerys had dismissed Tyrion long ago, and now she stood before one of the windows in her private bedchamber, staring out onto the open sea. It was pitch black outside, though the occasional strikes of lightning lit up the sky and revealed the rough waves crashing against the shore.', 'It was cold, despite the fire raging in the hearth, though she hardly felt it wrapped in her cloak. She had owned it for barely a day and yet it had already become her favorite item in her entire wardrobe, easily displacing anything she had seen made for her arrival to Westeros.', 'She wrapped it tighter around her, the softness of the fur that lined it doing a great deal to comfort her, though after a while all it did was make her think of the man who had given her the cloak in the first place.']


In [38]:
class CharacterRecognitionDataset(Dataset):
    """PyTorch dataset for the character recognition model.
    """
    def __init__(self, corpus: data.CorefCorpus, tokenizer: PreTrainedTokenizer,
                 seq_length: int, obey_scene_boundaries: bool) -> None:
        """Initializer for the character recognition model data.

        Args:
            corpus: CorefCorpus
            tokenizer: Transformer tokenizer
            seq_length: maximum number of tokens (not sub-tokens) in a sequence
            obey_scene_boundaries: if true, sequences do not cross scene
            boundaries
        """
        super().__init__()
        tokens_list: list[list[str]] = []
        labels_list: list[torch.IntTensor] = []
        parse_tags_list: list[torch.IntTensor] = []
        parse_tag_set = ["S", "N", "C", "D"]

        for document in corpus:
            tokens = document.token
            labels = torch.zeros(len(tokens), dtype=int)
            parse_tags = document.parse
            for mentions in document.clusters.values():
                for mention in mentions:
                    labels[mention.head] = 1

            # Find token-level scene boundaries
            if obey_scene_boundaries:
                scene_boundaries = np.zeros(len(tokens), dtype=int)
                found_content_tag = False
                i = 0
                while i < len(tokens):
                    if parse_tags[i] == "S":
                        if found_content_tag:
                            scene_boundaries[i] = 1
                        found_content_tag = False
                        while i < len(tokens) and parse_tags[i] == "S":
                            i += 1
                    else:
                        if parse_tags[i] in "ND":
                            found_content_tag = True
                        i += 1

            # Create movie parse tensors
            parse_tag_tensor = torch.zeros(len(tokens), dtype=int)
            for i, tag in enumerate(parse_tags):
                try:
                    parse_tag_tensor[i] = parse_tag_set.index(tag) + 1
                except Exception:
                    pass

            # Segment document into sequences
            i = 0
            while i < len(tokens):
                end = i + seq_length
                if obey_scene_boundaries and (
                   np.any(scene_boundaries[i: end] == 1)):
                    end = i + np.nonzero(scene_boundaries[i: end] == 1)[0][0] + 1
                tokens_list.append(tokens[i: end])
                labels_list.append(labels[i: end])
                parse_tags_list.append(parse_tag_tensor[i : end])
                i = end

        # Find token character offsets
        token_char_offset_list: list[list[tuple[int, int]]] = []
        text_list: list[str] = []
        max_n_tokens_per_sequence = -np.inf
        for tokens in tokens_list:
            token_char_offset: list[tuple[int, int]] = []
            c = 0
            for token in tokens:
                token_char_offset.append((c, c + len(token)))
                c += len(token) + 1
            token_char_offset_list.append(token_char_offset)
            text_list.append(" ".join(tokens))
            max_n_tokens_per_sequence = max(max_n_tokens_per_sequence,
                                            len(tokens))

        # Encode
        encoding = tokenizer(text_list, padding="longest", return_tensors="pt",
                             return_offsets_mapping=True,
                             return_attention_mask=True)
        
        # Find token to subtoken offset
        token_offset = torch.zeros(
            (len(text_list), max_n_tokens_per_sequence, 2), dtype=int)
        for i, (subtoken_char_offset, token_char_offset, attention_mask) in (
            enumerate(zip(encoding["offset_mapping"], token_char_offset_list,
                          encoding["attention_mask"]))):
            j, k = 0, 0
            n_subtokens = attention_mask.sum()
            while j < n_subtokens and k < len(token_char_offset):
                if subtoken_char_offset[j, 0] == subtoken_char_offset[j, 1]:
                    j += 1
                elif subtoken_char_offset[j, 0] == token_char_offset[k][0]:
                    l = j
                    while l < n_subtokens and (
                        subtoken_char_offset[l, 1] != token_char_offset[k][1]):
                        l += 1
                    token_offset[i, k, 0] = j
                    token_offset[i, k, 1] = l
                    k += 1
                    j = l + 1
                else:
                    assert False, "Something went wrong!"

        self.subtoken_ids = encoding["input_ids"]
        self.attention_mask = encoding["attention_mask"]
        self.token_offset = token_offset
        self.parse_tag_ids = pad_sequence(parse_tags_list, batch_first=True,
                                          padding_value=0)
        self.label_ids = pad_sequence(labels_list, batch_first=True,
                                      padding_value=0)
    
    def __len__(self) -> int:
        return len(self.subtoken_ids)
    
    def __getitem__(self, i: int) -> (
        tuple[torch.LongTensor, torch.FloatTensor, torch.LongTensor,
              torch.LongTensor, torch.LongTensor]):
        return (self.subtoken_ids[i], self.attention_mask[i],
                self.token_offset[i], self.parse_tag_ids[i], self.label_ids[i])

In [39]:
ds = CharacterRecognitionDataset(corpus, roberta_fast_tokenizer, seq_length=256,
                                 obey_scene_boundaries=True)

In [41]:
for tensor_tuple in ds:
    for tensor in tensor_tuple:
        print(f"{tensor.dtype} {tensor.shape}", end=",")
    print()

torch.int64 torch.Size([482]),torch.int64 torch.Size([482]),torch.int64 torch.Size([256, 2]),torch.int64 torch.Size([256]),torch.int64 torch.Size([256]),
torch.int64 torch.Size([482]),torch.int64 torch.Size([482]),torch.int64 torch.Size([256, 2]),torch.int64 torch.Size([256]),torch.int64 torch.Size([256]),
torch.int64 torch.Size([482]),torch.int64 torch.Size([482]),torch.int64 torch.Size([256, 2]),torch.int64 torch.Size([256]),torch.int64 torch.Size([256]),
torch.int64 torch.Size([482]),torch.int64 torch.Size([482]),torch.int64 torch.Size([256, 2]),torch.int64 torch.Size([256]),torch.int64 torch.Size([256]),
torch.int64 torch.Size([482]),torch.int64 torch.Size([482]),torch.int64 torch.Size([256, 2]),torch.int64 torch.Size([256]),torch.int64 torch.Size([256]),
torch.int64 torch.Size([482]),torch.int64 torch.Size([482]),torch.int64 torch.Size([256, 2]),torch.int64 torch.Size([256]),torch.int64 torch.Size([256]),
torch.int64 torch.Size([482]),torch.int64 torch.Size([482]),torch.int64 torc

In [42]:
print(f"subtoken ids   = {ds.subtoken_ids.dtype} {ds.subtoken_ids.shape}")
print(f"attention mask = {ds.attention_mask.dtype} {ds.attention_mask.shape}")
print(f"token offset   = {ds.token_offset.dtype} {ds.token_offset.shape}")
print(f"parse tag ids  = {ds.parse_tag_ids.dtype} {ds.parse_tag_ids.shape}")
print(f"label ids      = {ds.label_ids.dtype} {ds.label_ids.shape}")

subtoken ids   = torch.int64 torch.Size([1553, 482])
attention mask = torch.int64 torch.Size([1553, 482])
token offset   = torch.int64 torch.Size([1553, 256, 2])
parse tag ids  = torch.int64 torch.Size([1553, 256])
label ids      = torch.int64 torch.Size([1553, 256])


In [44]:
for subtoken_ids, attention_mask, token_offset, parse_tag_ids, label_ids in (
    zip(ds.subtoken_ids, ds.attention_mask, ds.token_offset, ds.parse_tag_ids,
        ds.label_ids)):
    subtokens = roberta_fast_tokenizer.convert_ids_to_tokens(subtoken_ids)
    tokens = []
    i = 0
    while i < len(token_offset):
        j, k = token_offset[i]
        if j == k == 0:
            break
        _subtokens = subtokens[j: k + 1]
        tokens.append("".join(_subtokens))
        i += 1
    print(subtokens)
    print(tokens)
    print()

['<s>', 'F', 'ĠO', 'ĠR', 'ĠY', 'ĠO', 'ĠU', 'ĠR', 'ĠC', 'ĠO', 'ĠN', 'ĠS', 'ĠI', 'ĠD', 'ĠE', 'ĠR', 'ĠAT', 'ĠI', 'ĠO', 'ĠN', 'ĠBEST', 'ĠAD', 'AP', 'TED', 'ĠSC', 'RE', 'EN', 'PLAY', 'ĠChristopher', 'ĠMarkus', 'Ġ&', 'ĠStephen', 'ĠMcF', 'eely', 'ĠAV', 'ENG', 'ERS', 'Ġ:', 'ĠEND', 'GAME', 'ĠAdapt', 'ed', 'ĠScreen', 'play', 'ĠWritten', 'Ġby', 'ĠChristopher', 'ĠMarkus', 'Ġand', 'ĠStephen', 'ĠMcF', 'eely', 'ĠEXT', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', 