In [17]:
from mica_text_coref.coref.movie_coref.data import (CharacterRecognitionDataset, 
                                                    CorefCorpus)

import collections
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel

In [2]:
corpus = CorefCorpus("/home/sbaruah_usc_edu/mica_text_coref/data/"
                     "movie_coref/results/regular/movie.jsonlines")
roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-base", use_fast=True)
dataset = CharacterRecognitionDataset(
    corpus, roberta_tokenizer, seq_length=256, obey_scene_boundaries=False)

In [5]:
labels = dataset.label_ids
label_distribution = collections.Counter(labels.flatten().tolist())
print(f"labels: {labels.dtype} {labels.shape}")
print(f"label distribution = {label_distribution}")

labels: torch.int64 torch.Size([792, 256])
label distribution = Counter({0: 176964, 1: 25788})


In [8]:
saved_labels = torch.load(
    "/home/sbaruah_usc_edu/mica_text_coref/data/movie_coref/results/"
    "character_recognition/epoch_1/dev/labels.pt")
saved_logits = torch.load(
    "/home/sbaruah_usc_edu/mica_text_coref/data/movie_coref/results/"
    "character_recognition/epoch_1/dev/logits.pt")

saved_labels_distribution = collections.Counter(saved_labels.flatten().tolist())
print(f"saved_labels: {saved_labels.dtype} {saved_labels.shape} "
      f"distribution = {saved_labels_distribution}")

n_pos_saved_logits = (saved_logits[:,:,1] > saved_logits[:,:,0]).sum().item()
print(f"saved_logits: {saved_logits.dtype} {saved_logits.shape} "
      f">0 = {n_pos_saved_logits}")

saved_labels: torch.int64 torch.Size([792, 256]) distribution = Counter({0: 176964, 1: 25788})
saved_logits: torch.float32 torch.Size([792, 256, 2]) >0 = 0


In [9]:
saved_token_attn_mask = torch.load(
    "/home/sbaruah_usc_edu/mica_text_coref/data/movie_coref/results/"
    "character_recognition/epoch_1/dev/token_attention_mask.pt")

In [11]:
saved_token_attn_mask.dtype, saved_token_attn_mask.shape

(torch.bool, torch.Size([792, 256]))

In [14]:
saved_token_attn_mask.all()

tensor(True, device='cuda:0')

In [16]:
class CharacterRecognition(nn.Module):
    """Character Recognition Model.
    """

    def __init__(self, 
                 encoder_name: str,
                 num_parse_tags: int,
                 parse_tag_embedding_size: int,
                 gru_hidden_size: int,
                 gru_num_layers: int,
                 gru_dropout: float,
                 gru_bidirectional: bool,
                 num_labels: int) -> None:
        """Initializer for Character Recognition Model.

        Args:
            encoder_name: Language model encoder name from transformers hub
                e.g. bert-base-cased
            num_parse_tags: Parse tag set size
            parse_tag_embedding_size: Embedding size of the parse tags
            gru_hidden_size: Hidden size of the GRU
            gru_num_layers: Number of layers of the GRU
            gru_dropout: Dropout used between the GPU layers
            gru_bidirectional: If true, the GRU is bidirectional
            num_labels: Number of labels in the label set. 2 if label_type =
                "head" or 3 if label_type = "span"
        """
        super().__init__()
        self.num_labels = num_labels
        
        self.encoder = AutoModel.from_pretrained(
            encoder_name, add_pooling_layer=False)
        self.encoder_hidden_size = self.encoder.config.hidden_size
        self.subtoken = nn.Linear(self.encoder_hidden_size, 1)
        self.parse_embedding = nn.Embedding(
            num_parse_tags, parse_tag_embedding_size)
        self.gru_input_size = (self.encoder_hidden_size +
                               parse_tag_embedding_size)
        self.gru_output_size = gru_hidden_size * (1 + int(gru_bidirectional))
        self.gru = nn.GRU(self.gru_input_size, gru_hidden_size,
                          num_layers=gru_num_layers, batch_first=True,
                          dropout=gru_dropout, bidirectional=gru_bidirectional)
        self.output = nn.Linear(self.gru_output_size, num_labels)
        self._device = "cpu"
    
    @property
    def device(self) -> torch.device:
        """Getter for model device."""
        return self._device
    
    @device.setter
    def device(self, device):
        """Setter for model device. Used by accelerate."""
        self._device = device
    
    def forward(self, subtoken_ids: torch.Tensor, attention_mask: torch.Tensor,
                token_offset: torch.Tensor, parse_ids: torch.Tensor,
                labels: torch.Tensor) -> torch.Tensor:
        """Forward propagation for the Character Recognition Model.

        Args:
            subtoken_ids: `batch_size x max_n_subtokens` Long Tensor
            attention_mask: `batch_size x max_n_subtokens` Float/Long Tensor
            token_offset: `batch_size x max_n_tokens x 2` Long Tensor
            parse_ids: `batch_size x max_n_tokens` Long Tensor
            labels: `batch_size x max_n_tokens` Long Tensor
        
        Returns:
            Return the loss value if model is begin trained, else the logits 
            `batch_size x max_n_tokens x num_labels` Float Tensor
        """
        batch_size = len(subtoken_ids)

        # subtoken_embedding = batch_size x max_n_subtokens x encoder_hidden_size
        encoder_output = self.encoder(subtoken_ids, attention_mask)
        subtoken_embedding = encoder_output.last_hidden_state

        # _subtoken_embedding = batch_size * max_n_subtokens x encoder_hidden_size
        # subtoken_attn = batch_size * max_n_tokens x batch_size * max_n_subtokens
        _subtoken_embedding = subtoken_embedding.view(-1, self.encoder_hidden_size)
        subtoken_attn = self._attn_scores(_subtoken_embedding,
                                          token_offset.view(-1, 2))
        
        # token_embedding = batch_size x max_n_tokens x encoder_hidden_size
        token_embedding = torch.mm(
            subtoken_attn, _subtoken_embedding).reshape(
                batch_size, -1, self.encoder_hidden_size)
        
        # gru_input = batch_size x max_n_tokens x (encoder_hidden_size +
        # parse_tag_embedding_size)
        parse_input = self.parse_embedding(parse_ids)
        gru_input = torch.cat((token_embedding, parse_input), dim=2).contiguous()

        # logits = batch_size x max_n_tokens x num_labels
        gru_output, _ = self.gru(gru_input)
        logits = self.output(gru_output)

        # token attention mask = batch_size x max_n_tokens
        # TODO sabyasachee put this inside self.training and don't return
        token_attention_mask = torch.any(subtoken_attn > 0, dim=1).reshape(
                batch_size, -1)

        if self.training:
            loss = compute_loss(logits, labels, token_attention_mask,
                                self.num_labels)
            return loss
        else:
            return logits, token_attention_mask

    def _attn_scores(self,
                     subtoken_embeddings: torch.FloatTensor,
                     token_offset: torch.LongTensor) -> torch.FloatTensor:
        """ Calculates attention scores for each of the subtokens of a token.

        Args:
            subtoken_embedding: `n_subtokens x embedding_size` Float Tensor,
                embeddings for each subtoken
            token_offset: `n_tokens x 2` Long Tensor, subtoken offset of each
                token

        Returns:
            torch.FloatTensor: `n_tokens x n_subtokens` Float Tensor, attention
            weights for each subtoken of a token
        """
        n_subtokens, n_tokens = len(subtoken_embeddings), len(token_offset)
        token_begin, token_end = token_offset[:,0], token_offset[:,1]
        
        # attn_mask: n_tokens x n_subtokens, contains -∞ for subtokens outside
        # the token's offsets and 0 for subtokens inside the token's offsets
        attn_mask = torch.arange(0, n_subtokens, device=self.device).expand(
            (n_tokens, n_subtokens))
        attn_mask = ((attn_mask >= token_begin.unsqueeze(1)) * 
                     (attn_mask <= token_end.unsqueeze(1)))
        attn_mask = torch.log(attn_mask.to(torch.float))

        # attn_scores: 1 x n_subtokens
        attn_scores = self.subtoken(subtoken_embeddings).T

        # attn_scores: n_tokens x n_subtokens
        attn_scores = attn_scores.expand((n_tokens, n_subtokens))

        # -∞ for subtokens outside the token's offsets and attn_scores for
        # inside the token's offsets
        attn_scores = attn_mask + attn_scores
        del attn_mask

        # subtoken_attn contains 0 for subtokens outside the token's offsets
        subtoken_attn = torch.softmax(attn_scores, dim=1)
        return subtoken_attn
    
def compute_loss(
    logits: torch.FloatTensor, label_ids: torch.LongTensor,
    attn_mask: torch.FloatTensor, n_labels: int) -> torch.FloatTensor:
    """Compute cross entropy loss"""
    active_labels = label_ids[attn_mask == 1.]
    active_logits = logits.flatten(0, 1)[attn_mask.flatten() == 1.]
    label_distribution = torch.bincount(active_labels, minlength=n_labels)
    class_weight = torch.sqrt(len(active_labels)/(1 + label_distribution))
    cross_entrop_loss_fn = nn.CrossEntropyLoss(
        weight=class_weight, reduction="mean")
    loss = cross_entrop_loss_fn(active_logits, active_labels)
    return loss

In [18]:
dataloader = DataLoader(dataset, batch_size=64)
for batch in dataloader:
    break

In [19]:
model = CharacterRecognition("roberta-base", len(dataset.parse_tag_to_id), 32, 768, 1, 0.2, True, 2)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
model.eval();

In [23]:
batch_logits, batch_token_attention_mask = model(**batch)

In [24]:
batch_logits.dtype, batch_logits.shape, batch_token_attention_mask.dtype, batch_token_attention_mask.shape

(torch.float32, torch.Size([64, 256, 2]), torch.bool, torch.Size([64, 256]))

In [26]:
(batch_logits[:,:,1] > batch_logits[:,:,0]).sum()

tensor(8883)