## BERT Summary Extractor

in this notebook we load a pretrained BERT model, then use its encodings to evaluate wether a sentence is important or not.

The usual GOTO would be to append a classifier layer at the end of BERT to get our score, but nothing prevents us from computing the loss directly on the embedding (faster experiments)

In [None]:
%%capture
!pip install transformers
!pip install datasets
!pip install rouge_score

In [None]:
from tqdm import tqdm
import gc, re

import torch, torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from datasets import load_dataset, load_metric
from transformers import BertModel, BertTokenizer
from transformers.optimization import AdamW

In [None]:
encoder_max_length = 400
decoder_max_length = 75
batch_size         = 4
model_name         = 'bert-base-uncased'

In [None]:
%%capture

# just to download stuff at the beginning, hiding the logs after
BertModel.from_pretrained(model_name)
load_dataset("cnn_dailymail", "3.0.0")
BertTokenizer.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
"""
Uses https://huggingface.co/datasets/viewer/?dataset=cnn_dailymail and a tokenizer
to create the hot-encodings to feed to the NN
"""

class CNNDataset(Dataset):
  def __init__(self, mode='train', n_articles=10000):
    super().__init__()
    raw_data  = load_dataset("cnn_dailymail", "3.0.0", split=f"{mode}[:{n_articles}]")  # download dataset from huggingface
    self.tokenizer = BertTokenizer.from_pretrained(model_name)    # load tokenizer to hot-encode articles
    self.data = self.preprocess_data(raw_data)                                          # preprocess data (hot-encode article/highlights)

  def preprocess_data(self, raw_data):
    data = []
    punctuations = set(self.tokenizer(". , ; : ? !"))
    for i in range(len(raw_data)):

      # hot encode article
      hot_article = self.tokenizer(raw_data[i]['article'], padding="max_length", truncation=True, max_length=encoder_max_length, return_tensors="pt").input_ids[0]
      hot_high  = self.tokenizer(raw_data[i]['highlights'], padding="max_length", truncation=True, max_length=decoder_max_length, return_tensors="pt").input_ids[0]

      # n-grams of highlights
      high_2grams, words = set(), set()
      prec = hot_high[0].item()
      for tok in hot_high[1:]:
        tok = tok.item()
        high_2grams.add((prec, tok))
        words.add(tok)
        prec = tok

      # n-grams of article (find best sentences)
      labels = torch.zeros_like(hot_article).float()
      for i, tok in enumerate(hot_article[1:]):
        tok = tok.item()
        if (prec, tok) in high_2grams:
          labels[i+1] = 1.
        prec = tok

      # important sentences
      labels2 = torch.zeros_like(hot_article).float()
      sum, prev = 0, 0
      for i, tok in enumerate(hot_article):
        tok = tok.item()
        sum += (tok in words and not tok in punctuations)*0.4 + labels[i]
        if tok in punctuations:
          labels[prev:i+1] = sum / (i+1-prev)
          sum, prev = 0, i+1


      data.append( {
          'input_ids':hot_article,       # article
          'decoder_input_ids':hot_high,  # summary
          'labels':labels,                            # summary (excluding <pad> from loss)
          'labels2':labels2,                            # summary (excluding <pad> from loss)
          'attention_mask':self.tokenizer(raw_data[i]['article'], padding="max_length", truncation=True, max_length=encoder_max_length, return_tensors="pt").attention_mask[0],
          'decoder_attention_mask':self.tokenizer(raw_data[i]['highlights'], padding="max_length", truncation=True, max_length=decoder_max_length, return_tensors="pt").attention_mask[0]
      } )
    return data
  
  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, idx):
    return self.data[idx]


In [None]:
class ModelSummarizer(nn.Module):
  def __init__(self):
    super().__init__()
    self.bert = BertModel.from_pretrained(model_name).cuda()

  def forward(self, batch_data):
    x = self.bert(**batch_data)
    return x[0][:, :, :2]

In [None]:
"""
Instantiate model, dataloader & optimizer
"""

# model trained on xetreme summarization task
model = ModelSummarizer().cuda()

# data
if 'train_data' not in dir():
  train_data = CNNDataset(mode='train', n_articles=10000)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, pin_memory=True)

# optimizer
params = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "LayerNorm.weight"])],
        "weight_decay": 1e-8,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in ["bias", "LayerNorm.weight"])],
        "weight_decay": 0.0,
    },
]
optim = torch.optim.AdamW(params, lr=3e-4)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234

## Finetuning not working

--> I think that the model was pretrained with a particular loss, for this reason finetuning do not work properly (or we just lack computing power)

In [None]:

model.train()
for epoch in range(3):
  tot_loss = []
  for data in tqdm(train_loader, postfix=True):

    # move to GPU
    for k in data:
      data[k] = data[k].cuda()

    # forward
    input = {k:v for k,v in data.items() if k in {'input_ids','attention_mask'}}
    logits = model(input)

    important_sentences = torch.stack((data['labels'], data['labels2']), dim=2)
    loss = nn.MSELoss()(logits, important_sentences)

    # backward
    loss.backward()
    optim.step()
    optim.zero_grad()
    tot_loss.append(loss.item())

    # print stats
    if len(tot_loss)%100==0: 
      print(f'loss e:{epoch} = {1000*sum(tot_loss)/len(tot_loss):.5f}')
  print(f'FINAL loss e:{epoch} = {1000*sum(tot_loss)/len(tot_loss):.5f}')
  tot_loss=[]

  4%|▍         | 100/2500 [01:07<26:51,  1.49it/sTrue]

loss e:0 = 108.38672


  8%|▊         | 200/2500 [02:15<25:49,  1.48it/sTrue]

loss e:0 = 91.03345


 12%|█▏        | 300/2500 [03:22<24:33,  1.49it/sTrue]

loss e:0 = 79.37600


 16%|█▌        | 400/2500 [04:29<23:25,  1.49it/sTrue]

loss e:0 = 79.16633


 20%|██        | 500/2500 [05:36<22:20,  1.49it/sTrue]

loss e:0 = 74.73874


 24%|██▍       | 600/2500 [06:43<21:17,  1.49it/sTrue]

loss e:0 = 70.69356


 28%|██▊       | 700/2500 [07:50<20:07,  1.49it/sTrue]

loss e:0 = 67.50783


 32%|███▏      | 800/2500 [08:57<19:05,  1.48it/sTrue]

loss e:0 = 65.85662


 36%|███▌      | 900/2500 [10:04<17:49,  1.50it/sTrue]

loss e:0 = 64.15568


 40%|████      | 1000/2500 [11:11<16:47,  1.49it/sTrue]

loss e:0 = 63.24111


 44%|████▍     | 1100/2500 [12:18<15:38,  1.49it/sTrue]

loss e:0 = 62.53050


 48%|████▊     | 1200/2500 [13:26<14:35,  1.48it/sTrue]

loss e:0 = 61.82925


 52%|█████▏    | 1300/2500 [14:33<13:27,  1.49it/sTrue]

loss e:0 = 61.10020


 56%|█████▌    | 1400/2500 [15:40<12:18,  1.49it/sTrue]

loss e:0 = 60.56487


 60%|██████    | 1500/2500 [16:47<11:08,  1.50it/sTrue]

loss e:0 = 60.20140


 64%|██████▍   | 1600/2500 [17:54<10:02,  1.49it/sTrue]

loss e:0 = 59.78812


 68%|██████▊   | 1700/2500 [19:01<08:56,  1.49it/sTrue]

loss e:0 = 59.12681


 72%|███████▏  | 1800/2500 [20:08<07:49,  1.49it/sTrue]

loss e:0 = 58.93726


 76%|███████▌  | 1900/2500 [21:15<06:42,  1.49it/sTrue]

loss e:0 = 58.56278


 80%|████████  | 2000/2500 [22:22<05:36,  1.49it/sTrue]

loss e:0 = 58.29725


 84%|████████▍ | 2100/2500 [23:29<04:28,  1.49it/sTrue]

loss e:0 = 58.27571


 88%|████████▊ | 2200/2500 [24:36<03:21,  1.49it/sTrue]

loss e:0 = 58.08323


 92%|█████████▏| 2300/2500 [25:43<02:14,  1.49it/sTrue]

loss e:0 = 57.74633


 96%|█████████▌| 2400/2500 [26:50<01:07,  1.49it/sTrue]

loss e:0 = 57.66664


100%|██████████| 2500/2500 [27:57<00:00,  1.49it/sTrue]


loss e:0 = 57.32804
FINAL loss e:0 = 57.32804


  4%|▍         | 100/2500 [01:07<26:46,  1.49it/sTrue]

loss e:1 = 54.36537


  8%|▊         | 200/2500 [02:14<25:42,  1.49it/sTrue]

loss e:1 = 54.77852


 12%|█▏        | 300/2500 [03:21<24:34,  1.49it/sTrue]

loss e:1 = 55.12221


 16%|█▌        | 400/2500 [04:28<23:29,  1.49it/sTrue]

loss e:1 = 54.08503


 20%|██        | 500/2500 [05:36<22:22,  1.49it/sTrue]

loss e:1 = 54.47310


 24%|██▍       | 600/2500 [06:43<21:22,  1.48it/sTrue]

loss e:1 = 53.72087


 28%|██▊       | 700/2500 [07:50<20:07,  1.49it/sTrue]

loss e:1 = 52.91765


 32%|███▏      | 800/2500 [08:57<19:05,  1.48it/sTrue]

loss e:1 = 53.01071


 36%|███▌      | 900/2500 [10:04<17:53,  1.49it/sTrue]

loss e:1 = 52.69689


 40%|████      | 1000/2500 [11:11<16:51,  1.48it/sTrue]

loss e:1 = 52.88388


 44%|████▍     | 1100/2500 [12:19<15:42,  1.49it/sTrue]

loss e:1 = 53.05958


 48%|████▊     | 1200/2500 [13:26<14:35,  1.49it/sTrue]

loss e:1 = 53.12324


 52%|█████▏    | 1300/2500 [14:33<13:22,  1.49it/sTrue]

loss e:1 = 53.03467


 56%|█████▌    | 1400/2500 [15:40<12:18,  1.49it/sTrue]

loss e:1 = 53.07370


 60%|██████    | 1500/2500 [16:47<11:13,  1.48it/sTrue]

loss e:1 = 53.20308


 64%|██████▍   | 1600/2500 [17:54<10:05,  1.49it/sTrue]

loss e:1 = 53.21259


 68%|██████▊   | 1700/2500 [19:01<08:58,  1.49it/sTrue]

loss e:1 = 52.92717


 72%|███████▏  | 1800/2500 [20:09<07:52,  1.48it/sTrue]

loss e:1 = 53.06513


 76%|███████▌  | 1900/2500 [21:16<06:43,  1.49it/sTrue]

loss e:1 = 52.98858


 80%|████████  | 2000/2500 [22:23<05:35,  1.49it/sTrue]

loss e:1 = 52.99394


 84%|████████▍ | 2100/2500 [23:30<04:30,  1.48it/sTrue]

loss e:1 = 53.21387


 88%|████████▊ | 2200/2500 [24:38<03:21,  1.49it/sTrue]

loss e:1 = 53.23862


 92%|█████████▏| 2300/2500 [25:45<02:14,  1.49it/sTrue]

loss e:1 = 53.10579


 96%|█████████▌| 2400/2500 [26:52<01:07,  1.49it/sTrue]

loss e:1 = 53.21461


100%|██████████| 2500/2500 [27:59<00:00,  1.49it/sTrue]


loss e:1 = 53.04889
FINAL loss e:1 = 53.04889


  4%|▍         | 100/2500 [01:07<26:46,  1.49it/sTrue]

loss e:2 = 54.10666


  8%|▊         | 200/2500 [02:14<25:41,  1.49it/sTrue]

loss e:2 = 54.55161


 12%|█▏        | 300/2500 [03:21<24:36,  1.49it/sTrue]

loss e:2 = 54.91836


 16%|█▌        | 400/2500 [04:28<23:30,  1.49it/sTrue]

loss e:2 = 53.91171


 20%|██        | 500/2500 [05:35<22:23,  1.49it/sTrue]

loss e:2 = 55.44209


 24%|██▍       | 600/2500 [06:42<21:23,  1.48it/sTrue]

loss e:2 = 54.51003


 28%|██▊       | 700/2500 [07:49<20:16,  1.48it/sTrue]

loss e:2 = 53.56457


 32%|███▏      | 800/2500 [08:57<19:05,  1.48it/sTrue]

loss e:2 = 53.54537


 36%|███▌      | 900/2500 [10:04<17:51,  1.49it/sTrue]

loss e:2 = 53.16516


 40%|████      | 1000/2500 [11:11<16:48,  1.49it/sTrue]

loss e:2 = 53.27829


 44%|████▍     | 1100/2500 [12:18<15:42,  1.49it/sTrue]

loss e:2 = 53.39341


 48%|████▊     | 1200/2500 [13:25<14:30,  1.49it/sTrue]

loss e:2 = 53.41968


 52%|█████▏    | 1300/2500 [14:33<13:22,  1.49it/sTrue]

loss e:2 = 53.30017


 56%|█████▌    | 1400/2500 [15:40<12:18,  1.49it/sTrue]

loss e:2 = 53.32029


 60%|██████    | 1500/2500 [16:47<11:13,  1.48it/sTrue]

loss e:2 = 53.42912


 64%|██████▍   | 1600/2500 [17:54<10:05,  1.49it/sTrue]

loss e:2 = 53.42063


 68%|██████▊   | 1700/2500 [19:01<08:56,  1.49it/sTrue]

loss e:2 = 53.12414


 72%|███████▏  | 1800/2500 [20:08<07:49,  1.49it/sTrue]

loss e:2 = 53.24325


 76%|███████▌  | 1900/2500 [21:15<06:43,  1.49it/sTrue]

loss e:2 = 53.15115


 80%|████████  | 2000/2500 [22:22<05:34,  1.49it/sTrue]

loss e:2 = 53.14184


 84%|████████▍ | 2100/2500 [23:29<04:27,  1.50it/sTrue]

loss e:2 = 53.35096


 88%|████████▊ | 2200/2500 [24:36<03:21,  1.49it/sTrue]

loss e:2 = 53.36307


 92%|█████████▏| 2300/2500 [25:43<02:14,  1.49it/sTrue]

loss e:2 = 53.22148


 96%|█████████▌| 2400/2500 [26:50<01:07,  1.49it/sTrue]

loss e:2 = 53.32303


100%|██████████| 2500/2500 [27:57<00:00,  1.49it/sTrue]

loss e:2 = 53.15068
FINAL loss e:2 = 53.15068





In [None]:
if 'test_data' not in dir():
  test_data = CNNDataset(mode='test', n_articles=1000)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, pin_memory=True)

predicted, real = [], []

model.eval()
with torch.no_grad():
  for batch in test_loader:

    # forward
    input = {k:v.cuda() for k,v in data.items() if k in {'input_ids','attention_mask'}}
    logits = model(input)
    logits = logits.sum(dim=2)[0]

    # score sentences
    punctuations = test_data.tokenizer.encode(". ; ? !")
    score, sum = [], 0
    for val, tok in zip(logits, data['input_ids'][0]):
      sum += val.item()
      if tok in punctuations:
        score.append(sum)
        sum = 0
    

    # convert hot-encode into text
    predicted_summary = test_data.tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True)
    blocks = [0]
    for i, char in enumerate(predicted_summary):
      if char in {".",";","?","!"}:
        blocks.append(i)
    blocks.append(len(predicted_summary))

    # get top 3
    score = score[:len(blocks)-1]
    if len(score)<3:
      idx = list(range(len(score)))
    else:
      _, idx = torch.tensor(score).topk(3)

    final = ''
    for i in idx:
      i = int(i)
      final += predicted_summary[blocks[i]+2:blocks[i+1]]
    
    # save decoded summaries
    real_summary = test_data.tokenizer.decode(batch['decoder_input_ids'][0], skip_special_tokens=True)
    predicted.append(final)
    real.append(real_summary)

  rouge = load_metric('rouge')

Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)


Downloading:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [None]:
score = rouge.compute(predictions=predicted, references=real)

for k in score:
  print(k, score[k].mid)

print('\n', predicted[2], '\n\n', real[2])

rouge1 Score(precision=0.1936956623231523, recall=0.2970620887408911, fmeasure=0.2226695122536908)
rouge2 Score(precision=0.05014442975045013, recall=0.0796845203961831, fmeasure=0.05839132807868044)
rougeL Score(precision=0.13192943844870555, recall=0.20167446552470977, fmeasure=0.1510992744850564)
rougeLsum Score(precision=0.13196855563764165, recall=0.2014674954178118, fmeasure=0.15113328949771845)

 i asked him about the science behind climate change and public health and the message he wants the average american to take away, as well as how enforceable his action plan iswhile in lhe credits the clean air act with making americans " a lot " healthier, in addition to being able to " see the mountains in the background because they aren't covered in smog 

 " no challenge poses more of a public threat than climate change, " the president says. he credits the clean air act with making americans " a lot " healthier.
