In [None]:
# install nlp
!pip install nlp
# Make sure that we have a recent version of pyarrow in the session before we continue - otherwise reboot Colab to activate it
import pyarrow
if int(pyarrow.__version__.split('.')[1]) < 16:
    import os
    os.kill(os.getpid(), 9)



In [None]:
import nlp

In [None]:
import logging
logging.basicConfig(level=logging.ERROR) #log only errors and supress warnings

## Dataset Used- CNN-DailyMail news articles and highlights dataset.



In [None]:
dataset_train = nlp.load_dataset('cnn_dailymail', '3.0.0', split='train[0:10000]') #287000+ samples in the training dataset, everything is cached so once youve loaded the cnn-dm dataset, you can alter the split sizes in no time 
dataset_valid = nlp.load_dataset('cnn_dailymail', '3.0.0', split='validation[:1000]')#13000+ samples in the validation set, using 1000

In [None]:
print(dataset_train)
print(dataset_valid)

Dataset(schema: {'article': 'string', 'highlights': 'string', 'id': 'string'}, num_rows: 10000)
Dataset(schema: {'article': 'string', 'highlights': 'string', 'id': 'string'}, num_rows: 1000)


## Rouge Metric for Model Evaluation 

In [None]:
!pip install rouge_score rouge_score
rouge_metric = nlp.load_metric("rouge") #rouge is a very standard metric for seq2seq tasks, especially summarization



In [None]:
!pip install transformers



In [None]:
import transformers
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

## Model to fine tune: BartForConditionalGeneration (seq2seq)

In [None]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')#bart-large is also available, but bart-base is much smaller and switching to large may or may not be worth the extra computation
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(1026, 768, padding_idx=1)
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): SelfAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
 

In [None]:
MAX_LEN = 1024
MAX_SUMMARY_LEN = 200
BATCH_SIZE = 2 #Experimented with sizes of 2, 4, 8, and 16. 2 seems to be optimal
EPOCHS = 2 #going over 4 epochs may result in overfitting
LEARNING_RATE = 5e-5 #any value from 3e-4 to 5e-5 works reasonably well
ADAM_EPS = 1e-8
NUM_BEAMS = 4

In [None]:
params = {'batch_size': BATCH_SIZE,
          'shuffle': True,
          'num_workers': 0
        }

train_loader = DataLoader(dataset_train, **params)
valid_loader = DataLoader(dataset_valid, **params)

In [None]:
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE, eps=ADAM_EPS)

### Train Function
Pytorch implementation

In [None]:
def train(epoch):
  model.train()
  for _,data in enumerate(train_loader, 0):
    source = tokenizer.batch_encode_plus(data['article'], max_length=MAX_LEN, pad_to_max_length=True,return_tensors='pt', truncation='only_first')
    target = tokenizer.batch_encode_plus(data['highlights'], max_length=MAX_LEN, pad_to_max_length=True,return_tensors='pt', truncation='only_first')
    source_ids = source['input_ids'].squeeze()
    source_mask = source['attention_mask'].squeeze()
    target_ids = target['input_ids'].squeeze()
    target_mask = target['attention_mask'].squeeze()
    #
    y = target_ids.to(device, dtype = torch.long)
    y_ids = y[:, :-1].contiguous() #to create labels for target, we need to right shift the target ids
    lm_labels = y[:, 1:].clone().detach()
    lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100 #padding tokens should be marked with -100 value so the model can output correct loss
    ids = source_ids.to(device, dtype = torch.long)
    mask = source_mask.to(device, dtype = torch.long)

    outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, lm_labels=lm_labels)
    loss = outputs[0]
        
    if _%100==0:
      print(f'Epoch: {epoch}, Loss:  {loss}')
        
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

### Validation Function

In [None]:
def validate():
  model.eval()
  prediction = []
  reference = []
  with torch.no_grad():
      for _, data in enumerate(valid_loader, 0):
        source = tokenizer.batch_encode_plus(data['article'], truncation='only_first', max_length=MAX_LEN, pad_to_max_length=True,return_tensors='pt')
        target = tokenizer.batch_encode_plus(data['highlights'], truncation='only_first', max_length=MAX_LEN, pad_to_max_length=True,return_tensors='pt')
        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target['attention_mask'].squeeze()

        y = target_ids.to(device, dtype = torch.long)
        ids = source_ids.to(device, dtype = torch.long)
        mask = source_mask.to(device, dtype = torch.long)

        generated_ids = model.generate(
            input_ids = ids,
            attention_mask = mask,
            max_length=MAX_SUMMARY_LEN, 
            num_beams=NUM_BEAMS,
            repetition_penalty=2.5, 
            no_repeat_ngram_size=4,
            early_stopping=True
            )#can experiment with different values of num_beams and penalties
        pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
        ref = [tokenizer.decode(r, skip_special_tokens=True, clean_up_tokenization_spaces=True)for r in y]
        if(_%100 == 0):
          print(f"{_} batches complete")

        prediction.extend(pred)
        reference.extend(ref)
  return prediction, reference

### Training for 2 epochs (experimented with the number of epochs and ended up with 2

In [None]:
for epoch in range(EPOCHS):
    train(epoch)

Epoch: 0, Loss:  7.20695686340332
Epoch: 0, Loss:  2.081721067428589
Epoch: 0, Loss:  3.5891597270965576
Epoch: 0, Loss:  2.746739387512207
Epoch: 0, Loss:  2.469104528427124
Epoch: 0, Loss:  3.4169774055480957
Epoch: 0, Loss:  2.6619198322296143
Epoch: 0, Loss:  2.9047155380249023
Epoch: 0, Loss:  3.8466298580169678
Epoch: 0, Loss:  2.9160592555999756
Epoch: 0, Loss:  3.349177837371826
Epoch: 0, Loss:  2.5790867805480957
Epoch: 0, Loss:  3.2401647567749023
Epoch: 0, Loss:  2.3947854042053223
Epoch: 0, Loss:  2.2626261711120605
Epoch: 0, Loss:  2.6098427772521973
Epoch: 0, Loss:  3.8146822452545166
Epoch: 0, Loss:  2.4320099353790283
Epoch: 0, Loss:  2.3282997608184814
Epoch: 0, Loss:  2.0354394912719727
Epoch: 0, Loss:  2.7363333702087402
Epoch: 0, Loss:  2.1360528469085693
Epoch: 0, Loss:  2.1020212173461914
Epoch: 0, Loss:  1.9237146377563477
Epoch: 0, Loss:  2.314549684524536
Epoch: 0, Loss:  2.593585968017578
Epoch: 0, Loss:  1.7425488233566284
Epoch: 0, Loss:  2.0907280445098877


### Validate

In [None]:
predictions, references = validate()

0 batches complete
100 batches complete
200 batches complete
300 batches complete
400 batches complete


### Evaluation

In [None]:
score = rouge_metric.compute(predictions, references)
print(f"Rouge Score: {score}")

Rouge Score: {'rouge1': AggregateScore(low=Score(precision=0.28138156100908157, recall=0.3527661847311623, fmeasure=0.3055916009714199), mid=Score(precision=0.29002877421532913, recall=0.36236129551046076, fmeasure=0.3138579771459466), high=Score(precision=0.2985856825164749, recall=0.3718726211589361, fmeasure=0.32247776133352224)), 'rougeL': AggregateScore(low=Score(precision=0.19791089848290735, recall=0.24858328447546973, fmeasure=0.21483102637136084), mid=Score(precision=0.20549013202888677, recall=0.25776862848097515, fmeasure=0.22256389534739437), high=Score(precision=0.21410270286952704, recall=0.2664930538358284, fmeasure=0.23056950724572037))}


### Some Results

In [None]:
for i in range(10):
  print(f"Prediction: {predictions[i]}")
  print(f"Reference: {references[i]}")

Prediction:  China has declared the Internet to be the new battlefield in its fight against "pornography and unlawful information"
Su Changlan was detained solely for expressing peaceful views online.
The crackdown is part of the worst crackdown against freedom of expression in China in more than a decade.
Reference: China's Internet model is one of extreme control, says Amnesty's East Asia director.
Chinese authorities suppress online debate on a range of legitimate issues, she says.
While the battlefield is virtual, the impact on people's lives is real and devastating, she adds.
Prediction:  man found hanging from tree in Mississippi woods with bedsheets around his neck.
Authorities say they still have a lot of work ahead to figure out how Byrd died.
"The community deserves answers," sheriff says.
Reference: Law enforcement officials say evidence collected so far doesn't suggest foul play.
Forensics expert talks about how evidence differs in suicides and lynchings.
FBI agent says a r

In [None]:
model.eval()
from pprint import pprint
for i in range(10):
  print(f"{i}th sample")
  art = dataset_valid[i]['article']
  summ = dataset_valid[i]['highlights']
  source = tokenizer.batch_encode_plus([art], max_length=MAX_LEN, pad_to_max_length=True,return_tensors='pt', truncation=True)
  target = tokenizer.batch_encode_plus([summ], max_length=MAX_LEN, pad_to_max_length=True,return_tensors='pt', truncation=True)
  source_ids = source['input_ids']
  source_mask = source['attention_mask']
  target_ids = target['input_ids']
  target_mask = target['attention_mask']

  y = target_ids.to(device, dtype = torch.long)
  ids = source_ids.to(device, dtype = torch.long)
  mask = source_mask.to(device, dtype = torch.long)

  generated_ids = model.generate(
      input_ids = ids,
      attention_mask = mask, 
      max_length=150, 
      num_beams=4,
      repetition_penalty=2.5,
      early_stopping=True
      )
  pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
  ref = [tokenizer.decode(r, skip_special_tokens=True, clean_up_tokenization_spaces=True)for r in y]
  pprint(f"Prediction: {pred}")
  pprint(f"Reference: {ref}")

0th sample
("Prediction: [' singer-songwriter hits jogger with his car.\\nThe accident "
 'happened in Santa Ynez, California.\\nHe was driving at approximately 50 mph '
 'when he hit the jogger.\\nHis injuries are not believed to be life '
 "threatening.']")
("Reference: ['Accident happens in Santa Ynez, California, near where Crosby "
 'lives.\\nThe jogger suffered multiple fractures; his injuries are not '
 "believed to be life-threatening.']")
1th sample
("Prediction: [' fraternity was founded March 9, 1856, five years before the "
 'American Civil War.\\nThe group now boasts more than 200,000 living '
 'alumni.\\nYale University banned SAEs from campus activities last month '
 "after a string of member deaths.']")
('Reference: ["Sigma Alpha Epsilon is being tossed out by the University of '
 "Oklahoma.\\nIt's also run afoul of officials at Yale, Stanford and Johns "
 'Hopkins in recent months."]')
2th sample
('Prediction: [\' of the "Finding Jesus" episode, Candida Moss is one of 