### This notebook is to jointly train BART-v2 model for both generating the conclusion and the counter

In [1]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
sys.path.append('../src-py')

In [3]:
import transformers
import datasets
from utils import *
from mt_bart_v2 import *

print(f"Running on transformers v{transformers.__version__} and datasets v{datasets.__version__}")

Running on transformers v4.9.1 and datasets v1.10.2


In [4]:
import torch
import json

import nltk
import numpy as np
import pandas as pd

from pathlib import Path
from datasets import load_dataset, load_metric, Dataset

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BartTokenizer, BartForConditionalGeneration

import ray
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.schedulers import PopulationBasedTraining
from ray import tune
from ray.tune import CLIReporter

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [6]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [7]:
def get_model(params):
    compute_dynamic_weights=False
    conc_loss_weight=0.5 if params == None else params['conc_loss_weight']
    counter_loss_weight=0.5 if params == None else params['counter_loss_weight']
    attention_to_conc=False
    conc_decoder=True
    model     = BartModelV2.from_pretrained('facebook/bart-base', compute_dynamic_weights=False, 
                                            conc_loss_weight = conc_loss_weight, 
                                            counter_loss_weight=counter_loss_weight, 
                                            attention_to_conc=attention_to_conc, 
                                            conc_decoder=conc_decoder).to(device)

    original_bart_model = BartModel.from_pretrained('facebook/bart-base').to(device)

    #load the weights of the two decoders
    model.conclusion_decoder.load_state_dict(original_bart_model.decoder.state_dict())
    model.counter_decoder.load_state_dict(original_bart_model.decoder.state_dict())
    
    data_collator= DataCollatorForSeq2Seq(tokenizer, model)
    
    return data_collator, model

In [8]:
data_fold = '../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/'

In [9]:
#Taking unique posts from valid dataset and sample only 1500 instances
# valid_df = pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc.pkl')
# valid_unique_df = valid_df.drop_duplicates('post_id')
# valid_sample_df = valid_unique_df.sample(1500)
# valid_sample_df.to_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc_sample.pkl')

In [10]:
train_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/train_conclusion_comp_remove_75sem_perc.pkl'))
valid_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc_sample.pkl'))

In [11]:
#Encoding function for joint generation of conclusion and counter
def preprocess_function(examples, tokenizer, premises_clm, counter_clm, conclusion_clm, max_input_length=512, max_conc_length=100, max_counter_length=200):
    premises   = examples[premises_clm]
    conclusions = examples[conclusion_clm]
    counters = examples[counter_clm]
    
        
    premises = [' '.join(x) for x in premises] if isinstance(premises[0], list) else premises
    counters = [' '.join(x) for x in counters] if isinstance(counters[0], list) else counters
    conclusions = [' '.join(x) for x in conclusions] if isinstance(conclusions[0], list) else conclusions
    
    model_inputs = tokenizer(premises, max_length=max_input_length, truncation=True, padding='max_length')
        
    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        counter_labels = tokenizer(counters, max_length=max_counter_length, truncation=True, padding='max_length')
        conclusion_labels = tokenizer(conclusions, max_length=max_conc_length, truncation=True, padding='max_length')

    model_inputs["conclusion_labels"] = conclusion_labels["input_ids"]
    model_inputs["counter_labels"] = counter_labels["input_ids"]
    model_inputs["counter_decoder_attention_mask"] = counter_labels['attention_mask']
    model_inputs["conclusion_decoder_attention_mask"] = conclusion_labels['attention_mask']
    
    return model_inputs

In [12]:
#downsample the training dataset
#tmp_ds = train_ds.train_test_split(0.005)
#train_ds = tmp_ds['test']

In [13]:
len(train_ds)

92397

In [14]:
len(valid_ds)

1500

In [15]:
train_tokenized_ds = train_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)
valid_tokenized_ds = valid_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)

  0%|          | 0/93 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [16]:
# #Train the model for different config
# batch_size = 32
# for conc_loss_weight, count_loss_weight in [(0.0, 1.0), (0.8, 0.2), (0.6, 0.4)]:
#         data_collator, model = get_model({'conc_loss_weight': conc_loss_weight, 'counter_loss_weight': count_loss_weight})
#         args = Seq2SeqTrainingArguments(
#             "../data/output/joint-con-counter-bart-model-no-attention-finetuned/{}-{}".format(str(conc_loss_weight).replace('.','-'), str(count_loss_weight).replace('.','-')),
#             evaluation_strategy = "steps",
#             learning_rate=2e-5,
#             per_device_train_batch_size=batch_size,
#             per_device_eval_batch_size=batch_size,
#             weight_decay=0.01,
#             save_total_limit=5,
#             num_train_epochs=3,
#             load_best_model_at_end=True,
#             predict_with_generate=True,
#             metric_for_best_model='bert-fscore',
#             label_names=['conclusion_labels', 'counter_labels']
#         )

#         trainer = Seq2TwoSeqTrainer(
#             model,
#             args,
#             train_dataset=train_tokenized_ds,
#             eval_dataset=valid_tokenized_ds,
#             data_collator=data_collator,
#             tokenizer=tokenizer,
#             compute_metrics=lambda x : compute_metrics(x, tokenizer)
#         )
        
#         trainer.train()
#         trainer.save_model()

#### Train a dyanmic weighting model:

In [17]:
batch_size = 32

In [18]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartModelV2.from_pretrained('facebook/bart-base', compute_dynamic_weights=True, conc_decoder=True).to(device)
original_bart_model = BartModel.from_pretrained('facebook/bart-base').to(device)
#load the weights of the two decoders
model.conclusion_decoder.load_state_dict(original_bart_model.decoder.state_dict())
model.counter_decoder.load_state_dict(original_bart_model.decoder.state_dict())

Some weights of the model checkpoint at facebook/bart-base were not used when initializing BartModelV2: ['decoder.layers.1.fc2.weight', 'decoder.layers.5.encoder_attn.out_proj.weight', 'decoder.layers.0.self_attn.q_proj.bias', 'decoder.layers.5.self_attn_layer_norm.weight', 'decoder.layers.4.self_attn.v_proj.weight', 'decoder.layers.5.self_attn.out_proj.weight', 'decoder.layers.4.encoder_attn.q_proj.weight', 'decoder.layers.3.encoder_attn.q_proj.bias', 'decoder.layers.3.self_attn.q_proj.weight', 'decoder.layers.0.fc2.weight', 'decoder.layers.2.self_attn.q_proj.bias', 'decoder.layers.0.self_attn_layer_norm.weight', 'decoder.layers.4.fc1.bias', 'decoder.layers.0.self_attn.v_proj.bias', 'decoder.layers.3.self_attn.k_proj.weight', 'decoder.layers.2.self_attn.out_proj.bias', 'decoder.layers.0.encoder_attn_layer_norm.bias', 'decoder.layers.5.self_attn.v_proj.weight', 'decoder.layers.2.encoder_attn_layer_norm.weight', 'decoder.layers.2.fc2.weight', 'decoder.layers.1.self_attn.q_proj.weight', 

<All keys matched successfully>

In [19]:
data_collator= DataCollatorForSeq2Seq(tokenizer, model)

args = Seq2SeqTrainingArguments(
    "../data/output/joint-con-counter-bart-model-no-attention-finetuned/dynamic-weight",
    evaluation_strategy = "steps",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=5,
    num_train_epochs=6,
    load_best_model_at_end=True,
    predict_with_generate=True,
    metric_for_best_model='bert-fscore',
    label_names=['conclusion_labels', 'counter_labels']
)

trainer = Seq2TwoSeqTrainer(
    model,
    args,
    train_dataset=train_tokenized_ds,
    eval_dataset=valid_tokenized_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=lambda x : compute_metrics(x, tokenizer)
)

trainer.train()
trainer.save_model()

The following columns in the training set  don't have a corresponding argument in `BartModelV2.forward` and have been ignored: comment_id, num_cand_conc, post_id, bart_conclusion, masked_premises, n_sentences, split, title, counter, post, __index_level_0__, premises_with_conclusion.
***** Running training *****
  Num examples = 92397
  Num Epochs = 6
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 17328


Step,Training Loss,Validation Loss,Bleu,Precisions,Brevity Penalty,Length Ratio,Translation Length,Reference Length,Bert-fscore
500,2.4384,1.562732,0.010045,"[0.1859771730254287, 0.02499662858573988, 0.004271165236429032, 0.0008629753804035154]",0.878011,0.88488,105314,119015,-0.19
1000,1.8559,1.387012,0.009773,"[0.2070409862344996, 0.026977340294775765, 0.004461474421235764, 0.0009219988936013276]",0.793835,0.812427,96691,119015,-0.06
1500,1.6922,1.292078,0.012336,"[0.20521524341993358, 0.029255378816772485, 0.00484801730462339, 0.0007955997215400974]",1.0,1.136151,135219,119015,-0.04
2000,1.5854,1.234214,0.011701,"[0.208154675785874, 0.028939535296992315, 0.004889909496150209, 0.0007576794924460267]",0.957357,0.958241,114045,119015,-0.03
2500,1.5022,1.195369,0.011911,"[0.22869433676845988, 0.031101946868706244, 0.0053128344847718206, 0.0008788387479016491]",0.882299,0.888712,105770,119015,-0.01
3000,1.4366,1.17029,0.011809,"[0.19956609839399125, 0.028422949704334736, 0.004587927968799163, 0.0007472459178917307]",1.0,1.173491,139663,119015,-0.02
3500,1.3716,1.151296,0.011637,"[0.19555195331496192, 0.028017678834506465, 0.004556124979335427, 0.0007346802471197195]",1.0,1.295845,154225,119015,-0.03
4000,1.3427,1.133117,0.011995,"[0.19642952661139343, 0.028022423341888425, 0.004585152838427947, 0.0008203501999172756]",1.0,1.256648,149560,119015,-0.02
4500,1.2986,1.122034,0.011733,"[0.18651517759798014, 0.027400590380274353, 0.004595083785835231, 0.0008069741414855392]",1.0,1.464269,174270,119015,-0.03
5000,1.2653,1.110842,0.011727,"[0.19796477119115094, 0.02764970568455208, 0.004603000007006088, 0.0007505328074883349]",1.0,1.224493,145733,119015,-0.02


The following columns in the evaluation set  don't have a corresponding argument in `BartModelV2.forward` and have been ignored: comment_id, num_cand_conc, post_id, masked_premises, n_sentences, split, title, counter, post, __index_level_0__, premises_with_conclusion.
***** Running Evaluation *****
  Num examples = 1500
  Batch size = 32
Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /mnt/ceph/storage/data-tmp/2021//sile2804/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_