In [1]:
# ignore warnings

import warnings
warnings.filterwarnings("ignore")

Installing Dependencies

In [None]:
! pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
! pip install datasets
! pip install rouge

Load dataset

In [3]:
from datasets import load_dataset
multi_news = load_dataset("multi_news", split="train")

View Data

In [4]:
multi_news.to_pandas()

Unnamed: 0,document,summary
0,"National Archives \n \n Yes, it’s that time ag...",– The unemployment rate dropped to 8.2% last m...
1,LOS ANGELES (AP) — In her first interview sinc...,"– Shelly Sterling plans ""eventually"" to divorc..."
2,"GAITHERSBURG, Md. (AP) — A small, private jet ...",– A twin-engine Embraer jet that the FAA descr...
3,Tucker Carlson Exposes His Own Sexism on Twitt...,– Tucker Carlson is in deep doodoo with conser...
4,A man accused of removing another man's testic...,– What are the three most horrifying words in ...
...,...,...
44967,"More than 670,000 copies of the Pearls’ self-p...",– The deaths of three children have been linke...
44968,Seeking out cost-conscious consumers who have ...,"– Apple is hoping its new, cheaper iPhone can ..."
44969,Click to email this to a friend (Opens in new ...,"– January Jones, who plays the beleaguered wif..."
44970,"BARRINGTON, R.I. (AP) — Women clad in yoga pan...",– A Rhode Island man who penned a letter to th...


Train test split (80:20 ratio respectively)

In [5]:
multi_news = multi_news.train_test_split(test_size=0.2)

Tokenizing Training and testing sets

In [None]:
! pip install transformers

In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [8]:
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    labels = tokenizer(text=examples["summary"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [9]:
tokenized_multi_news = multi_news.map(preprocess_function, batched=True)

Map:   0%|          | 0/35977 [00:00<?, ? examples/s]

Map: 100%|██████████| 35977/35977 [01:14<00:00, 481.87 examples/s]
Map: 100%|██████████| 8995/8995 [00:22<00:00, 405.13 examples/s]


Model Loading

In [10]:
from transformers import DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model='t5-small')
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")




Defining Hyperparameters

In [11]:
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=15,
per_device_eval_batch_size=15,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=7,
fp16=True,
)


Initializing Trainer

In [12]:
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_multi_news["train"],
eval_dataset=tokenized_multi_news["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)


Train the Text Summarization Model

In [13]:
trainer.train()

  0%|          | 0/16793 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  3%|▎         | 500/16793 [05:38<3:02:26,  1.49it/s]

{'loss': 3.2675, 'learning_rate': 1.9405704757934856e-05, 'epoch': 0.21}


  6%|▌         | 1000/16793 [11:14<2:56:11,  1.49it/s]

{'loss': 3.0192, 'learning_rate': 1.881021854344072e-05, 'epoch': 0.42}


  9%|▉         | 1500/16793 [16:51<2:50:58,  1.49it/s]

{'loss': 2.9807, 'learning_rate': 1.8215923301375575e-05, 'epoch': 0.63}


 12%|█▏        | 2000/16793 [22:26<2:44:38,  1.50it/s]

{'loss': 2.9638, 'learning_rate': 1.7620437086881443e-05, 'epoch': 0.83}


                                                      
 14%|█▍        | 2399/16793 [28:52<2:43:10,  1.47it/s]

{'eval_loss': 2.723966360092163, 'eval_runtime': 118.4381, 'eval_samples_per_second': 75.947, 'eval_steps_per_second': 5.066, 'epoch': 1.0}


 15%|█▍        | 2500/16793 [30:00<2:38:59,  1.50it/s]  

{'loss': 2.9431, 'learning_rate': 1.7024950872387303e-05, 'epoch': 1.04}


 18%|█▊        | 3000/16793 [35:34<2:32:36,  1.51it/s]

{'loss': 2.9272, 'learning_rate': 1.642946465789317e-05, 'epoch': 1.25}


 21%|██        | 3500/16793 [41:09<2:26:54,  1.51it/s]

{'loss': 2.9155, 'learning_rate': 1.5833978443399035e-05, 'epoch': 1.46}


 24%|██▍       | 4000/16793 [46:45<2:22:56,  1.49it/s]

{'loss': 2.889, 'learning_rate': 1.5238492228904902e-05, 'epoch': 1.67}


 27%|██▋       | 4500/16793 [52:21<2:17:24,  1.49it/s]

{'loss': 2.8774, 'learning_rate': 1.4643006014410768e-05, 'epoch': 1.88}


                                                      
 29%|██▊       | 4798/16793 [57:42<2:14:26,  1.49it/s]

{'eval_loss': 2.6847105026245117, 'eval_runtime': 120.0449, 'eval_samples_per_second': 74.93, 'eval_steps_per_second': 4.998, 'epoch': 2.0}


 30%|██▉       | 5000/16793 [59:58<2:12:50,  1.48it/s]  

{'loss': 2.8812, 'learning_rate': 1.4047519799916634e-05, 'epoch': 2.08}


 33%|███▎      | 5500/16793 [1:05:35<2:05:31,  1.50it/s]

{'loss': 2.8879, 'learning_rate': 1.34520335854225e-05, 'epoch': 2.29}


 36%|███▌      | 6000/16793 [1:11:11<2:00:23,  1.49it/s]

{'loss': 2.8569, 'learning_rate': 1.2856547370928365e-05, 'epoch': 2.5}


 39%|███▊      | 6500/16793 [1:16:51<1:55:52,  1.48it/s]

{'loss': 2.8706, 'learning_rate': 1.2262252128863217e-05, 'epoch': 2.71}


 42%|████▏     | 7000/16793 [1:22:31<1:51:17,  1.47it/s]

{'loss': 2.8603, 'learning_rate': 1.1666765914369082e-05, 'epoch': 2.92}


                                                        
 43%|████▎     | 7197/16793 [1:26:44<1:46:48,  1.50it/s]

{'eval_loss': 2.6659939289093018, 'eval_runtime': 119.7954, 'eval_samples_per_second': 75.086, 'eval_steps_per_second': 5.009, 'epoch': 3.0}


 45%|████▍     | 7500/16793 [1:30:08<1:43:07,  1.50it/s] 

{'loss': 2.8441, 'learning_rate': 1.1071279699874948e-05, 'epoch': 3.13}


 48%|████▊     | 8000/16793 [1:35:40<1:37:11,  1.51it/s]

{'loss': 2.8463, 'learning_rate': 1.0475793485380814e-05, 'epoch': 3.33}


 51%|█████     | 8500/16793 [1:41:13<1:31:38,  1.51it/s]

{'loss': 2.8609, 'learning_rate': 9.88030727088668e-06, 'epoch': 3.54}


 54%|█████▎    | 9000/16793 [1:46:45<1:26:05,  1.51it/s]

{'loss': 2.8532, 'learning_rate': 9.284821056392545e-06, 'epoch': 3.75}


 57%|█████▋    | 9500/16793 [1:52:17<1:20:26,  1.51it/s]

{'loss': 2.8381, 'learning_rate': 8.689334841898411e-06, 'epoch': 3.96}


                                                        
 57%|█████▋    | 9596/16793 [1:55:20<1:19:53,  1.50it/s]

{'eval_loss': 2.653245687484741, 'eval_runtime': 118.1676, 'eval_samples_per_second': 76.121, 'eval_steps_per_second': 5.078, 'epoch': 4.0}


 60%|█████▉    | 10000/16793 [1:59:47<1:15:00,  1.51it/s]

{'loss': 2.8341, 'learning_rate': 8.093848627404277e-06, 'epoch': 4.17}


 63%|██████▎   | 10500/16793 [2:05:20<1:09:43,  1.50it/s]

{'loss': 2.8292, 'learning_rate': 7.500744357768118e-06, 'epoch': 4.38}


 66%|██████▌   | 11000/16793 [2:10:53<1:03:57,  1.51it/s]

{'loss': 2.8363, 'learning_rate': 6.905258143273984e-06, 'epoch': 4.59}


 68%|██████▊   | 11500/16793 [2:16:26<58:25,  1.51it/s]  

{'loss': 2.8351, 'learning_rate': 6.30977192877985e-06, 'epoch': 4.79}


                                                         
 71%|███████▏  | 11995/16793 [2:23:53<53:14,  1.50it/s]

{'eval_loss': 2.646230936050415, 'eval_runtime': 118.2104, 'eval_samples_per_second': 76.093, 'eval_steps_per_second': 5.076, 'epoch': 5.0}


 71%|███████▏  | 12000/16793 [2:23:57<12:12:42,  9.17s/it]

{'loss': 2.8361, 'learning_rate': 5.7142857142857145e-06, 'epoch': 5.0}


 74%|███████▍  | 12500/16793 [2:29:30<47:35,  1.50it/s]   

{'loss': 2.8204, 'learning_rate': 5.11879949979158e-06, 'epoch': 5.21}


 77%|███████▋  | 13000/16793 [2:35:03<42:10,  1.50it/s]  

{'loss': 2.8204, 'learning_rate': 4.524504257726434e-06, 'epoch': 5.42}


 80%|████████  | 13500/16793 [2:40:36<36:27,  1.51it/s]

{'loss': 2.8248, 'learning_rate': 3.929018043232299e-06, 'epoch': 5.63}


 83%|████████▎ | 14000/16793 [2:46:09<30:59,  1.50it/s]

{'loss': 2.8313, 'learning_rate': 3.3335318287381653e-06, 'epoch': 5.84}


                                                       
 86%|████████▌ | 14394/16793 [2:52:31<27:05,  1.48it/s]

{'eval_loss': 2.641486406326294, 'eval_runtime': 118.5907, 'eval_samples_per_second': 75.849, 'eval_steps_per_second': 5.059, 'epoch': 6.0}


 86%|████████▋ | 14500/16793 [2:53:42<25:29,  1.50it/s]   

{'loss': 2.8245, 'learning_rate': 2.7380456142440306e-06, 'epoch': 6.04}


 89%|████████▉ | 15000/16793 [2:59:21<19:55,  1.50it/s]

{'loss': 2.8169, 'learning_rate': 2.142559399749896e-06, 'epoch': 6.25}


 92%|█████████▏| 15500/16793 [3:04:58<14:18,  1.51it/s]

{'loss': 2.8242, 'learning_rate': 1.5482641576847499e-06, 'epoch': 6.46}


 95%|█████████▌| 16000/16793 [3:10:31<08:47,  1.50it/s]

{'loss': 2.8226, 'learning_rate': 9.527779431906152e-07, 'epoch': 6.67}


 98%|█████████▊| 16500/16793 [3:16:04<03:15,  1.50it/s]

{'loss': 2.829, 'learning_rate': 3.58482701125469e-07, 'epoch': 6.88}


                                                       
100%|██████████| 16793/16793 [3:25:48<00:00,  1.36it/s]

{'eval_loss': 2.640953540802002, 'eval_runtime': 120.0139, 'eval_samples_per_second': 74.95, 'eval_steps_per_second': 4.999, 'epoch': 7.0}
{'train_runtime': 12348.117, 'train_samples_per_second': 20.395, 'train_steps_per_second': 1.36, 'train_loss': 2.876775576772595, 'epoch': 7.0}





TrainOutput(global_step=16793, training_loss=2.876775576772595, metrics={'train_runtime': 12348.117, 'train_samples_per_second': 20.395, 'train_steps_per_second': 1.36, 'train_loss': 2.876775576772595, 'epoch': 7.0})

Evaluate model on a single document

In [14]:
document2="Real Madrid, established in 1902, stands as one of the most iconic football clubs globally, boasting an unparalleled legacy of excellence. Initially founded as the Madrid Football Club, it earned the royal title Real in 1920 from King Alfonso XIII, setting the stage for a journey marked by triumphs and historic moments. The Galácticos era of the early 2000s, featuring legendary players like Zinedine Zidane and Ronaldo Nazário, elevated the club to unprecedented heights, securing multiple UEFA Champions League titles. Domestically, Real Madrid's La Liga dominance, with a record 34 titles, reflects a consistent pursuit of excellence, highlighted by the intense El Clásico rivalry with FC Barcelona. The Santiago Bernabéu Stadium, an iconic footballing cathedral, has been witness to countless historic moments, while legends like Alfredo Di Stéfano and Cristiano Ronaldo have left an indelible mark. Beyond the pitch, the commitment to nurturing talent at La Fábrica ensures a sustainable future, while the global fanbase, united in passion for the white jersey, solidifies Real Madrid's status as a footballing institution. Despite challenges, Real Madrid's ability to overcome adversity adds another chapter to its storied legacy, ensuring its tale resonates across generations, inspiring football enthusiasts worldwide."
human_summary2="Established in 1902, Real Madrid is a football giant with a rich history of success. From its royal origins as the Madrid Football Club to the Galácticos era's international triumphs, the club has secured a record 34 La Liga titles and 13 UEFA Champions League victories. Legends like Zinedine Zidane and Cristiano Ronaldo have graced the Santiago Bernabéu, symbolizing the club's commitment to excellence. The global fanbase, La Fábrica youth academy, and ongoing stadium renovations demonstrate Real Madrid's enduring influence. Despite challenges, the club's resilience and iconic moments ensure its legacy as a footballing institution transcends generations worldwide."

In [15]:
def predict_summary(document):
  device = model.device
  tokenized = tokenizer([document], truncation =True, padding ='longest',return_tensors='pt')
  tokenized = {k: v.to(device) for k, v in tokenized.items()}
  tokenized_result = model.generate(**tokenized, max_length=128)
  tokenized_result = tokenized_result.to('cpu')
  predicted_summary = tokenizer.decode(tokenized_result[0])
  return predicted_summary

In [16]:
human_summary2

"Established in 1902, Real Madrid is a football giant with a rich history of success. From its royal origins as the Madrid Football Club to the Galácticos era's international triumphs, the club has secured a record 34 La Liga titles and 13 UEFA Champions League victories. Legends like Zinedine Zidane and Cristiano Ronaldo have graced the Santiago Bernabéu, symbolizing the club's commitment to excellence. The global fanbase, La Fábrica youth academy, and ongoing stadium renovations demonstrate Real Madrid's enduring influence. Despite challenges, the club's resilience and iconic moments ensure its legacy as a footballing institution transcends generations worldwide."

In [17]:
predicted_summary = predict_summary(document2)
predicted_summary

'<pad> – Real Madrid, established in 1902, is one of the most iconic football clubs globally, boasting an unparalleled legacy of excellence, highlighting the intense El Clásico rivalry with FC Barcelona. The club, established in 1902, earned the royal title Real in 1920 from King Alfonso XIII, setting the stage for a journey marked by triumphs and historic moments. The club, which earned the royal title Real in 1920 from King Alfonso XIII, earned the royal title Real in 1920, securing multiple UEFA Champions League titles, s'

Evaluate using Rouge Scores

ROUGE (Recall-Oriented Understudy for Gisting Evaluation) is a suite of metrics for evaluating the quality of text summaries. It compares a set of reference summaries with a set of generated summaries, and computes a score based on the overlap between the two.

In [18]:
validation_data = load_dataset("multi_news", split="validation")

In [29]:
validation_data=validation_data.to_pandas()

In [35]:
val_data=validation_data.sample(1000)

In [36]:
val_data.shape

(1000, 2)

In [37]:
from rouge import Rouge

def get_rouge_scores(actual_summary, predicted_summary):
    rouge = Rouge()
    scores = rouge.get_scores(predicted_summary, actual_summary)
    return [scores[0]['rouge-1']['f'], scores[0]['rouge-2']['f'], scores[0]['rouge-l']['f']]

In [39]:
from tqdm import tqdm

rouge1_scores = []
rouge2_scores = []
rougel_scores = []

pred_summary_list = []

for index, row in tqdm(val_data.iterrows(), total=len(val_data)):

    doc = row['document']
    pred_summary = predict_summary(doc)
    human_summary = row['summary']

    score = get_rouge_scores(human_summary, pred_summary)

    rouge1_scores.append(score[0])
    rouge2_scores.append(score[1])
    rougel_scores.append(score[2])

    pred_summary_list.append(pred_summary)

val_data["pred_summary"] = pred_summary_list

val_data['rouge1'] = rouge1_scores
val_data['rouge2'] = rouge2_scores
val_data['rougel'] = rougel_scores

val_data


100%|██████████| 1000/1000 [26:28<00:00,  1.59s/it]


Unnamed: 0,document,summary,pred_summary,rouge1,rouge2,rougel
4830,"Good Country Index Is Released, US Not In The ...","– Skal, Sweden! Raise a glass to the Scandinav...",<pad> – The New York Times is announcing the r...,0.260000,0.085937,0.220000
1255,Fans wear shirts emblazoned with the misspelle...,– The Colorado Rockies love their All-Star sho...,<pad> – Troy Tulowitzki has a tough name to sp...,0.389610,0.104712,0.324675
80,A CAR park in a Highland town has been confirm...,– Archaeologists have uncovered another parkin...,"<pad> – A medieval Norse parliament is a ""Thin...",0.218182,0.065147,0.200000
3044,"share tweet pin email \n \n Heath Harding, 18,...","– There's a family of 12 in Montgomery, Ala., ...","<pad> – Heath Harding, 18, says he's ""the slac...",0.200000,0.024390,0.166667
4486,PITTSBURGH (KDKA) — A former Pittsburgh Public...,– An elementary school teacher is in hot water...,<pad> – A former Pittsburgh Public School teac...,0.286792,0.067606,0.226415
...,...,...,...,...,...,...
1025,A British surfer has broken his back in a wipe...,– A big-wave surfer is already dreaming of get...,<pad> – A British surfer has broken his back i...,0.253165,0.067278,0.236287
181,Obama spokesman: That Romney birth certificate...,"– Mitt Romney's joke at a rally today that ""no...","<pad> – The Obama birth certificate joke is ""g...",0.146119,0.000000,0.146119
3323,Demand for all things Napoleon has often sent ...,– Napoleon's hats were a big deal. The BBC exp...,<pad> – A two-cornered military dress hat said...,0.202335,0.033708,0.178988
3366,"Centenarians, people 100 years or older, are m...",– Centenarians really are different than most ...,<pad> – A new study has found that people 100 ...,0.342342,0.080000,0.315315


In [40]:
# average rouge 1
val_data['rouge1'].mean()

0.2661834072645412

In [41]:
# average rouge 2
val_data['rouge2'].mean()

0.08259645769586553

In [42]:
# average rouge l
val_data['rougel'].mean()

0.24077162517175035