# Text Summary with T5 from Huggingface Tensorflow

In [15]:
import tensorflow as tf
from transformers import T5Tokenizer, TFT5Model, TFT5ForConditionalGeneration
import tensorflow_datasets as tfds
import time

### Params

In [4]:
BATCH_SIZE = 16

SHUFFEL_SIZE = 1024

learning_rate = 3e-5

## Define Pretrained Model and Tokenizer

In [5]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = TFT5ForConditionalGeneration.from_pretrained('t5-small')

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    
pad_token_id = tokenizer.pad_token_id

In [6]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08, clipnorm=1.0)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

model.summary()

Model: "tf_t5for_conditional_generation"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
shared (TFSharedEmbeddings)  multiple                  16449536  
_________________________________________________________________
encoder (TFT5MainLayer)      multiple                  18881280  
_________________________________________________________________
decoder (TFT5MainLayer)      multiple                  25176064  
Total params: 60,506,880
Trainable params: 60,506,880
Non-trainable params: 0
_________________________________________________________________


## Load Dataset

In [7]:
cnn_dailymail = tfds.load(name="cnn_dailymail")

INFO:absl:No config specified, defaulting to first: cnn_dailymail/plain_text
INFO:absl:Load dataset info from /home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0
INFO:absl:Field info.description from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset cnn_dailymail (/home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0)
INFO:absl:Constructing tf.data.Dataset for split None, from /home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0


In [8]:
train_tfds = cnn_dailymail['train']
test_tfds = cnn_dailymail['test']
val_tfds = cnn_dailymail['validation']

In [9]:
len_train = len(list(train_tfds))
len_test = len(list(test_tfds))
len_val = len(list(val_tfds))

In [10]:
def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text.numpy().decode('UTF-8')

def tokenize_articles(text):
    text = normalize_text(text)
    ids = tokenizer.encode_plus((model.config.prefix + text), return_tensors="tf", max_length=512) 

    return tf.squeeze(ids['input_ids']), tf.squeeze(ids['attention_mask'])
        
def tokenize_highlights(text):
    text = normalize_text(text)
    ids = tokenizer.encode(text, return_tensors="tf", max_length=150)
    return tf.squeeze(ids)



def map_func(features):

    article_ids, attention_mask = tf.py_function(tokenize_articles, inp=[features["article"]], Tout=(tf.int32, tf.int32))
    highlights_ids = tf.py_function(tokenize_highlights, inp=[features["highlights"]], Tout=tf.int32)

    return article_ids, attention_mask, highlights_ids


In [11]:
train_ds = train_tfds.map(map_func)\
    .shuffle(SHUFFEL_SIZE)\
    .padded_batch(BATCH_SIZE, padded_shapes=([512],[512],[150]))\
    .prefetch(tf.data.experimental.AUTOTUNE)

val_ds = val_tfds.map(map_func)\
    .shuffle(SHUFFEL_SIZE)\
    .padded_batch(BATCH_SIZE, padded_shapes=([512],[512],[150]))\
    .prefetch(tf.data.experimental.AUTOTUNE)

test_ds = test_tfds.map(map_func)\
    .shuffle(SHUFFEL_SIZE)\
    .padded_batch(BATCH_SIZE, padded_shapes=([512],[512],[150]))\
    .prefetch(tf.data.experimental.AUTOTUNE)

## Define Train and Validation Step

In [12]:
@tf.function
def train_step(input_ids, input_mask, y):
    # https://github.com/huggingface/transformers/blob/master/examples/summarization/bart/finetune.py
    y_ids = y[:, :-1]
    lm_labels = tf.identity(y[:, 1:])
    lm_labels = tf.where(tf.equal(y[:, 1:],pad_token_id), -100, lm_labels)

    with tf.GradientTape() as tape:
        # prediction_scores: (bs, 150, 32128)
        # decoder_past_key_value_states: (bs, 512, 512), (bs, 8, 150, 64)
        # z: (bs, 512, 512)
        predictions, _, _ = model(input_ids, attention_mask=input_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, training=True)
        loss = loss_object(y[:, 1:], predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(y[:, 1:], predictions)

In [13]:
@tf.function
def val_step(input_ids, input_mask, y):
    # https://github.com/huggingface/transformers/blob/master/examples/summarization/bart/finetune.py
    y_ids = y[:, :-1]
    lm_labels = tf.identity(y[:, 1:])
    lm_labels = tf.where(tf.equal(y[:, 1:],pad_token_id), -100, lm_labels)
    
    predictions, _, _ = model(input_ids, attention_mask=input_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, training=False)
    v_loss = loss_object(y[:, 1:], predictions)

    val_loss(v_loss)
    val_accuracy(y[:, 1:], predictions)

## Train

In [17]:
EPOCHS = 1
log_interval = 200
for epoch in range(EPOCHS):
    # reset metrics
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    val_loss.reset_states()
    val_accuracy.reset_states()
    
    val_batches = iter(val_ds)
    
    start_time = time.time()
    for i, (input_ids, input_mask, y) in enumerate(train_ds):
        # training
        train_step(input_ids, input_mask, y)
        
        # validation
        if i % log_interval == 0:
            x_val, x_mask_val, y_val = next(val_batches)
            val_step(x_val, x_mask_val, y_val)
            elapsed = time.time() - start_time
            print('| epoch {:3d} | [{:5d}/{:5d}] | '
                  'ms/batch {:5.2f} | '
                  'train acc {:5.2f} | train acc {:5.2f} |'
                  'loss {:5.2f} | val loss {:5.2f}'.format(
                    epoch, i, len_train,
                    elapsed * 1000 / log_interval,
                    train_accuracy.result() * 100, val_accuracy.result() * 100, 
                    train_loss.result(),  val_loss.result()))
            start_time = time.time()
            if i > 5000:
                break


| epoch   0 | [    0/287113] | ms/batch 40.59 | train acc 22.86 | train acc 33.18 |loss 11.36 | val loss  9.26
| epoch   0 | [  200/287113] | ms/batch 328.69 | train acc 62.01 | train acc 53.67 |loss  2.77 | val loss  5.31
| epoch   0 | [  400/287113] | ms/batch 336.83 | train acc 68.89 | train acc 62.37 |loss  2.04 | val loss  3.87


KeyboardInterrupt: 

## Evaluate

### Define Rouge Score

In [18]:
from rouge_score import rouge_scorer
from rouge_score import scoring

class RougeScore:
    '''
    mostly from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/evaluation/metrics.py 
    '''
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

### Compute Rouge Score

In [24]:
predictions = []
rouge_score = RougeScore()
for i, (input_ids, input_mask, y) in enumerate(test_ds):
    summaries = model.generate(input_ids=input_ids, attention_mask=input_mask)

    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in y]
    
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
        
    if i > 1:
        # otherwise it will take ages
        break


rouge_score.result()

rouge1 = 38.89, 95% confidence [35.55, 42.19]
rouge2 = 15.99, 95% confidence [13.34, 18.72]
rougeLsum = 35.85, 95% confidence [32.62, 38.75]


{'rouge1': 38.89303107315277,
 'rouge2': 15.991272587561367,
 'rougeLsum': 35.846465481520475}

### Predict some Sentences

In [25]:
for pred in predictions[:10]:
    print("------")
    print(pred)
    print("------")     

------
pred sentence: dybala is wanted by manchester united, chelsea, arsenal, inter milan, juventus and psg . palermo president maurizio zamparini has slapped a €50m price tag on paulo sybala . he is being tracked by paris saint-germain, manchester city, arsenal . the argentine is ligue 1 champions league champions pcg, with roberto

 real sentence: paulo dybala is being tracked by manchester united, chelsea and arsenal . inter milan, juventus and paris saint-germain are among others interested . inter boss roberto mancini was spotted at the palermo match on sunday . palermo president maurizio zamparini is demanding £36m from psg . read: dybala says he would love a serie a stay .
------
------
pred sentence: lib dem stephen gilbert was bitten by a dog while campaigning for re-election in cornwall . he joked that the dog had unresolved anger issues after posting a photograph online of his bloodied hand . the animal took a lump out of his right hand as he posted a leaflet through a lett