### This notebook is to train BART model to generate the counter given a ground-truth/automatic conclusion

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
import transformers
import datasets

print(f"Running on transformers v{transformers.__version__} and datasets v{datasets.__version__}")

Running on transformers v4.9.1 and datasets v1.10.2


In [3]:
import torch
import json

import nltk
import numpy as np
import pandas as pd

from pathlib import Path
from datasets import load_dataset, load_metric, Dataset

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BartTokenizer, BartForConditionalGeneration

In [4]:
train_ds = Dataset.from_pandas(pd.read_pickle('../data/train_conclusion_comp_remove_50perc.pkl'))
valid_ds = Dataset.from_pandas(pd.read_pickle('../data/valid_conclusion_comp_remove_50perc.pkl'))

In [5]:
max_input_length = 512
max_target_length = 200

rouge_metric = load_metric("rouge")
bertscore_metric = load_metric('bertscore')

def preprocess_function(examples, input_clm, output_clm):
    text_inputs = examples[input_clm]
    text_outputs = examples[output_clm]
    
    if isinstance(text_inputs[0], list):
        text_inputs = [' '.join(x) for x in text_inputs]
    
    model_inputs = tokenizer(text_inputs, max_length=max_input_length, truncation=True)
    
    
    if isinstance(text_outputs[0], list):
        text_outputs = [' '.join(x) for x in text_outputs]
        
    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(text_outputs, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [6]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [43]:
#downsample the training dataset
tmp_ds = train_ds.train_test_split(0.1)
train_ds = tmp_ds['test']

In [7]:
len(train_ds)

271923

In [8]:
#downsample the valid dataset
tmp_ds = valid_ds.train_test_split(0.01)
valid_ds = tmp_ds['test']

In [9]:
len(valid_ds)

1046

In [10]:
train_tokenized_premises_w_conc_ds = train_ds.map(lambda x :preprocess_function(x, 'premises_with_conclusion', 'counter'), batched=True)
valid_tokenized_premises_w_conc_ds = valid_ds.map(lambda x :preprocess_function(x, 'premises_with_conclusion', 'counter'), batched=True)

  0%|          | 0/272 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [11]:
batch_size = 16
args = Seq2SeqTrainingArguments(
    "../data/output/known-conclusion-bart-model",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True
)

In [12]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [13]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_tokenized_premises_w_conc_ds,
    eval_dataset=valid_tokenized_premises_w_conc_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [14]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: conclusions_in_argument, post, comment_id, n_sentences, num_cand_conc, post_id, title, counter, __index_level_0__, masked_premises, premises_with_conclusion, split.
***** Running training *****
  Num examples = 271923
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 16996


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,3.5435,3.346712,12.6231,2.2485,9.3129,11.1684,19.8824


Saving model checkpoint to ../data/output/known-conclusion-bart-model/checkpoint-500
Configuration saved in ../data/output/known-conclusion-bart-model/checkpoint-500/config.json
Model weights saved in ../data/output/known-conclusion-bart-model/checkpoint-500/pytorch_model.bin
tokenizer config file saved in ../data/output/known-conclusion-bart-model/checkpoint-500/tokenizer_config.json
Special tokens file saved in ../data/output/known-conclusion-bart-model/checkpoint-500/special_tokens_map.json
Deleting older checkpoint [../data/output/known-conclusion-bart-model/checkpoint-26500] due to args.save_total_limit
Saving model checkpoint to ../data/output/known-conclusion-bart-model/checkpoint-1000
Configuration saved in ../data/output/known-conclusion-bart-model/checkpoint-1000/config.json
Model weights saved in ../data/output/known-conclusion-bart-model/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in ../data/output/known-conclusion-bart-model/checkpoint-1000/tokenizer_conf

TrainOutput(global_step=16996, training_loss=3.6040863446500278, metrics={'train_runtime': 2805.1905, 'train_samples_per_second': 96.936, 'train_steps_per_second': 6.059, 'total_flos': 1.151809377997824e+17, 'train_loss': 3.6040863446500278, 'epoch': 1.0})

In [73]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: counter, conclusions_in_argument, __index_level_0__, weak_premises, masked_premises, premises_with_conclusion, premises, conclusion.
***** Running Evaluation *****
  Num examples = 898
  Batch size = 2


{'eval_loss': 2.9925241470336914,
 'eval_rouge1': 11.3275,
 'eval_rouge2': 1.6403,
 'eval_rougeL': 8.6976,
 'eval_rougeLsum': 10.1471,
 'eval_gen_len': 18.3552,
 'eval_runtime': 83.7533,
 'eval_samples_per_second': 10.722,
 'eval_steps_per_second': 5.361,
 'epoch': 3.0}