In [None]:
#!pip install transformers datasets evaluate transformers[torch]

In [None]:
#!pip install rouge_score

In [18]:
from datasets import load_dataset, load_metric
import matplotlib.pyplot as plt
from transformers import BartForConditionalGeneration, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
import torch
import pandas as pd

In [16]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [None]:
dataset = load_dataset("multi_news")
print(dataset)
print(f"Features: {dataset['train'].column_names}")

In [None]:
model_ckpt = "sshleifer/distilbart-cnn-6-6"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BartForConditionalGeneration.from_pretrained(model_ckpt)

In [None]:
def convert_examples_to_features(example_batch):
    input_encodings = tokenizer(example_batch["document"], max_length=1024, truncation=True)
    target_encodings = tokenizer(text_target=example_batch["summary"], max_length=256, truncation=True)

    return {"input_ids": input_encodings["input_ids"],
           "attention_mask": input_encodings["attention_mask"],
           "labels": target_encodings["input_ids"]}

new_dataset = dataset.map(convert_examples_to_features, batched=True)

In [7]:
columns = ["input_ids", "labels", "attention_mask"]
new_dataset.set_format(type="torch", columns=columns)

In [8]:
seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
rouge = load_metric("rouge")

In [10]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [None]:
class MetricsCallback(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if "loss" in logs or "eval_loss" in logs:
            self.metrics.append({key: val for key, val in logs.items() if key in ["loss", "eval_loss"]})

metrics_callback = MetricsCallback()

In [None]:
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    fp16=True,
    weight_decay=0.01,
    output_dir="./results",
    logging_steps=50,
    eval_steps=50,
    gradient_accumulation_steps=4,
    load_best_model_at_end=True,
    num_train_epochs=3
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    data_collator=seq2seq_data_collator,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=new_dataset['train'],
    eval_dataset=new_dataset['validation'],
    callbacks=[metrics_callback]
)

In [None]:
trainer.train()

In [None]:
trainer.save_model('./results/bert_news_summary/model')

In [None]:
train_losses = [x['loss'] for x in metrics_callback.metrics if 'loss' in x]
eval_losses = [x['eval_loss'] for x in metrics_callback.metrics if 'eval_loss' in x]
steps = range(0, len(train_losses) * training_args.logging_steps, training_args.logging_steps)

plt.figure(figsize=(10, 5))
plt.plot(steps, train_losses, label='Train Loss')
if eval_losses:
    eval_steps = range(0, len(eval_losses) * training_args.eval_steps, training_args.eval_steps)
    plt.plot(eval_steps, eval_losses, label='Eval Loss')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training and Evaluation Loss')
plt.legend()
plt.show()