In [49]:
import os
import sys
import argparse
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import f1_score
from datasets import load_dataset, load_metric, Dataset
from transformers import DataCollatorForSeq2Seq, AdamWeightDecay, \
    TFT5ForConditionalGeneration, T5Tokenizer

In [65]:
def preprocess_function(examples):
    """ Use tokenizer to preprocess data. """
    
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    prefix = "summarize: "

    inputs = [prefix + doc for doc in examples["string"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["label"], max_length=80, truncation=True)
    print(labels)
    #model_inputs["labels"] = labels["input_ids"]

    return model_inputs


def download_and_preprocess_data(dataset):
    """ Load dataset from HuggingFace and preprocess. """
    

    
    # Tokenized using preprocess_function
    tokenized_news = dataset.map(preprocess_function, batched=True)

    return tokenized_news

In [7]:
tokenizer = T5Tokenizer.from_pretrained("t5-small",from_pt = True)

optimizer = AdamWeightDecay(
    learning_rate=2e-5, 
    weight_decay_rate=0.01
)

model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
model.compile(optimizer=optimizer)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, 
    model=model, 
    return_tensors="tf"
)

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
No loss specified in compile() - the model's internal loss computation will be used as the loss. Don't panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss.


In [57]:
data = pd.read_excel("test.xlsx")
dataset = Dataset.from_pandas(data)

In [62]:
dataset

Dataset({
    features: ['string', 'label'],
    num_rows: 122
})

  0%|          | 0/1 [00:00<?, ?ba/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


KeyError: 'labels'

In [66]:
tokenized_news = download_and_preprocess_data(dataset)
tokenized_news

  0%|          | 0/1 [00:00<?, ?ba/s]

{'input_ids': [[71, 4210, 11, 7295, 13, 149, 16981, 2373, 1], [3, 9, 9251, 13, 315, 2284, 21, 16981, 1], [7295, 13, 572, 8, 5796, 19, 1692, 1], [3, 9, 13005, 344, 572, 8, 5431, 19, 1692, 11, 572, 8, 5796, 19, 1692, 1], [7295, 13, 572, 8, 1997, 19, 4459, 1], [3, 9, 4903, 13, 3, 9, 2647, 3767, 15, 1], [8, 7796, 11, 6900, 13, 3, 9, 2647, 3767, 15, 1], [46, 7295, 13, 149, 3, 9, 2647, 3767, 15, 3, 89, 4664, 1], [46, 7295, 13, 572, 2647, 3767, 15, 7, 33, 19963, 400, 17, 15, 1], [8, 10364, 13, 3, 9, 2647, 3767, 15, 1], [433, 13, 31638, 1], [570, 13, 3379, 1], [4903, 13, 3, 9, 9753, 1], [16726, 13, 753, 13, 3, 9, 810, 1], [10005, 13, 8, 810, 1], [18070, 257, 13, 97, 12, 11091, 3, 9, 4340, 1], [433, 13, 8416, 6126, 6373, 1], [9624, 122, 127, 1707, 13, 3652, 1], [9624, 122, 127, 1707, 13, 5492, 77, 257, 1308, 1], [3, 9, 4210, 13, 3, 9, 1994, 1], [46, 4332, 27866, 53, 13, 8, 16813, 13, 8, 5447, 152, 4716, 1], [16726, 13, 1128, 1], [16726, 13, 753, 13, 20818, 3, 354, 83, 23, 9, 1], [46, 677, 13, 3

Dataset({
    features: ['string', 'label', 'input_ids', 'attention_mask'],
    num_rows: 122
})

In [59]:
ds_pandas

Unnamed: 0,article,highlights,id
0,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,f001ec5c4704938247d27a44948eebb37ae98d01
1,(CNN)Never mind cats having nine lives. A stra...,"Theia, a bully breed mix, was apparently hit b...",230c522854991d053fe98a718b1defa077a8efef
2,"(CNN)If you've been following the news lately,...",Mohammad Javad Zarif has spent more time with ...,4495ba8f3a340d97a9df1476f8a35502bcce1f69
3,(CNN)Five Americans who were monitored for thr...,17 Americans were exposed to the Ebola virus w...,a38e72fed88684ec8d60dd5856282e999dc8c0ca
4,(CNN)A Duke student has admitted to hanging a ...,Student is no longer on Duke University campus...,c27cf1b136cc270023de959e7ab24638021bc43f
...,...,...,...
11485,Telecom watchdogs are to stop a rip-off that a...,Operators are charging up to 20p a minute - ev...,0ac776a4dc09ca97c136f4314fed4defb48a361a
11486,The chilling reenactment of how executions are...,Bali Nine ringleaders will face the firing squ...,fe89a6a2e28d173e5ad4c6d814c15b95aa969e3f
11487,It is a week which has seen him in deep water ...,Hardy was convicted of domestic abuse against ...,ded2f535cd6ab95d11b5f4ea29bbf2b2d3c55c50
11488,"Despite the hype surrounding its first watch, ...",Apple sold more than 61 million iPhones in the...,30ec5f280eee772a73d181bfc8514defd8026434


In [32]:
ds_pandas = ds.to_pandas()
df_pandas = tokenized_news.to_pandas()

In [60]:
ds

Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 11490
})

In [19]:
df_pandas

Unnamed: 0,article,highlights,id,input_ids,attention_mask,labels
0,(CNN)The Palestinian Authority officially beca...,Membership gives the ICC jurisdiction over all...,f001ec5c4704938247d27a44948eebb37ae98d01,"[21603, 10, 41, 254, 17235, 61, 634, 10748, 92...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[19428, 1527, 8, 3, 24291, 10185, 147, 3, 1255..."
1,(CNN)Never mind cats having nine lives. A stra...,"Theia, a bully breed mix, was apparently hit b...",230c522854991d053fe98a718b1defa077a8efef,"[21603, 10, 41, 254, 17235, 61, 567, 3258, 809...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[37, 23, 9, 6, 3, 9, 8434, 63, 8885, 2153, 6, ..."
2,"(CNN)If you've been following the news lately,...",Mohammad Javad Zarif has spent more time with ...,4495ba8f3a340d97a9df1476f8a35502bcce1f69,"[21603, 10, 41, 254, 17235, 61, 5801, 25, 31, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1290, 1483, 11374, 10318, 26, 24374, 99, 65, ..."
3,(CNN)Five Americans who were monitored for thr...,17 Americans were exposed to the Ebola virus w...,a38e72fed88684ec8d60dd5856282e999dc8c0ca,"[21603, 10, 41, 254, 17235, 61, 371, 757, 5452...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1003, 5452, 130, 6666, 12, 8, 262, 4243, 9, 6..."
4,(CNN)A Duke student has admitted to hanging a ...,Student is no longer on Duke University campus...,c27cf1b136cc270023de959e7ab24638021bc43f,"[21603, 10, 41, 254, 17235, 61, 188, 15090, 12...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[6341, 19, 150, 1200, 30, 15090, 636, 4730, 11..."
...,...,...,...,...,...,...
11485,Telecom watchdogs are to stop a rip-off that a...,Operators are charging up to 20p a minute - ev...,0ac776a4dc09ca97c136f4314fed4defb48a361a,"[21603, 10, 7338, 287, 1605, 10169, 7, 33, 12,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[25667, 7, 33, 10871, 95, 12, 460, 102, 3, 9, ..."
11486,The chilling reenactment of how executions are...,Bali Nine ringleaders will face the firing squ...,fe89a6a2e28d173e5ad4c6d814c15b95aa969e3f,"[21603, 10, 37, 10191, 53, 3, 60, 35, 2708, 29...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[20241, 19636, 3, 1007, 22900, 7, 56, 522, 8, ..."
11487,It is a week which has seen him in deep water ...,Hardy was convicted of domestic abuse against ...,ded2f535cd6ab95d11b5f4ea29bbf2b2d3c55c50,"[21603, 10, 94, 19, 3, 9, 471, 84, 65, 894, 37...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[6424, 63, 47, 3, 21217, 13, 4422, 5384, 581, ..."
11488,"Despite the hype surrounding its first watch, ...",Apple sold more than 61 million iPhones in the...,30ec5f280eee772a73d181bfc8514defd8026434,"[21603, 10, 3, 4868, 8, 22980, 3825, 165, 166,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[2184, 1916, 72, 145, 3, 4241, 770, 3146, 7, 1..."


In [34]:
ds

Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 11490
})

In [8]:
test_ds = tokenized_news.to_tf_dataset(
    columns=["attention_mask", "input_ids", "labels"],
    shuffle=False,
    batch_size=4,
    collate_fn=data_collator,
)

In [18]:
def compute_metrics(metric, pred, actual):
    """ Compute the model's rouge performance on an instance. """

    metric.add(predictions=pred, references=actual)
    final_score = metric.compute()
    
    return final_score

In [20]:
metric = load_metric('rouge')
result = [[] for x in range(3)]

cnt = 0
for item in test_ds:
    article = item['input_ids']
    actual = item['labels']
    
    pred = model.generate(
        do_sample=True,
        input_ids=article,
        # min_length=56,
        max_length=80,
        temperature=0.8, 
        top_k=45,
        no_repeat_ngram_size=3,
        num_beams=5,
        early_stopping=True
    )

    rouge_score = compute_metrics(metric, pred, actual)
    rouge1 = 100 * rouge_score['rouge1'][1][2]
    rouge2 = 100 * rouge_score['rouge2'][1][2]
    rougeL = 100 * rouge_score['rougeL'][1][2]

    cnt += 1 
    if cnt % 25 == 0:
        print(f'Round: {cnt * 4}')

    result[0].append(rouge1)
    result[1].append(rouge2)
    result[2].append(rougeL)

Round: 100
Round: 200
Round: 300
Round: 400
Round: 500


KeyboardInterrupt: 

In [23]:
result[0]

[41.30434782608695,
 45.86206896551724,
 41.9047619047619,
 30.0632911392405,
 34.66666666666667,
 37.919463087248324,
 33.45323741007194,
 38.666666666666664,
 37.85714285714286,
 35.815602836879435,
 32.857142857142854,
 31.967213114754095,
 36.394557823129254,
 35.56338028169014,
 40.789473684210535,
 36.61971830985915,
 40.26845637583892,
 40.833333333333336,
 33.55263157894737,
 38.16793893129771,
 40.0,
 35.416666666666664,
 31.25,
 29.411764705882355,
 33.33333333333333,
 37.5,
 34.96503496503497,
 34.10852713178294,
 32.16783216783217,
 42.857142857142854,
 33.56164383561644,
 30.47945205479452,
 35.15625000000001,
 35.338345864661655,
 39.310344827586206,
 37.03703703703703,
 28.47682119205298,
 36.59420289855072,
 38.43283582089552,
 30.41666666666667,
 36.56716417910447,
 36.0,
 30.14705882352941,
 44.26229508196722,
 33.587786259541986,
 32.22222222222222,
 33.44594594594595,
 33.54430379746836,
 34.0625,
 30.718954248366014,
 35.08064516129032,
 29.78723404255319,
 34.7517

In [24]:
result[1]

[18.545454545454547,
 20.41522491349481,
 20.095693779904305,
 10.158730158730158,
 14.381270903010035,
 15.151515151515149,
 13.718411552346572,
 19.732441471571903,
 16.48745519713262,
 13.523131672597863,
 14.695340501792115,
 13.991769547325102,
 17.064846416382252,
 15.547703180212014,
 16.5016501650165,
 14.13427561837456,
 22.22222222222222,
 19.246861924686193,
 12.871287128712872,
 21.455938697318008,
 17.573221757322173,
 14.634146341463413,
 11.808118081180812,
 12.236286919831224,
 13.240418118466902,
 11.808118081180814,
 16.49122807017544,
 11.673151750972762,
 11.929824561403509,
 25.80645161290323,
 10.996563573883163,
 12.371134020618555,
 16.470588235294116,
 14.716981132075471,
 19.031141868512112,
 14.869888475836431,
 9.302325581395348,
 14.181818181818182,
 11.985018726591761,
 11.297071129707113,
 17.602996254681642,
 10.702341137123746,
 13.284132841328415,
 21.39917695473251,
 11.49425287356322,
 12.267657992565056,
 11.525423728813559,
 14.285714285714288,
 13

In [25]:
result[2]

[25.72463768115941,
 25.517241379310345,
 30.0,
 17.721518987341774,
 20.0,
 23.154362416107382,
 18.345323741007196,
 27.666666666666668,
 25.35714285714285,
 20.921985815602838,
 21.428571428571427,
 23.36065573770492,
 23.46938775510204,
 22.887323943661972,
 24.013157894736842,
 22.535211267605636,
 28.859060402684566,
 30.000000000000004,
 22.697368421052634,
 28.24427480916031,
 28.333333333333332,
 20.833333333333336,
 19.48529411764706,
 19.327731092436974,
 18.75,
 19.485294117647058,
 24.475524475524477,
 20.542635658914726,
 19.23076923076923,
 30.00000000000001,
 18.835616438356162,
 19.178082191780824,
 21.875,
 21.428571428571427,
 27.241379310344826,
 21.48148148148148,
 16.887417218543042,
 22.10144927536232,
 20.895522388059703,
 18.750000000000004,
 19.776119402985078,
 19.666666666666668,
 19.485294117647058,
 27.04918032786885,
 20.229007633587788,
 20.370370370370374,
 17.905405405405407,
 22.468354430379744,
 18.75,
 17.64705882352941,
 20.967741935483872,
 17.375