In [2]:
import transformers
from transformers import AutoTokenizer
from datasets import load_dataset
from evaluate import load
import pandas as pd
import re

In [3]:
sample_df = pd.read_parquet("gs://scraped-news-article-data-null/fine-tune-summary--1.parquet")
sample_df = sample_df.sample(frac=1, random_state=93).reset_index(drop=True)
sample_df.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type
0,reuters,33048,Macro Matters,Drop in business activity flags UK recession r...,2023-01-24T10:55:00,"LONDON, Jan 24 (Reuters) - British private-sec...",* \n* Flash composite PMI 47.8 vs Reuters poll...,BULLETS
1,cnbc,8589,CNBC Disruptor 50,Cybereason CEO told the world about DarkSide's...,2021-05-27T12:53:48+00:00,Cybereason CEO Lior Div on disrupting the cybe...,* Cybereason ranked No. 32 on CNBC's Disruptor...,BULLETS
2,reuters,53400,Europe,Belarus says it will host Russian nuclear weap...,2023-03-28T16:54:00,"LONDON, March 28 (Reuters) - Belarus on Tuesda...",* \n* Belarus justifies its decision to host R...,BULLETS
3,reuters,70915,Asian Markets,LIVE MARKETS Europe gains after three weeks of...,2022-03-11T17:11:00,March 11 - Welcome to the home for real-time c...,"* \n* S&P 500, Nasdaq down, Dow up\n* Energy l...",BULLETS
4,cnbc,12123,Personal Finance,How to teach your kids about money,2017-11-19T16:00:00+00:00,## Kids and money\n\n## On the Money\n\nIn ord...,"When it comes to teaching your own children, t...",PLAIN


In [4]:
clean_regex = re.compile(r"\*[\s\n]*(?=\*)")
sample_df["summary"] = sample_df.summary.apply(lambda s: clean_regex.sub(" ", s).strip())
sample_df.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type
0,reuters,33048,Macro Matters,Drop in business activity flags UK recession r...,2023-01-24T10:55:00,"LONDON, Jan 24 (Reuters) - British private-sec...",* Flash composite PMI 47.8 vs Reuters poll 49....,BULLETS
1,cnbc,8589,CNBC Disruptor 50,Cybereason CEO told the world about DarkSide's...,2021-05-27T12:53:48+00:00,Cybereason CEO Lior Div on disrupting the cybe...,* Cybereason ranked No. 32 on CNBC's Disruptor...,BULLETS
2,reuters,53400,Europe,Belarus says it will host Russian nuclear weap...,2023-03-28T16:54:00,"LONDON, March 28 (Reuters) - Belarus on Tuesda...",* Belarus justifies its decision to host Russi...,BULLETS
3,reuters,70915,Asian Markets,LIVE MARKETS Europe gains after three weeks of...,2022-03-11T17:11:00,March 11 - Welcome to the home for real-time c...,"* S&P 500, Nasdaq down, Dow up\n* Energy leads...",BULLETS
4,cnbc,12123,Personal Finance,How to teach your kids about money,2017-11-19T16:00:00+00:00,## Kids and money\n\n## On the Money\n\nIn ord...,"When it comes to teaching your own children, t...",PLAIN


In [4]:
sample_df.summary_type.value_counts()

summary_type
BULLETS    29381
PLAIN        798
Name: count, dtype: int64

In [5]:
train_df = sample_df.iloc[:21125]
eval_df = sample_df.iloc[21125:]
eval_df.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type
21125,cnbc,937,Media,"Russia detains Wall Street Journal reporter, p...",2023-03-30T14:43:26+00:00,An undated ID photo of U.S. journalist Evan Ge...,* Russian authorities detained Wall Street Jou...,BULLETS
21126,reuters,97761,Asian Markets,Sri Lanka closes in on $2.9 bln IMF deal after...,2023-03-08T12:04:00,"COLOMBO/WASHINGTON, March 7 (Reuters) - Sri La...","* Bonds jump, rupee soars on prospect of IMF d...",BULLETS
21127,cnbc,7932,Asia Economy,Here's a list of the Australian exports hit by...,2020-12-18T01:32:13+00:00,A general view of a Australian flag is seen ou...,* The two countries' relationship has deterior...,BULLETS
21128,reuters,25437,EmploymentClass Actions & Multi-District Litig...,Activision sex bias settlement would derail st...,2022-05-20T18:44:00,(Reuters) - The California agency that enforce...,* State and federal agencies both sued Activis...,BULLETS
21129,cnbc,11128,Markets,"First Republic jumps nearly 30%, leads comebac...",2023-03-21T10:56:56+00:00,## In this article\n\nFollow your favorite sto...,* The move comes after a speech from Treasury ...,BULLETS


In [6]:
from datasets import Dataset, DatasetDict

model_checkpoint = "t5-small"
metric = load("rouge")
train_data = Dataset.from_pandas(train_df[["body", "summary", "summary_type"]])
eval_data = Dataset.from_pandas(eval_df[["body", "summary", "summary_type"]])
raw_datasets = DatasetDict({
    "train": train_data,
    "eval": eval_data
})

In [7]:
raw_datasets["eval"][1]

{'body': 'COLOMBO/WASHINGTON, March 7 (Reuters) - Sri Lanka looks set to get a sign-off on a long-awaited $2.9 billion four-year bailout from the International Monetary Fund (IMF) on March 20 after the crisis-hit country secured new financing support from China.\n\nThe IMF and the island nation confirmed on Tuesday that Sri Lanka had received assurances from all its major bilateral creditors, a key step to deploy financing and an important moment for the country engulfed in its worst economic crisis since independence from Britain in 1948.\n\nSri Lankan President Ranil Wickremesinghe told parliament there were signs the economy was improving, but there was still insufficient foreign currency for all imports, making the IMF deal crucial so other creditors could also start releasing funds.\n\n"Sri Lanka has completed all prior actions that were required by the IMF," Wickremesinghe said, and that he and the central bank governor had sent a letter of intent to the IMF.\n\n"I welcome the pr

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
prefix_bullets = "summarize in bullet points: "
prefix_plain = "summarize as paragraph: "
max_input_length = 2048
max_target_length = 512

def preprocess_function(examples):
    inputs = []
    for body, type in zip(examples["body"], examples["summary_type"]):
        if type == "BULLETS":
            inputs.append(prefix_bullets + body)
        elif type == "PLAIN":
            inputs.append(prefix_plain + body)
        else:
            raise ValueError("typo")
            
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs


tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

Map:   0%|          | 0/21125 [00:00<?, ? examples/s]

Map:   0%|          | 0/9054 [00:00<?, ? examples/s]

In [9]:
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM

model_name = "t5"
BATCH_TRAIN = 4
BATCH_EVAL = 8
GRADIENT_STEP = 1
LEARNING_RATE = 2e-5
EPOCHS = 4
LAMBDA = 0.01

args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=LEARNING_RATE,
    gradient_checkpointing=True,
    per_device_train_batch_size=BATCH_TRAIN,
    per_device_eval_batch_size=BATCH_EVAL,
    weight_decay=LAMBDA,
    num_train_epochs=EPOCHS,
    predict_with_generate=True,
    fp16=True,
    seed=93
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [10]:
sample = tokenizer(prefix_bullets + tokenized_datasets["eval"][1]["body"], 
                   max_length=max_input_length, truncation=True, 
                   return_tensors="pt")
print(tokenizer.batch_decode(model.generate(**sample, max_new_tokens=512, temperature=0), skip_special_tokens=True)[0])



a new letter from the IMF to Sri Lanka resolved the stalemate. the country has been waiting for about 187 days to finalise a bailout. the country has been engulfed in its worst economic crisis since independence from Britain.


In [11]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [12]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["eval"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

try:
    results = trainer.train(resume_from_checkpoint = True)
except ValueError as e:
    results = trainer.train(resume_from_checkpoint = False)

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.2856,2.045327,20.8744,7.9883,17.4206,17.9476,18.9569
2,2.2001,1.985655,21.2979,8.2911,17.7365,18.2436,18.9419
3,2.1983,1.962469,21.4974,8.4313,17.9433,18.4425,18.946
4,2.1944,1.950999,21.6483,8.5324,18.0239,18.5228,18.9463


In [13]:
sample = tokenizer(prefix_bullets + tokenized_datasets["eval"][1]["body"], 
                   max_length=max_input_length, truncation=True, 
                   return_tensors="pt").to("cuda")
print(tokenizer.batch_decode(model.generate(**sample, max_new_tokens=512, temperature=0), skip_special_tokens=True)[0])



* IMF, IMF confirm Sri Lanka to get sign-off on March 20 * IMF says it has received assurances from all creditors * IMF says it will not add items to agenda unless members act
