In [1]:
!pip install rouge
!pip install transformers[torch]
!pip install accelerate

Collecting rouge
  Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: rouge
Successfully installed rouge-1.0.1
Collecting accelerate>=0.20.3 (from transformers[torch])
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.25.0


In [2]:
import os
import pandas as pd
from transformers import BartForConditionalGeneration, BartTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments, Trainer
from torch.utils.data import DataLoader, Dataset
import torch
from rouge import Rouge
from tqdm import tqdm

In [3]:
#data loading
train_data = pd.read_csv('/content/train.csv')
test_data = pd.read_csv('/content/test.csv')

In [4]:
#selecting relevant column
train_data = train_data[['dialogue', 'summary']]
test_data = test_data[['dialogue', 'summary']]

In [5]:
train_data = train_data.sample(frac=0.10, random_state=42)

In [6]:
test_data[:10]

Unnamed: 0,dialogue,summary
0,"#Person1#: Ms. Dawson, I need you to take a di...",Ms. Dawson helps #Person1# to write a memo to ...
1,"#Person1#: Ms. Dawson, I need you to take a di...",In order to prevent employees from wasting tim...
2,"#Person1#: Ms. Dawson, I need you to take a di...",Ms. Dawson takes a dictation for #Person1# abo...
3,#Person1#: You're finally here! What took so l...,#Person2# arrives late because of traffic jam....
4,#Person1#: You're finally here! What took so l...,#Person2# decides to follow #Person1#'s sugges...
5,#Person1#: You're finally here! What took so l...,#Person2# complains to #Person1# about the tra...
6,"#Person1#: Kate, you never believe what's happ...",#Person1# tells Kate that Masha and Hero get d...
7,"#Person1#: Kate, you never believe what's happ...",#Person1# tells Kate that Masha and Hero are g...
8,"#Person1#: Kate, you never believe what's happ...",#Person1# and Kate talk about the divorce betw...
9,"#Person1#: Happy Birthday, this is for you, Br...",#Person1# and Brian are at the birthday party ...


In [7]:
class SummarizationDataset(Dataset):
    def __init__(self, dialogue_list, summary_list, tokenizer, max_input_length=1024, max_target_length=150):
        self.dialogue_list = dialogue_list
        self.summary_list = summary_list
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dialogue_list)

    def __getitem__(self, idx):
        dialogue = self.dialogue_list[idx]
        summary = self.summary_list[idx]

        # Tokenize and pad/truncate
        inputs = self.tokenizer(dialogue, max_length=self.max_input_length, padding='max_length', truncation=True, return_tensors="pt")
        targets = self.tokenizer(summary, max_length=self.max_target_length, padding='max_length', truncation=True, return_tensors="pt")

        input_ids = inputs["input_ids"].squeeze()
        attention_mask = inputs["attention_mask"].squeeze()
        labels = targets["input_ids"].squeeze()

        # Replace padding token id's of the labels by -100 so it's ignored in the loss.
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [8]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

In [9]:
#create DataLoader
train_dataset = SummarizationDataset(train_data['dialogue'].tolist(), train_data['summary'].tolist(), tokenizer)
test_dataset = SummarizationDataset(test_data['dialogue'].tolist(), test_data['summary'].tolist(), tokenizer)
dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [10]:
#fine-tuning the model
model_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
optimizer = AdamW(model_bart.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataloader) * 5)

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]



In [11]:
#training arguments
training_args = TrainingArguments(
    output_dir='./summarization_results',
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=1,
    warmup_steps=500,
    weight_decay=0.01,
    learning_rate=5e-5,
    logging_dir='./logs',
)

#trainer
trainer = Trainer(
    model=model_bart,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

In [12]:
trainer.train()

Step,Training Loss


TrainOutput(global_step=312, training_loss=1.8535346006735778, metrics={'train_runtime': 393.0897, 'train_samples_per_second': 6.34, 'train_steps_per_second': 0.794, 'total_flos': 1519463253934080.0, 'train_loss': 1.8535346006735778, 'epoch': 2.0})

In [13]:
trainer.save_model('./bart_model')

In [14]:
tokenizer.save_pretrained('./bart_model')

('./bart_model/tokenizer_config.json',
 './bart_model/special_tokens_map.json',
 './bart_model/vocab.json',
 './bart_model/merges.txt',
 './bart_model/added_tokens.json')

In [15]:
results = trainer.evaluate()

In [16]:
results

{'eval_loss': 1.5692164897918701,
 'eval_runtime': 94.4599,
 'eval_samples_per_second': 15.88,
 'eval_steps_per_second': 15.88,
 'epoch': 2.0}