In [8]:
import os
import sys
import argparse
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import f1_score
from datasets import load_dataset, load_metric
from transformers import DataCollatorForSeq2Seq, AdamWeightDecay, \
    T5ForConditionalGeneration, T5Tokenizer

In [2]:
def preprocess_function(examples):
    """ Use tokenizer to preprocess data. """
    
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    prefix = "summarize: "

    inputs = [prefix + doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["highlights"], max_length=80, truncation=True)

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs


def download_and_preprocess_data():
    """ Load dataset from HuggingFace and preprocess. """
    
    news_ds = load_dataset("cnn_dailymail", "3.0.0", split="test")

    # Tokenized using preprocess_function
    tokenized_news = news_ds.map(preprocess_function, batched=True)

    return tokenized_news

In [15]:
tokenizer = T5Tokenizer.from_pretrained("t5-small",from_pt = True)

optimizer = AdamWeightDecay(
    learning_rate=2e-5, 
    weight_decay_rate=0.01
)

model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
model.compile(optimizer=optimizer)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    model=model, 
    return_tensors="tf"
)

Downloading:   0%|          | 0.00/242M [00:00<?, ?B/s]

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


In [16]:
tokenized_news = download_and_preprocess_data()
tokenized_news

Found cached dataset cnn_dailymail (C:/Users/28165/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
Loading cached processed dataset at C:\Users\28165\.cache\huggingface\datasets\cnn_dailymail\3.0.0\3.0.0\1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de\cache-eba88d0ba3636bf1.arrow


Dataset({
    features: ['article', 'highlights', 'id', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 11490
})

In [17]:
test_ds = tokenized_news.to_tf_dataset(
    columns=["attention_mask", "input_ids", "labels"],
    shuffle=False,
    batch_size=4,
    collate_fn=data_collator,
)

In [18]:
def compute_metrics(metric, pred, actual):
    """ Compute the model's rouge performance on an instance. """

    metric.add(predictions=pred, references=actual)
    final_score = metric.compute()
    
    return final_score

In [20]:
metric = load_metric('rouge')
result = [[] for x in range(3)]

cnt = 0
for item in test_ds:
    article = item['input_ids']
    actual = item['labels']
    
    pred = model.generate(
        do_sample=True,
        input_ids=article,
        # min_length=56,
        max_length=80,
        temperature=0.8, 
        top_k=45,
        no_repeat_ngram_size=3,
        num_beams=5,
        early_stopping=True
    )

    rouge_score = compute_metrics(metric, pred, actual)
    rouge1 = 100 * rouge_score['rouge1'][1][2]
    rouge2 = 100 * rouge_score['rouge2'][1][2]
    rougeL = 100 * rouge_score['rougeL'][1][2]

    cnt += 1 
    if cnt % 25 == 0:
        print(f'Round: {cnt * 4}')

    result[0].append(rouge1)
    result[1].append(rouge2)
    result[2].append(rougeL)

Round: 100
Round: 200
Round: 300
Round: 400
Round: 500


KeyboardInterrupt: 

In [23]:
result[0]

[41.30434782608695,
 45.86206896551724,
 41.9047619047619,
 30.0632911392405,
 34.66666666666667,
 37.919463087248324,
 33.45323741007194,
 38.666666666666664,
 37.85714285714286,
 35.815602836879435,
 32.857142857142854,
 31.967213114754095,
 36.394557823129254,
 35.56338028169014,
 40.789473684210535,
 36.61971830985915,
 40.26845637583892,
 40.833333333333336,
 33.55263157894737,
 38.16793893129771,
 40.0,
 35.416666666666664,
 31.25,
 29.411764705882355,
 33.33333333333333,
 37.5,
 34.96503496503497,
 34.10852713178294,
 32.16783216783217,
 42.857142857142854,
 33.56164383561644,
 30.47945205479452,
 35.15625000000001,
 35.338345864661655,
 39.310344827586206,
 37.03703703703703,
 28.47682119205298,
 36.59420289855072,
 38.43283582089552,
 30.41666666666667,
 36.56716417910447,
 36.0,
 30.14705882352941,
 44.26229508196722,
 33.587786259541986,
 32.22222222222222,
 33.44594594594595,
 33.54430379746836,
 34.0625,
 30.718954248366014,
 35.08064516129032,
 29.78723404255319,
 34.7517