In [1]:
import os
import json
import math
import torch
import numpy as np
from collections import Counter, namedtuple
import matplotlib.pyplot as plt
from transformers import BertTokenizer
from keras.preprocessing.sequence import pad_sequences

from torch import nn
from transformers import BertModel
from torch.optim import AdamW

from typing import List, Tuple
import config
from copy import deepcopy, copy
import importlib
from tqdm import tqdm

In [8]:
config = importlib.reload(config)

In [3]:
FrameBatch = namedtuple("FrameBatch", "sentence_ids, sentences, sentence_token_ids, mask, frames")

def tokenize(tokenizer, tokens) -> Tuple[List[str], List[int]]:
    '''
    tokenizer is pretrained transformer model tokenizer
    tokens is list of str

    return:
        new tokens  : list of str
        mapping     : list of int
    
    prepend [AUTHOR] to new tokens
    if old tokens has m words and new tokens has n words,
    then mapping is m + 2 dimensional and each value lies between 0 and n

    mapping[i] = j means that ith word of old tokens begins at jth
    position in new tokens
    '''
    mapping = np.zeros(len(tokens) + 2, dtype=int)
    new_tokens = ["[AUTHOR]"]

    for i, token in enumerate(tokens):
        ttokens = tokenizer.tokenize(token)
        new_tokens.extend(ttokens)
        mapping[i + 2] = mapping[i + 1] + len(ttokens)

    mapping[1:] += 1
    return new_tokens, mapping

def create_frame_label(mapping, frame_tuples, max_seq_len) -> np.ndarray:
    '''
    mapping         : list of int
    frame_tuples    : set of frame tuples
                      each tuple is 6-dimensional: holder, predicate, and target
                      correct it by adding 1
    max_seq_len     : int
                      max sentence length
                      if an argument's span exceeds max_seq_len, ignore the frame
    
    return:
        frame label : numpy int array B x L x F
                      B is the number of frame tuples
                      L = max_seq_len
                      F = 3 because of 3 arguments: holder, predicate, and target
                      frame label is 0 = O, 1 = B, or 2 = I
    '''
    frame_labels = []

    for frame_tuple in frame_tuples:
        frame_label = np.zeros((max_seq_len, 3), dtype=int)
        if mapping[max(frame_tuple[1], frame_tuple[3], frame_tuple[5])] >= max_seq_len:
            continue

        for j in range(mapping[frame_tuple[0] + 1], mapping[frame_tuple[1] + 1]):
            if j == mapping[frame_tuple[0] + 1]:
                frame_label[j, 0] = 1
            else:
                frame_label[j, 0] = 2
        
        for j in range(mapping[frame_tuple[2] + 1], mapping[frame_tuple[3] + 1]):
            if j == mapping[frame_tuple[2] + 1]:
                frame_label[j, 1] = 1
            else:
                frame_label[j, 1] = 2
        
        for j in range(mapping[frame_tuple[4] + 1], mapping[frame_tuple[5] + 1]):
            if j == mapping[frame_tuple[4] + 1]:
                frame_label[j, 2] = 1
            else:
                frame_label[j, 2] = 2

        frame_labels.append(frame_label)
    
    frame_label = np.array(frame_labels)
    return frame_label

class FrameDataset:

    def __init__(self, docids, train=True) -> None:
        '''
        params:
            docids      : list of str docids
            train       : bool
            sentence_ids: numpy array of shape D x 2
                          file name (str) and sentence index (int)
            sentences   : str numpy array of shape D x L
                          padded tokens are [PAD]
            sentence_token_ids
                        : torch long tensor of shape D x L
            mask        : torch float tensor of shape D x L
            frames      : torch long tensor of shape D x L x F
        
        D is number of direct subjective frames
        L is max seq length
        F = 3 for number of argument types: holder, predicate, target

        if train is true, sentence with no direct subjective frame are excluded
        else, they are included exactly once
        '''
        self.docids = docids
        self.train = train

        sentences = []
        sentence_ids = []
        sentence_token_ids = []
        frames = []
        tokenizer = BertTokenizer.from_pretrained(config.embed_model_name)
        tokenizer.vocab["[AUTHOR]"] = tokenizer.vocab.pop("[unused1]")
        
        for docid in self.docids:
            doc = json.load(open(os.path.join(config.DATA_FOLDER, "mpqa2-processed", docid, "tokenized.json")))

            for i, sentence in enumerate(doc):
                
                frame_tuples = set()
                for frame in sentence["dse"]:
                    if isinstance(frame["dse-span"], list) and isinstance(frame["target-span"], list) and (isinstance(frame["holder-span"], list) or frame["holder-type"] in ["writer", "implicit"]):
                        if isinstance(frame["holder-span"], list):
                            frame_tuples.add(tuple(frame["holder-span"] + frame["dse-span"] + frame["target-span"]))
                        else:
                            frame_tuples.add(tuple([-1, 0] + frame["dse-span"] + frame["target-span"]))
                
                new_tokens, mapping = tokenize(tokenizer, sentence["tokens"])
                frame = create_frame_label(mapping, frame_tuples, config.max_seq_len)
                if frame.size:
                    for _ in range(frame.shape[0]):
                        sentences.append(new_tokens)
                        sentence_ids.append([docid, i])
                        sentence_token_ids.append(tokenizer.convert_tokens_to_ids(new_tokens))
                    frames.append(frame)
                elif not self.train:
                    sentences.append(new_tokens)
                    sentence_ids.append([docid, i])
                    sentence_token_ids.append(tokenizer.convert_tokens_to_ids(new_tokens))
                    frames.append(np.zeros((1, config.max_seq_len, 3), dtype=int))
        
        self.sentences = pad_sequences(sentences, maxlen=config.max_seq_len, padding="post", truncating="post", value="[PAD]", dtype=object).astype(str)

        self.sentence_ids = np.array(sentence_ids)

        sentence_token_ids = pad_sequences(sentence_token_ids, maxlen=config.max_seq_len, padding="post", truncating="post", value=tokenizer.vocab["[PAD]"])
        self.sentence_token_ids = torch.LongTensor(sentence_token_ids)
        
        mask = np.zeros(self.sentence_token_ids.shape, dtype=float)
        for i, sentence in enumerate(self.sentences):
            mask[i, : len(sentence)] = 1
        self.mask = torch.FloatTensor(mask)

        self.frames = torch.LongTensor(np.vstack(frames))

    def to(self, device):
        self.sentence_token_ids = self.sentence_token_ids.to(device)
        self.mask = self.mask.to(device)
        self.frames = self.frames.to(device)
    
    def print_frame(self, i):
        print("SENTENCE ID : {}".format(self.sentence_ids[i]))
        print("TOKENS : {}".format(self.sentences[i]))
        frame = self.frames[i].cpu().numpy()
        if frame.sum() > 0:
            holder = " ".join(self.sentences[i][frame[:, 0] != 0])
            predicate = " ".join(self.sentences[i][frame[:, 1] != 0])
            target = " ".join(self.sentences[i][frame[:, 2] != 0])
            print("HOLDER = '{}' PREDICATE = '{}' TARGET = '{}'".format(holder, predicate, target))
        else:
            print("EMPTY FRAME")

class FrameIterator:

    def __init__(self, ds: FrameDataset, batch_size: int, shuffle_batch=True, shuffle_sample=True) -> None:
        '''
        params:
            ds          : FrameDataset
            batch_size  : int
            shuffle_batch : bool
                          if shuffle_batch is true, shuffle the order of batches
            shuffle_sample : bool
                          if shuffle_sample is true, shuffle the order of samples in adjacent batches
        '''
        self.ds = ds
        self.batch_size = batch_size
        self.shuffle_batch = shuffle_batch
        self.shuffle_sample = shuffle_sample
        sentence_lens = (ds.sentences != "[PAD]").sum(axis = 1)
        self.sorted_index = np.argsort(sentence_lens)
        self.n_batches = math.ceil(len(self.sorted_index)/self.batch_size)
    
    def __iter__(self):
        if self.shuffle_batch:
            self.batch_sequence = np.random.permutation(self.n_batches)
        else:
            self.batch_sequence = np.arange(self.n_batches)
        self.index = self.sorted_index.copy()
        if self.shuffle_sample:
            for i in range(self.n_batches - 2):
                np.random.shuffle(self.index[i * self.batch_size: (i + 3) * self.batch_size])
        self.i = 0
        return self
    
    def __next__(self) -> FrameBatch:
        if self.i == self.n_batches:
            raise StopIteration
        i = self.batch_sequence[self.i]
        batch_index = self.index[i * self.batch_size : (i + 1) * self.batch_size]
        batch_sentences = self.ds.sentences[batch_index]
        batch_sentence_token_ids = self.ds.sentence_token_ids[batch_index]
        batch_sentence_ids = self.ds.sentence_ids[batch_index]
        batch_mask = self.ds.mask[batch_index]
        batch_frames = self.ds.frames[batch_index]
        self.i += 1
        return FrameBatch(batch_sentence_ids, batch_sentences, batch_sentence_token_ids, batch_mask, batch_frames)
    
    def __len__(self) -> int:
        return self.n_batches

In [9]:
mpqa2_filelist = open(os.path.join(config.RAW_FOLDER, "database.mpqa.2.0/doclist.attitudeSubset")).read().splitlines()
train_filelist = open(os.path.join(config.SRL4ORL_SPLITS_FOLDER, "filelist_train0")).read().splitlines()
test_filelist = open(os.path.join(config.SRL4ORL_SPLITS_FOLDER, "filelist_test0")).read().splitlines()
dev_filelist = open(os.path.join(config.SRL4ORL_SPLITS_FOLDER, "filelist_dev")).read().splitlines()
device = torch.device("cuda:0")

In [10]:
train_ds = FrameDataset(train_filelist, train=True)
dev_ds = FrameDataset(dev_filelist, train=False)

train_ds.to(device)
dev_ds.to(device)

print("train sentences            : {} {}".format(train_ds.sentences.shape, train_ds.sentences.dtype))
print("train sentence ids         : {} {}".format(train_ds.sentence_ids.shape, train_ds.sentence_ids.dtype))
print("train sentence token ids   : {} {}".format(train_ds.sentence_token_ids.shape, train_ds.sentence_token_ids.dtype))
print("train mask                 : {} {}".format(train_ds.mask.shape, train_ds.mask.dtype))
print("train frames               : {} {}".format(train_ds.frames.shape, train_ds.frames.dtype))

print()

print("dev sentences            : {} {}".format(dev_ds.sentences.shape, dev_ds.sentences.dtype))
print("dev sentence ids         : {} {}".format(dev_ds.sentence_ids.shape, dev_ds.sentence_ids.dtype))
print("dev sentence token ids   : {} {}".format(dev_ds.sentence_token_ids.shape, dev_ds.sentence_token_ids.dtype))
print("dev mask                 : {} {}".format(dev_ds.mask.shape, dev_ds.mask.dtype))
print("dev frames               : {} {}".format(dev_ds.frames.shape, dev_ds.frames.dtype))

train sentences            : (3193, 100) <U18
train sentence ids         : (3193, 2) <U29
train sentence token ids   : torch.Size([3193, 100]) torch.int64
train mask                 : torch.Size([3193, 100]) torch.float32
train frames               : torch.Size([3193, 100, 3]) torch.int64

dev sentences            : (2419, 100) <U16
dev sentence ids         : (2419, 2) <U29
dev sentence token ids   : torch.Size([2419, 100]) torch.int64
dev mask                 : torch.Size([2419, 100]) torch.float32
dev frames               : torch.Size([2419, 100, 3]) torch.int64


In [6]:
def shape_and_type(tensor):
    return "{} {}".format(tensor.dtype, tensor.shape)

def permutation_traversal(path, pathlen, nchoices, allpaths):
    if pathlen == 0:
        allpaths.append(path)
    else:
        for choice in range(nchoices[len(nchoices) - pathlen]):
            newpath = copy(path)
            newpath.append(choice)
            permutation_traversal(newpath, pathlen - 1, nchoices, allpaths)

def permutations(nchoices):
    all_possible_sequences = []
    permutation_traversal([], len(nchoices), nchoices, all_possible_sequences)
    return all_possible_sequences

class FrameExtractor(nn.Module):

    def __init__(self, hparams):
        super().__init__()
        self.n_frame_arguments = hparams.n_frame_arguments
        self.n_labels = hparams.n_labels

        # TODO: dropout for BertModel
        self.embedder = BertModel.from_pretrained(hparams.embed_model_name)
        self.embedding_size = self.embedder.config.hidden_size

        self.output_embedder = nn.Embedding(self.n_labels + 1, hparams.output_embedding_size)
        self.output_embedding_size = hparams.output_embedding_size

        self.encoder = nn.LSTM(self.embedding_size, hparams.encoder_hidden_size, 
                                    num_layers = hparams.encoder_num_layers, batch_first = True, 
                                    bidirectional = True, dropout = hparams.dropout)
        self.encoder_hidden_size = 2 * hparams.encoder_hidden_size

        # labels: B, I, O
        # arguments: holder, target, predicate
        self.decoder = nn.LSTM(self.encoder_hidden_size + self.n_frame_arguments * self.output_embedding_size, hparams.decoder_hidden_size, batch_first = True, dropout = hparams.dropout)
        self.decoder_hidden_size = hparams.decoder_hidden_size

        self.holder_predictor = nn.Linear(self.decoder_hidden_size, self.n_labels)
        self.predicate_predictor = nn.Linear(self.decoder_hidden_size, self.n_labels)
        self.target_predictor = nn.Linear(self.decoder_hidden_size, self.n_labels)

        self.beam_width = hparams.beam_width
        nchoices = [self.beam_width] + [self.n_labels for _ in range(self.n_frame_arguments)]
        self.all_permutations = permutations(nchoices)

    def forward(self, sentences, mask, frames=None):
        '''
        sentences   : long tensor of shape B x L
        mask        : float tensor of shape B x L
        frames      : long tensor of shape B x L x F

        B = batch size
        L = max seq len
        F = self.n_frame_arguments
        '''

        device = next(self.parameters()).device
        batch_size, seq_len = sentences.shape

        embedding = self.embedder(sentences, mask).last_hidden_state
        # embedding : float tensor of shape B x L x self.embedding_size

        encoding, _ = self.encoder(embedding)
        # encoding  : float tensor of shape B x L x self.encoder_hidden_size

        if frames is not None:
            # training code
            loss_function = nn.CrossEntropyLoss()
            logits_arr = []
            nextlabel_frames = torch.zeros(frames.shape, dtype=torch.long, device=device)
            nextlabel_frames[:, :-1] = frames[:, 1:]
            # nextlabel_frames : long tensor of shape B x L x F

            t = torch.full((batch_size, 1, self.n_frame_arguments), fill_value=self.n_labels, dtype=torch.long, device=device)
            # t : long tensor of shape B x 1 x self.n_frame_arguments

            hidden = torch.randn((1, batch_size, self.decoder_hidden_size), device=device)
            cell = torch.randn((1, batch_size, self.decoder_hidden_size), device=device)

            for i in range(seq_len):
                x = encoding[:, i].reshape((batch_size, 1, -1))
                # x : float tensor of shape B x 1 x self.encoder_hidden_size

                y = self.output_embedder(t).reshape((batch_size, 1, -1))
                # y : float tensor of shape B x 1 x self.n_frame_arguments * self.output_embedding_size

                z = torch.cat((x, y), dim = 2)
                # z : float tensor of shape B x 1 x (self.encoder_hidden_size + self.n_frame_arguments * self.output_embedding_size)

                a, (hidden, cell) = self.decoder(z, (hidden, cell))
                # a : float tensor of shape B x 1 x self.decoder_hidden_size

                holder_logits = self.holder_predictor(a)
                predicate_logits = self.predicate_predictor(a)
                target_logits = self.target_predictor(a)
                logits = torch.cat((holder_logits, predicate_logits, target_logits), dim = 1)
                # logits : float tensor of shape B x self.n_frame_arguments x self.n_labels

                logits_arr.append(logits.unsqueeze(dim = 1))
                # append reshaped logits of shape B x 1 x self.n_frame_arguments x self.n_labels

                t = frames[:, i].reshape((batch_size, 1, -1)).clone()
                # t : long tensor of shape B x 1 x self.n_frame_arguments
            
            logits = torch.cat(logits_arr, dim = 1).reshape((-1, self.n_labels))
            # logits : float tensor of shape (B * L * F) x self.n_labels

            loss_mask = (frames == 1) | ((frames == 2) & (nextlabel_frames != 0))
            # loss mask : bool tensor of shape B x L x F

            frames[~loss_mask] = loss_function.ignore_index
            
            labels = frames.flatten()
            # labels : long tensor of shape B * L * F

            loss = loss_function(logits, labels)
            return loss
        
        else:
            # beam search inference

            X = encoding.repeat_interleave(self.beam_width, dim=0)
            # X : B * self.beam_width x L x self.encoder_hidden_size

            sequences = np.zeros((batch_size * self.beam_width, seq_len, self.n_frame_arguments), dtype=int)
            scores = np.zeros((batch_size * self.beam_width,))
            prev_label = torch.full((batch_size * self.beam_width, 1, self.n_frame_arguments), fill_value=self.n_labels, dtype=torch.long, device=device)
            hidden = torch.randn((1, batch_size * self.beam_width, self.decoder_hidden_size), device=device)
            cell = torch.randn((1, batch_size * self.beam_width, self.decoder_hidden_size), device=device)

            for i in range(seq_len):
                x = X[:, i].reshape((batch_size * self.beam_width, 1, -1))
                # x : B * self.beam_width x 1 x self.encoder_hidden_size

                y = self.output_embedder(prev_label).reshape((batch_size * self.beam_width, 1, -1))
                # y : B * self.beam_width x 1 x self.n_frame_arguments * self.output_embedding_size

                z = torch.cat([x, y], dim=2)
                # z : B * self.beam_width x 1 x (self.encoder_hidden_size + self.n_frame_arguments * self.output_embedding_size)

                a, (new_hidden, new_cell) = self.decoder(z, (hidden, cell))
                # a : B * self.beam_width x 1 x self.decoder_hidden_size
                # new_hidden, new_cell : 1 x B * self.beam_width x self.decoder_hidden_size

                holder_logits = self.holder_predictor(a)
                predicate_logits = self.predicate_predictor(a)
                target_logits = self.target_predictor(a)
                # (holder, predicate, target) logits : B * self.beam_width x 1 x self.n_labels

                logits = torch.cat([holder_logits, predicate_logits, target_logits], dim=1)
                # logits : B * self.beam_width x self.n_frame_arguments x self.n_labels

                logprob = torch.log_softmax(logits, dim=2).cpu().numpy()
                # logprob : B * self.beam_width x self.n_frame_arguments x self.n_labels

                all_scores = logprob + scores.reshape((batch_size * self.beam_width, 1, 1))
                # scores : B * self.beam_width x self.n_frame_arguments x self.n_labels

                new_sequences = sequences.copy()
                # copy sequences to new_sequences

                for j in range(batch_size):
                    if i < mask[j].sum():
                        sentence_scores = all_scores[j * self.beam_width: (j + 1) * self.beam_width]
                        # sentence scores : self.beam_width x self.n_frame_arguments x self.n_labels

                        sequence_permutations = copy(self.all_permutations)
                        # sequence permutations [int] of size self.beam_width * self.n_labels ^ self.n_frame_arguments

                        sequence_permutations_with_score = [permute + [sum(sentence_scores[permute[0], arg, permute[arg + 1]] for arg in range(self.n_frame_arguments))] for permute in sequence_permutations]

                        sequence_permutations_with_score = sorted(sequence_permutations_with_score, key=lambda permute: permute[-1], reverse=True)

                        sequence_permutations_with_score = sequence_permutations_with_score[:self.beam_width]
                        # sequence permutations has the top-k (k = beam width) sequences

                        for k, permute in enumerate(sequence_permutations_with_score):
                            l = j * self.beam_width + k
                            m = j * self.beam_width + permute[0]
                            new_sequences[l, :i] = sequences[m, :i]
                            new_sequences[l, i] = permute[1:-1]
                            scores[l] = permute[-1]
                            
                            for n in range(self.n_frame_arguments):
                                prev_label[l, 0, n] = permute[n + 1]
                            
                            hidden[0, l] = new_hidden[0, m]
                            cell[0, l] = new_cell[0, m]
                
                sequences = new_sequences

            return sequences.reshape(batch_size, self.beam_width, seq_len, self.n_frame_arguments), scores.reshape(batch_size, self.beam_width)