### This notebook is to jointly train BART-v2 model for both generating the conclusion and the counter

In [1]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
sys.path.append('../src-py')

In [3]:
import transformers
import datasets
from mt_bart_v2 import *

print(f"Running on transformers v{transformers.__version__} and datasets v{datasets.__version__}")

Running on transformers v4.9.1 and datasets v1.10.2


In [4]:
import torch
import json

import nltk
import numpy as np
import pandas as pd

from pathlib import Path
from datasets import load_dataset, load_metric, Dataset

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BartTokenizer, BartForConditionalGeneration

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [6]:
# Input to the model should be:
# input_ids: the encoded premises
# attention_mask: the attention mask of the premises
# conclusion_decoder_input_ids: the encoded conclusion
# conclusion_decoder_attention_mask: the attention mask of the conclusion
# counter_decoder_input_ids: the encoded counter
# counter_decoder_attention_mask: the attention mask of the counter
# conclusion_labels: this is the encoded conclusion again
# counter_labels: this is the encoded counter again

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model     = BartModelV2.from_pretrained('facebook/bart-base', conc_loss_weight = 0.5, counter_loss_weight=0.5).to(device)
original_bart_model = BartModel.from_pretrained('facebook/bart-base').to(device)
#load the weights of the two decoders
model.conclusion_decoder.load_state_dict(original_bart_model.decoder.state_dict())
model.counter_decoder.load_state_dict(original_bart_model.decoder.state_dict())

Some weights of the model checkpoint at facebook/bart-base were not used when initializing BartModelV2: ['decoder.layers.1.final_layer_norm.bias', 'decoder.layers.5.encoder_attn.q_proj.weight', 'decoder.layers.0.encoder_attn.v_proj.weight', 'decoder.layers.0.encoder_attn.out_proj.bias', 'decoder.layers.2.fc1.weight', 'decoder.layers.4.fc2.weight', 'decoder.layers.5.self_attn.k_proj.weight', 'decoder.layers.3.self_attn.k_proj.weight', 'decoder.layers.0.self_attn.k_proj.bias', 'decoder.layers.4.final_layer_norm.weight', 'decoder.layers.4.encoder_attn.k_proj.bias', 'decoder.layers.1.fc2.weight', 'decoder.layers.1.self_attn.out_proj.bias', 'decoder.layers.3.encoder_attn_layer_norm.bias', 'decoder.layers.1.encoder_attn.v_proj.bias', 'decoder.layers.2.self_attn_layer_norm.bias', 'decoder.layers.5.self_attn.v_proj.weight', 'decoder.layers.2.encoder_attn.k_proj.bias', 'decoder.layers.0.encoder_attn.q_proj.weight', 'decoder.layers.2.self_attn.q_proj.weight', 'decoder.layers.5.encoder_attn.k_pro

<All keys matched successfully>

In [7]:
data_fold = '../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/'

In [32]:
train_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/train_conclusion_comp_remove_75sem_perc.pkl'))
valid_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc.pkl'))

In [33]:
max_input_length = 512
max_target_length = 200

rouge_metric = load_metric("rouge")
bertscore_metric = load_metric('bertscore')

#Encoding function for joint generation of conclusion and counter
def preprocess_function(examples, tokenizer, premises_clm, counter_clm, conclusion_clm, max_input_length=512, max_conc_length=100, max_counter_length=200):
    premises   = examples[premises_clm]
    conclusions = examples[conclusion_clm]
    counters = examples[counter_clm]
    
        
    premises = [' '.join(x) for x in premises] if isinstance(premises[0], list) else premises
    counters = [' '.join(x) for x in counters] if isinstance(counters[0], list) else counters
    conclusions = [' '.join(x) for x in conclusions] if isinstance(conclusions[0], list) else conclusions
    
    model_inputs = tokenizer(premises, max_length=max_input_length, truncation=True, padding='max_length')
        
    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        counter_labels = tokenizer(counters, max_length=max_counter_length, truncation=True, padding='max_length')
        conclusion_labels = tokenizer(conclusions, max_length=max_conc_length, truncation=True, padding='max_length')

    #     print(text_inputs[0])
    #     print(model_inputs['input_ids'][0])
    #     print('-----------------')
    #     print(text_outputs[0])
    #     print(labels["input_ids"][0])
    
    
    model_inputs["conclusion_labels"] = conclusion_labels["input_ids"]
    model_inputs["counter_labels"] = counter_labels["input_ids"]
    #model_inputs["counter_decoder_attention_mask"] = counter_labels['attention_mask']
    #model_inputs["conclusion_decoder_attention_mask"] = conclusion_labels['attention_mask']
    
    return model_inputs

def compute_metrics(eval_pred, tokenizer):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    #compute BertScore bertscore_metric
    bertscore_result = bertscore_metric.compute(predictions=decoded_preds, references=decoded_labels, lang='en', rescale_with_baseline=True)
    result['bert-fscore'] = round(np.mean(bertscore_result['f1']), 2)

In [34]:
#downsample the training dataset
# tmp_ds = train_ds.train_test_split(0.01)
# train_ds = tmp_ds['test']

In [35]:
len(train_ds)

92397

In [36]:
#downsample the valid dataset
tmp_ds = valid_ds.train_test_split(0.1)
valid_ds = tmp_ds['test']

In [37]:
len(valid_ds)

3318

In [38]:
train_tokenized_ds = train_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)
valid_tokenized_ds = valid_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)

  0%|          | 0/93 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [39]:
batch_size = 16
args = Seq2SeqTrainingArguments(
    "../data/output/joint-con-counter-bart-model-50-50",
    evaluation_strategy = "steps",
    eval_steps=1000,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=10,
    num_train_epochs=6,
    predict_with_generate=False
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [40]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [41]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_tokenized_ds,
    eval_dataset=valid_tokenized_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    #compute_metrics=compute_metrics
)

In [None]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BartModelV2.forward` and have been ignored: counter, __index_level_0__, premises_with_conclusion, bart_conclusion, post, split, masked_premises, title, post_id, comment_id, num_cand_conc, n_sentences.
***** Running training *****
  Num examples = 92397
  Num Epochs = 6
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 34650


Step,Training Loss,Validation Loss
1000,2.5589,No log
2000,2.4918,No log
3000,2.494,No log
4000,2.4668,No log
5000,2.2518,No log
6000,2.1085,No log


Saving model checkpoint to ../data/output/joint-con-counter-bart-model-50-50/checkpoint-500
Configuration saved in ../data/output/joint-con-counter-bart-model-50-50/checkpoint-500/config.json
Model weights saved in ../data/output/joint-con-counter-bart-model-50-50/checkpoint-500/pytorch_model.bin
tokenizer config file saved in ../data/output/joint-con-counter-bart-model-50-50/checkpoint-500/tokenizer_config.json
Special tokens file saved in ../data/output/joint-con-counter-bart-model-50-50/checkpoint-500/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `BartModelV2.forward` and have been ignored: counter, __index_level_0__, premises_with_conclusion, post, split, masked_premises, title, post_id, comment_id, num_cand_conc, n_sentences.
***** Running Evaluation *****
  Num examples = 3318
  Batch size = 16
Saving model checkpoint to ../data/output/joint-con-counter-bart-model-50-50/checkpoint-1000
Configuration saved in ../data/ou