# Text Summary with T5 from Huggingface Pytorch

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from pathlib import Path
import torch
import re
import time
from torch import nn
from torch.nn import CrossEntropyLoss 

import copy
from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration

from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.modeling_t5 import T5Stack, T5LayerNorm, T5DenseReluDense, T5LayerFF, T5Attention, T5LayerSelfAttention, T5LayerCrossAttention, T5Block, T5PreTrainedModel

In [2]:
BATCH_SIZE = 4

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

model_size = 't5-small'

In [3]:
print(device)

cuda:0


In [4]:
class MyT5(T5ForConditionalGeneration):
    
    def __init__(self, config):   
        super().__init__(config)
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_input_ids_translate=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        head_mask=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        translation_labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to ``-100`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
        Returns:
        Examples::
            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True)
            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
            labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits
            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model.generate(input_ids)
        """

        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("lm_labels")
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

                
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        sequence_output = decoder_outputs[0]
        
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.model_dim ** -0.5)
        lm_logits = self.lm_head(sequence_output)
         
        # Machine translation Decode
        decoder_input_ids_translate = self._shift_right(decoder_input_ids_translate)
#         print(decoder_input_ids_translate.shape, decoder_attention_mask, decoder_inputs_embeds, output_attentions)
        mt_decoder_outputs = self.mt_decoder(
            input_ids=decoder_input_ids_translate,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

       
        mt_sequence_output = mt_decoder_outputs[0]
        mt_sequence_output = mt_sequence_output * (self.model_dim ** -0.5)
        mt_lm_logits = self.mt_lm_head(mt_sequence_output)
        
        loss = None
        if labels is not None:
#             print(lm_logits.view(-1, lm_logits.size(-1)).shape, labels.view(-1).shape)
#             print(mt_lm_logits.view(-1, mt_lm_logits.size(-1)).shape, translation_labels.view(-1).shape)
            
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            
            summary_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            tranlsation_loss = loss_fct(mt_lm_logits.view(-1, mt_lm_logits.size(-1)), translation_labels.view(-1))
            
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
            loss = summary_loss + tranlsation_loss
        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


In [5]:
mt_config = T5Config(model_size)
mt_config

T5Config {
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 6,
  "num_heads": 8,
  "num_layers": 6,
  "pad_token_id": 0,
  "relative_attention_num_buckets": 32,
  "vocab_size": "t5-small"
}

In [6]:
mt_decoder = T5ForConditionalGeneration.from_pretrained(model_size).decoder
mt_lm_head = T5ForConditionalGeneration.from_pretrained(model_size).lm_head

In [7]:
myT5 = MyT5.from_pretrained(model_size)

myT5.mt_decoder = mt_decoder
myT5.mt_lm_head = mt_lm_head

myT5 = myT5.to(device)


## Define Model

In [8]:
tokenizer = T5Tokenizer.from_pretrained(model_size)
task_specific_params = myT5.config.task_specific_params
if task_specific_params is not None:
    myT5.config.update(task_specific_params.get("summarization", {}))
    

optimizer = torch.optim.Adam(myT5.parameters(),lr=learning_rate, weight_decay=0.0001)

## Define Pytorch Dataset

In [9]:
def read_files(name):
    article_path = "../data/%s/article" % name
    highlights_path = "../data/%s/highlights" % name
    
    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [10]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(myT5.config.prefix + self.transfrom(self.x[index]), max_length=512, return_tensors="pt", padding='max_length', truncation=True)
        y = tokenizer.encode(self.transfrom(self.y[index]), max_length=150, return_tensors="pt", padding='max_length',  truncation=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [11]:
def get_dataset(name):
    articles, highlights = read_files(name)
    return MyDataset(articles, highlights)

In [12]:
train_ds = get_dataset("train")
test_ds = get_dataset("test")
val_ds = get_dataset("val")

In [13]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

## Define Step function

In [14]:
pad_token_id = tokenizer.pad_token_id
def step(input_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    t_y_ids = input_ids[:, :-1].contiguous()
    
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    
    t_lm_labels = input_ids[:, 1:].clone()
    t_lm_labels[input_ids[:, 1:] == pad_token_id] = -100
#     print(inputs_ids.device, attention_mask.device, y_ids.device, lm_labels.device)
    output = myT5(
        input_ids, 
        attention_mask=attention_mask, 
        decoder_input_ids=y_ids, 
        decoder_input_ids_translate=t_y_ids, 
        labels=lm_labels,
        translation_labels=t_lm_labels,
        return_dict=True
    )
    return output['loss'] 

## Train

In [15]:
EPOCHS = 1
log_interval = 200
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    myT5.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        train_loss.append(loss.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(myT5.parameters(), 0.5)
        optimizer.step()
            
        if (i + 1) % log_interval == 0:
            with torch.no_grad():
                x, x_mask, y = next(iter(val_loader))
                x = x.to(device)
                x_mask = x_mask.to(device)
                y = y.to(device)
                
                v_loss = step(x, x_mask, y)
                v_loss = v_loss.item()
                
                
                elapsed = time.time() - start_time
                print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len(train_loader),
                    elapsed * 1000 / log_interval,
                    loss.item(), v_loss))
                start_time = time.time()
                val_loss.append(v_loss)
                
                

| epoch   0 | [  199/71779] | ms/batch 224.50 | loss  6.86 | val loss  6.71


KeyboardInterrupt: 