In [None]:
!git clone https://github.com/aagohary/canard.git

In [None]:
!mkdir ./seq2seq

!python ./canard/FormatSeq2Seq.py ./canard/data/release/train.json train ./seq2seq --spacy True
!python ./canard/FormatSeq2Seq.py ./canard/data/release/dev.json dev ./seq2seq --spacy True
!python ./canard/FormatSeq2Seq.py ./canard/data/release/test.json test ./seq2seq --spacy True

In [None]:
!pip install transformers

In [None]:
from transformers import Trainer, TrainingArguments

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration

In [None]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [None]:
class CanardDataset(Dataset):
    
    def __init__(self, data_file, label_file=None):
        
        self.data_file = data_file
        self.label_file = label_file
        self.is_train = (label_file is not None)
        
        with open(self.data_file) as fd:
            if self.is_train:
                with open(self.label_file) as fl:
                    text = fd.readlines()
                    labels = fl.readlines()
            else:
                text = fd.readlines()
                labels = None
        
        self.text = [x.replace('\n', '') for x in text]
        self.labels = [x.replace('\n', '') for x in labels] if labels is not None else labels
        print("Total Lines: {}".format(len(self.text)))
        
    def __len__(self):
        """Returns total number of samples in the dataset"""
        return len(self.text)
    
    def __getitem__(self, idx):
        text = self.text[idx]
        
        if self.is_train:
            label = self.labels[idx]
        else:
            label = None
        input_encodings = tokenizer([text], padding='max_length', truncation=True, return_tensors="pt", add_prefix_space = True)
        target_encodings = tokenizer([label], padding='max_length', truncation=True, return_tensors="pt", add_prefix_space = True)
    
        labels = target_encodings['input_ids']
        # decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id)
        # labels[labels[:, :] == model.config.pad_token_id] = -100
    
        encodings = {
            'input_ids': input_encodings['input_ids'].squeeze(0),
            'attention_mask': input_encodings['attention_mask'].squeeze(0),
            'labels': labels.squeeze(0),
        }
        return encodings

In [None]:
# train_dat = CanardDataset('./seq2seq/train-src.txt', './seq2seq/train-tgt.txt')
val_dat = CanardDataset('./seq2seq/dev-src.txt', './seq2seq/dev-tgt.txt')
test_dat = CanardDataset('./seq2seq/test-src.txt', './seq2seq/test-tgt.txt')

In [None]:
training_args = TrainingArguments(
    output_dir='./models/bart-summarizer',          
    num_train_epochs=2,           
    per_device_train_batch_size=1, 
    per_device_eval_batch_size=1,   
    warmup_steps=500,               
    weight_decay=0.01,              
    logging_dir='./logs',
    logging_steps=100,    
    do_train=True,
    do_eval=True,
    save_steps=2000
)

trainer = Trainer(
    model=model,                       
    args=training_args,                  
    train_dataset=val_dat,        
    eval_dataset=test_dat   
)

In [None]:
trainer.train()