# Text Summary with T5 from Huggingface

In [1]:
import tensorflow as tf
from transformers import T5Tokenizer, TFT5Model, TFT5ForConditionalGeneration
import tensorflow_datasets as tfds

### Params

In [2]:
BATCH_SIZE = 16

SHUFFEL_SIZE = 1024

## Define Pretrained Model and Tokenizer

In [3]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = TFT5ForConditionalGeneration.from_pretrained('t5-small')

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    
pad_token_id = tokenizer.pad_token_id

In [4]:
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

model.summary()

Model: "tf_t5for_conditional_generation"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
shared (TFSharedEmbeddings)  multiple                  16449536  
_________________________________________________________________
encoder (TFT5MainLayer)      multiple                  18881280  
_________________________________________________________________
decoder (TFT5MainLayer)      multiple                  25176064  
Total params: 60,506,880
Trainable params: 60,506,880
Non-trainable params: 0
_________________________________________________________________


## Load Dataset

In [5]:
cnn_dailymail = tfds.load(name="cnn_dailymail")

INFO:absl:No config specified, defaulting to first: cnn_dailymail/plain_text
INFO:absl:Load dataset info from /home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0
INFO:absl:Field info.description from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset cnn_dailymail (/home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0)
INFO:absl:Constructing tf.data.Dataset for split None, from /home/yannik/tensorflow_datasets/cnn_dailymail/plain_text/3.0.0


In [6]:
train_tfds = cnn_dailymail['train']
test_tfds = cnn_dailymail['test']
val_tfds = cnn_dailymail['validation']

In [7]:
len_train = len(list(train_tfds))
len_test = len(list(test_tfds))
len_val = len(list(val_tfds))

In [8]:
def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text.numpy().decode('UTF-8')

def tokenize_articles(text):
    text = normalize_text(text)
    ids = tokenizer.encode_plus((model.config.prefix + text), return_tensors="tf", max_length=512) 

    return tf.squeeze(ids['input_ids']), tf.squeeze(ids['attention_mask'])
        
def tokenize_highlights(text):
    text = normalize_text(text)
    ids = tokenizer.encode(text, return_tensors="tf", max_length=150)
    return tf.squeeze(ids)



def map_func(features):

    article_ids, attention_mask = tf.py_function(tokenize_articles, inp=[features["article"]], Tout=(tf.int32, tf.int32))
    highlights_ids = tf.py_function(tokenize_highlights, inp=[features["highlights"]], Tout=tf.int32)

    return article_ids, attention_mask, highlights_ids


In [9]:
train_ds = train_tfds.map(map_func)\
    .shuffle(SHUFFEL_SIZE)\
    .padded_batch(BATCH_SIZE, padded_shapes=([512],[512],[150]))\
    .prefetch(tf.data.experimental.AUTOTUNE)

val_ds = val_tfds.map(map_func)\
    .shuffle(SHUFFEL_SIZE)\
    .padded_batch(BATCH_SIZE, padded_shapes=([512],[512],[150]))\
    .prefetch(tf.data.experimental.AUTOTUNE)

test_ds = test_tfds.map(map_func)\
    .shuffle(SHUFFEL_SIZE)\
    .padded_batch(BATCH_SIZE, padded_shapes=([512],[512],[150]))\
    .prefetch(tf.data.experimental.AUTOTUNE)

## Define Train and Validation Step

In [10]:
@tf.function
def train_step(input_ids, input_mask, y):
    # https://github.com/huggingface/transformers/blob/master/examples/summarization/bart/finetune.py
    y_ids = y[:, :-1]
    lm_labels = tf.identity(y[:, 1:])
    lm_labels = tf.where(tf.equal(y[:, 1:],pad_token_id), -100, lm_labels)

    with tf.GradientTape() as tape:
        # prediction_scores: (bs, 150, 32128)
        # decoder_past_key_value_states: (bs, 512, 512), (bs, 8, 150, 64)
        # z: (bs, 512, 512)
        predictions, _, _ = model(input_ids, attention_mask=input_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, training=True)
        loss = loss_object(y[:, 1:], predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(y[:, 1:], predictions)

In [11]:
@tf.function
def val_step(input_ids, input_mask, y):
    # https://github.com/huggingface/transformers/blob/master/examples/summarization/bart/finetune.py
    y_ids = y[:, :-1]
    lm_labels = tf.identity(y[:, 1:])
    lm_labels = tf.where(tf.equal(y[:, 1:],pad_token_id), -100, lm_labels)
    
    predictions, _, _ = model(input_ids, attention_mask=input_mask, decoder_input_ids=y_ids, lm_labels=lm_labels, training=False)
    v_loss = loss_object(y[:, 1:], predictions)

    val_loss(v_loss)
    val_accuracy(y[:, 1:], predictions)

## Define Evaluation Step

### Define Rouge Score

In [12]:
from rouge_score import rouge_scorer
from rouge_score import scoring

class RougeScore:
    '''
    mostly from https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/evaluation/metrics.py 
    '''
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

In [13]:
rouge_score = RougeScore()

rouge_score("I want some ice cream, what are you doing","What I want is ice cream, what the hell")
rouge_score.result()

rouge1 = 55.56, 95% confidence [55.56, 55.56]
rouge2 = 37.50, 95% confidence [37.50, 37.50]
rougeLsum = 55.56, 95% confidence [55.56, 55.56]


{'rouge1': 55.55555555555556, 'rouge2': 37.5, 'rougeLsum': 55.55555555555556}

In [14]:
def test_step(input_ids, input_mask):
    predicted_ids = model.generate(input_ids=input_ids, attention_mask=input_mask)

    return predicted_ids


# def test_step(input_ids, input_mask):
#     predicted_ids = model.generate(input_ids=tf.reshape(input_ids, (1,-1)), attention_mask=tf.reshape(input_mask, (1,-1)))

#     return predicted_ids

def print_prediction(prediction, real):
    
    decoded_prediction = tokenizer.decode(tf.reshape(prediction, -1))
    decoded_label = tokenizer.decode(tf.reshape(real, -1))
    
    print("------")
    print("Pred Sentence:", decoded_prediction, "\n")
    print("True Sentence:", decoded_label)
    print("------")

## Train

In [16]:
EPOCHS = 1
for epoch in range(EPOCHS):
    # reset metrics
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    test_loss.reset_states()
    test_accuracy.reset_states()
    
    val_loss.reset_states()
    val_accuracy.reset_states()
    
    val_batches = iter(val_ds)
    
    for i, (input_ids, input_mask, y) in enumerate(train_ds):
        # training
        train_step(input_ids, input_mask, y)
        
        # validation
        if i % 200 == 0:
            x_val, x_mask_val, y_val = next(val_batches)
            val_step(x_val, x_mask_val, y_val)
            template = 'Epoch {}, [{}/{}], train loss: {:.2f}, train acc: {:.2f}, val loss: {:.2f}, val acc: {:.2f}'
            print(template.format(epoch + 1,
                        i, int(len_train / BATCH_SIZE),
                        train_loss.result(),
                        train_accuracy.result() * 100,
                        val_loss.result(),
                        val_accuracy.result() * 100))
            
            if i > 5000:
                break


Epoch 1, [0/17944], train loss: 10.51, train acc: 26.64, val loss: 9.92, val acc: 31.59
Epoch 1, [200/17944], train loss: 2.74, train acc: 62.16, val loss: 5.59, val acc: 53.19
Epoch 1, [400/17944], train loss: 2.02, train acc: 68.99, val loss: 4.14, val acc: 60.89
Epoch 1, [600/17944], train loss: 1.75, train acc: 71.71, val loss: 3.38, val acc: 65.34
Epoch 1, [800/17944], train loss: 1.60, train acc: 73.12, val loss: 2.91, val acc: 68.22
Epoch 1, [1000/17944], train loss: 1.51, train acc: 74.14, val loss: 2.64, val acc: 69.18
Epoch 1, [1200/17944], train loss: 1.44, train acc: 74.85, val loss: 2.42, val acc: 70.56
Epoch 1, [1400/17944], train loss: 1.39, train acc: 75.42, val loss: 2.25, val acc: 71.71
Epoch 1, [1600/17944], train loss: 1.35, train acc: 75.83, val loss: 2.12, val acc: 72.52
Epoch 1, [1800/17944], train loss: 1.32, train acc: 76.16, val loss: 2.04, val acc: 72.76
Epoch 1, [2000/17944], train loss: 1.30, train acc: 76.44, val loss: 1.96, val acc: 73.20
Epoch 1, [2200/1

## Evaluate

In [17]:
for i, (input_ids, input_mask, y) in enumerate(test_ds):
    predicted_ids = test_step(input_ids, input_mask)

    for j in range(predicted_ids.shape[0]):
        decoded_prediction = tokenizer.decode(predicted_ids[j])
        decoded_label = tokenizer.decode(y[j])
        rouge_score(decoded_label, decoded_prediction)
    if i > 20:
        # otherwise it will take ages
        break


print(rouge_score.result())

rouge1 = 38.73, 95% confidence [37.63, 39.97]
rouge2 = 17.35, 95% confidence [16.16, 18.63]
rougeLsum = 25.11, 95% confidence [24.00, 26.24]
{'rouge1': 38.725790158616526, 'rouge2': 17.3517251146141, 'rougeLsum': 25.10854188534256}


## Predict some Sentences

In [18]:
for input_ids, input_mask, y in test_ds.take(1):
    for i in range(input_ids.shape[0]):
        predicted_ids = test_step(tf.reshape(input_ids[i], (1,-1)), tf.reshape(input_mask[i], (1,-1)))
        print_prediction(predicted_ids, y[i])

        

------
Pred Sentence: robert bates surrendered tuesday morning, accompanied by his attorney, and posted bail of $25,000. he shot eric courtney harris in the back with a handgun. the reserve deputy was charged with second-degree manslaughter monday. his attorney reiterated that he believes the charge against his client is unwarranted. police say a sting operation caught him illegally selling a gun. 

True Sentence: reserve deputy robert bates surrenders to authorities, posts bail of $25,000. bates is charged with second-degree manslaughter in the killing of eric harris.
------
------
Pred Sentence: u.s. teenager tommy schaefer, 21, found guilty of killing sheila von wiese-mack in a suitcase at an upmarket hotel in bali. heather mack, 19, gave birth to her daughter just weeks ago. he said he killed his girlfriend's mother after a violent argument erupted over the couple's relationship. von wie-mack was reported to have a troubled relationship with her teenage daughter, 21. the pair were 

------
Pred Sentence: dr xiao-ping zhai has been helping women conceive for decades using only traditional chinese medicine. she prescribes a bespoke combination of herbs which must be taken day and night. one in seven couples suffer with infertility and an increasing number are seeking alternative therapies. bbc cameras inside london's famous private medical haven harley street to reveal secrets of her success. the clinic, which she opened in 1996, will air tonight on 

True Sentence: around one in seven couples suffer with infertility. dr xiao-ping zhai offers help via with traditional chinese medicine. uses acupuncture and prescribes course of chinese herbs.
------
------
Pred Sentence: dr fredric brandt, 65, was found hanged in his miami home. he had been suffering from depression and was left 'devastated' by recent rumors comparing him to the grotesque character of franff in tina fey’s netflix show. the pictures, taken in 1967, show him pouting beneath razor sharp cheek bones, sun