In [None]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset, load_metric
from compare_mt.rouge.rouge_scorer import RougeScorer
import pickle
from torch.nn import CrossEntropyLoss

In [None]:
device = torch.device('cuda')
torch.cuda.set_device(f'cuda:{0}')

In [None]:
# load pretrained model
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base').to(device)

In [None]:
# load train tokens
train_data_path = ''
with open(train_data_path, 'rb') as f:
    train_token = pickle.load(f)

In [None]:
# data shuffle
import random
random.seed(100)
random.shuffle(train_token)

### Model training

In [None]:
loss_fct = CrossEntropyLoss(ignore_index=1)
optimizer = AdamW(model.parameters(), lr=3e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, 500, 15000)
gradient_accumulation_steps = 2

model.zero_grad()

max_epoch = 1


'''
option:
 'weight': weigt multiplication
 'zscore'
 'division': (m:n) internal division

ratio:
 if option='weight': weight ratio
 if option='zscore': pass
 if option='division': m:n ratio
 
'''

ratio = [1, 1]

for _ in range(max_epoch):
    for step, batch in enumerate(train_token):
        document, summary = torch.tensor(batch['document']).unsqueeze(0).to(device), torch.tensor(batch['summary']).unsqueeze(0).to(device)
        sent_ids = [batch['pos'], batch['nos']]
        
        model.train()
        
        loss = model(input_ids=document, decoder_input_ids=summary[..., :-1], labels = summary[..., 1:], option='division', sent_ids=sent_ids, ratio=ratio)[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()

In [None]:
save_path = ''

# save fine-tuned model
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

# load fine-tuned model
tokenizer = BartTokenizer.from_pretrained(save_path)
model = BartForConditionalGeneration.from_pretrained(save_path).to(device)

### Model evaluation

In [None]:
# evaluation metric
from compare_mt.rouge.rouge_scorer import RougeScorer

rouge_scorer = RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)
bertscorer = load_metric("bertscore")

In [None]:
# load test data
test_data_path = ''
with open(test_data_path, 'rb') as f:
    test_token = pickle.load(f)

In [None]:
def get_score(model, tokenizer, test_data):
    rouge1 = 0
    rouge2 = 0
    rougeLsum = 0
    bertscore = 0

    n = len(test_data)
    
    for row in test_data:
        input_tokens = torch.tensor(row['document']).unsqueeze(0).to(device)
        sum_ids = model.generate(input_tokens, num_beams=6, max_length=62, early_stopping=True, no_repeat_ngram_size=3)
        summary = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in sum_ids]
        summary = summary[0]
        ref = tokenizer.decode(row['summary'], skip_special_tokens=True)

        # ROUGE
        score = rouge_scorer.score(summary, ref)
        rouge1 += score['rouge1'].fmeasure
        rouge2 += score['rouge2'].fmeasure
        rougeLsum += score['rougeLsum'].fmeasure

        # BertScore
        results = bertscorer.compute(predictions=[summary], references=[ref], lang='en')
        b_s = results['f1'][0]
        bertscore += b_s
        
    rouge1 /= n
    rouge2 /= n
    rougeLsum /= n
    bertscore /= n
        
    return rouge1, rouge2, rougeLsum, bertscore

In [None]:
# model evaluation
rouge1, rouge2, rougeLsum, bert_score = get_score(model, tokenizer, test_token)
print(rouge1, rouge2, rougeLsum, bert_score)