<a href="https://colab.research.google.com/github/wooohun/BERT-Summarizer/blob/main/BART_Abstractive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install dependencies
!pip install transformers
!pip install torch
!pip install datasets
!pip install evaluate

In [20]:
import pandas as pd
import torch
import nltk
import evaluate
from datasets import load_dataset, load_metric
from evaluate import evaluator

In [3]:
# install kaggle
!pip install -q kaggle
!mkdir ~/.kaggle

# get kaggle api token from account -> API -> create new API Token
# move kaggle api token to kaggle folder
!cp -v kaggle.json ~/.kaggle

'kaggle.json' -> '/root/.kaggle/kaggle.json'


In [4]:
# download dataset
# !chmod 600 /root/.kaggle/kaggl
!kaggle datasets download -d gowrishankarp/newspaper-text-summarization-cnn-dailymail
!unzip newspaper-text-summarization-cnn-dailymail

Downloading newspaper-text-summarization-cnn-dailymail.zip to /content
100% 503M/503M [00:07<00:00, 55.2MB/s]
100% 503M/503M [00:07<00:00, 69.1MB/s]
Archive:  newspaper-text-summarization-cnn-dailymail.zip
  inflating: cnn_dailymail/test.csv  
  inflating: cnn_dailymail/train.csv  
  inflating: cnn_dailymail/validation.csv  


In [None]:
# grab datasets
dataset = load_dataset("cnn_dailymail")
# formatted as:
# DatasetDict{
#   train: Dataset{
#     features:[]
#     num_rows: int 
#   }
#   test: {}
#   validation: {}  
# }

In [6]:
from transformers import BartTokenizerFast, BartForConditionalGeneration

# using fast tokenizer
tokenizer = BartTokenizerFast.from_pretrained('facebook/bart-large')
model  = BartForConditionalGeneration.from_pretrained('facebook/bart-large')

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

In [7]:
def preprocessing(dataset):
  inputs = [article for article in dataset['article']]
  tokenized_inputs = tokenizer(inputs, max_length = 1024, truncation=True)

  with tokenizer.as_target_tokenizer():
    labels = tokenizer(dataset['highlights'], max_length = 128, truncation=True)

  tokenized_inputs['labels'] = labels['input_ids']
  return tokenized_inputs

In [8]:
# process all datasets in batches using fast tokenizer for efficiency
processed_dataset = dataset.map(preprocessing, batched = True)

Map:   0%|          | 0/287113 [00:00<?, ? examples/s]



Map:   0%|          | 0/11490 [00:00<?, ? examples/s]

Map:   0%|          | 0/13368 [00:00<?, ? examples/s]

In [27]:
tokenizer.batch_decode(processed_dataset['train']['labels'][0], skip_special_tokens=True)

['',
 'B',
 'ishop',
 ' John',
 ' Fold',
 'a',
 ',',
 ' of',
 ' North',
 ' Dakota',
 ',',
 ' is',
 ' taking',
 ' time',
 ' off',
 ' after',
 ' being',
 ' diagnosed',
 '.',
 '\n',
 'He',
 ' contracted',
 ' the',
 ' infection',
 ' through',
 ' contaminated',
 ' food',
 ' in',
 ' Italy',
 '.',
 '\n',
 'Church',
 ' members',
 ' in',
 ' Fargo',
 ',',
 ' Grand',
 ' For',
 'ks',
 ' and',
 ' Jam',
 'est',
 'own',
 ' could',
 ' have',
 ' been',
 ' exposed',
 '.',
 '']

In [29]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

In [32]:
# create training args
batch_size = 8
num_train_epochs = 8
logging_steps = len(processed_dataset['train'])

args = Seq2SeqTrainingArguments(
    output_dir = "facebook-bart-large-finetuned-cnn-dailymail",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps
)



In [None]:
# install metrics
!pip install bert_score
!pip install rouge_score

In [26]:
rouge = evaluate.load('rouge')
bert_score = evaluate.load('bertscore')

Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

In [28]:
import numpy as np

# metric computation function to pass into trainer object
def metric_compute(predicted):
  predictions, labels = predicted
  # decode predictions, labels for metric computation
  decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  # batch_decode returns list of tokens, use nltk to convert list of tokens to list of sentences
  decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
  decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

  # compute bert_score
  bert_score_res = bert_score.compute(
      predictions = decoded_preds, references = decoded_labels, use_stemmer=True
  )
  rouge_res = rouge.compute(
      predictions = decoded_preds, references = decoded_labels, use_stemmer=True
  )
  result = {key: value.mid.fmeasure * 100 for key, value in bert_score_res.items()}
  for key, value in rouge_res.items():
    result[key] = value.mid.fmeasure * 100
  
  return {key: round(val, 4) for key, val in result.items()}

In [None]:
task_eval = evaluator('summarization')

eval_res = task_eval.compute(
    model_or_pipeline = model,
    data = processed_dataset,
    metric = evaluate.combine(['accuracy', 'bertscore', 'rouge'])
)

ValueError: ignored

In [30]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [33]:
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['test'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=metric_compute
)

In [None]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
