In [1]:
import transformers
from transformers import AutoTokenizer
from datasets import load_dataset
from evaluate import load
import pandas as pd
import re

In [2]:
sample_df = pd.read_parquet("gs://scraped-news-article-data-null/fine-tune-summary--1.parquet")
sample_df = sample_df.sample(frac=1, random_state=93).reset_index(drop=True)
sample_df.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type
0,reuters,41422,United Kingdom,UK's William Hill given record $24 million fin...,2023-03-28T07:26:00,"LONDON, March 28 (Reuters) - Britain's Gamblin...",* \n* Fine is largest in UK gambling\n* Regula...,BULLETS
1,reuters,109510,U.S. Markets,Indexes slip with tech-related shares; consume...,2023-05-13T00:05:00,May 12 (Reuters) - U.S. stocks ended slightly ...,* \n* U.S. consumer sentiment drops to six-mon...,BULLETS
2,reuters,48857,Exploration & ProductionClimate Change,Shell wins UK Supreme Court case on 2011 oil s...,2023-05-10T11:41:00,"LONDON, May 10 (Reuters) - The UK Supreme Cour...",* \n* Nigerians trying to sue Shell over offsh...,BULLETS
3,reuters,93908,Europe,Ukraine gets more U.S. aid as Russia-Iran ties...,2022-12-10T21:08:00,"KYIV, Dec 9 (Reuters) - The United States anno...",* \n* Washington announces new military aid fo...,BULLETS
4,cnbc,2218,Personal Finance,Secure 2.0 changes 3 key rules around required...,2023-01-03T19:26:50+00:00,President Joe Biden signed a $1.7 trillion leg...,* President Joe Biden signed a $1.7 trillion o...,BULLETS


In [3]:
clean_regex = re.compile(r"\*[\s\n]*(?=\*)")
sample_df["summary"] = sample_df.summary.apply(lambda s: clean_regex.sub(" ", s).strip())
sample_df.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type
0,reuters,41422,United Kingdom,UK's William Hill given record $24 million fin...,2023-03-28T07:26:00,"LONDON, March 28 (Reuters) - Britain's Gamblin...",* Fine is largest in UK gambling\n* Regulator ...,BULLETS
1,reuters,109510,U.S. Markets,Indexes slip with tech-related shares; consume...,2023-05-13T00:05:00,May 12 (Reuters) - U.S. stocks ended slightly ...,* U.S. consumer sentiment drops to six-month l...,BULLETS
2,reuters,48857,Exploration & ProductionClimate Change,Shell wins UK Supreme Court case on 2011 oil s...,2023-05-10T11:41:00,"LONDON, May 10 (Reuters) - The UK Supreme Cour...",* Nigerians trying to sue Shell over offshore ...,BULLETS
3,reuters,93908,Europe,Ukraine gets more U.S. aid as Russia-Iran ties...,2022-12-10T21:08:00,"KYIV, Dec 9 (Reuters) - The United States anno...",* Washington announces new military aid for Uk...,BULLETS
4,cnbc,2218,Personal Finance,Secure 2.0 changes 3 key rules around required...,2023-01-03T19:26:50+00:00,President Joe Biden signed a $1.7 trillion leg...,* President Joe Biden signed a $1.7 trillion o...,BULLETS


In [4]:
sample_df.summary_type.value_counts()

summary_type
BULLETS    2630
PLAIN        61
Name: count, dtype: int64

In [5]:
train_df = sample_df.iloc[:2000]
eval_df = sample_df.iloc[2000:]
eval_df.head()

Unnamed: 0,source,id,category,title,published,body,summary,summary_type
2000,reuters,35795,United Kingdom,UK's Sunak to meet EU chief in push to finalis...,2023-02-26T20:14:00,"LONDON, Feb 26 (Reuters) - British Prime Minis...",* UK and EU leaders to meet on Monday seeking ...,BULLETS
2001,cnbc,4169,CNBC Disruptor 50,The pandemic drove Clubhouse to a $4 billion v...,2023-04-27T23:02:05+00:00,In this photo illustration the Clubhouse logo ...,* Clubhouse said on Thursday that it's cutting...,BULLETS
2002,reuters,91712,Europe,"EU grants Ukraine candidate status, 'beginning...",2022-06-23T23:36:00,"BRUSSELS, June 23 (Reuters) - European Union l...","* Summit accepts Ukraine, Moldova as candidate...",BULLETS
2003,reuters,67607,Autos & Transportation,"Defying gloom, Ferrari sees strong demand for ...",2022-11-02T16:28:00,"MILAN, Nov 2 (Reuters) - Ferrari (RACE.MI) sai...",* Co lifts FY core profit forecast to over 1.7...,BULLETS
2004,reuters,42820,European Markets,"European shares log weekly gains, UK's blue-ch...",2023-02-03T17:27:00,Feb 3 (Reuters) - European shares rose on Frid...,"* STOXX 600 up 0.3%, logs gains for second str...",BULLETS


In [6]:
from datasets import Dataset, DatasetDict

model_checkpoint = "t5-small"
metric = load("rouge")
train_data = Dataset.from_pandas(train_df[["body", "summary", "summary_type"]])
eval_data = Dataset.from_pandas(eval_df[["body", "summary", "summary_type"]])
raw_datasets = DatasetDict({
    "train": train_data,
    "eval": eval_data
})

In [7]:
raw_datasets["eval"][1]

{'body': 'In this photo illustration the Clubhouse logo seen displayed on a smartphone screen.\n\nSocial audio platform Clubhouse announced Thursday that it was laying off half its staff in order to "reset" the company. It shouldn\\\'t come as a surprise.\n\nIf there was a posterchild for the tech industry\\\'s irrational exuberance during the Covid pandemic, it was Clubhouse.\n\nWith the physical world closed for business, consumers looked for other ways to congregate and find entertainment. So did celebrities. So did tech executives. So did venture capitalists.\n\nBack then, capital was still cheap and plentiful. Software was still perceived as "eating the world," in the famous words of investor Marc Andreessen. It was time for the next great social network. Clubhouse, which allowed people to listen in on discussions about topics including music, technology, fashion, technology and more technology, was on a viral curve. MC Hammer, Oprah Winfrey, and Mark Zuckerberg were there.\n\nIn 

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
prefix_bullets = "summarize in bullet points: "
prefix_plain = "summarize as paragraph: "
max_input_length = 2048
max_target_length = 512

def preprocess_function(examples):
    inputs = []
    for body, type in zip(examples["body"], examples["summary_type"]):
        if type == "BULLETS":
            inputs.append(prefix_bullets + body)
        elif type == "PLAIN":
            inputs.append(prefix_plain + body)
        else:
            raise ValueError("typo")
            
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    labels = tokenizer(text_target=examples["summary"], max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs


tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/691 [00:00<?, ? examples/s]

In [9]:
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM

model_name = "t5"
BATCH_TRAIN = 4
BATCH_EVAL = 8
GRADIENT_STEP = 1
LEARNING_RATE = 2e-5
EPOCHS = 4
LAMBDA = 0.01

args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_TRAIN,
    per_device_eval_batch_size=BATCH_EVAL,
    weight_decay=LAMBDA,
    num_train_epochs=EPOCHS,
    predict_with_generate=True,
    fp16=True
)

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [10]:
sample = tokenizer(prefix_bullets + tokenized_datasets["eval"][1]["body"], 
                   max_length=max_input_length, truncation=True, 
                   return_tensors="pt")
print(tokenizer.batch_decode(model.generate(**sample, max_new_tokens=512, temperature=0), skip_special_tokens=True)[0])



social media platform Clubhouse announced it was laying off half its staff. it was a viral curve, with a sluggish economy. founders say it's a "difficult time" for people to find their friends.


In [11]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    # Note that other metrics may not have a `use_aggregator` parameter
    # and thus will return a list, computing a metric for each sentence.
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True, use_aggregator=True)
    # Extract a few results
    result = {key: value * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["eval"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

try:
    results = trainer.train(resume_from_checkpoint = True)
except ValueError as e:
    results = trainer.train(resume_from_checkpoint = False)

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.891,2.307005,14.6879,5.1895,12.2995,12.7521,18.9493
2,2.5567,2.218881,17.4692,6.2135,14.5529,15.1276,18.9392
3,2.4838,2.188127,18.0004,6.5956,14.955,15.5417,18.8929


In [None]:
sample = tokenizer(prefix_bullets + tokenized_datasets["eval"][1]["body"], 
                   max_length=max_input_length, truncation=True, 
                   return_tensors="pt").to("cuda")
print(tokenizer.batch_decode(model.generate(**sample, max_new_tokens=512, temperature=0), skip_special_tokens=True)[0])