### This notebook is to jointly train BART-v2 model for both generating the conclusion and the counter

In [1]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
sys.path.append('../src-py')

In [3]:
import transformers
import datasets
from utils import *
from mt_bart_v2 import *

print(f"Running on transformers v{transformers.__version__} and datasets v{datasets.__version__}")

Running on transformers v4.9.1 and datasets v1.10.2


In [4]:
import torch
import json

import nltk
import numpy as np
import pandas as pd

from pathlib import Path
from datasets import load_dataset, load_metric, Dataset

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BartTokenizer, BartForConditionalGeneration

import ray
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.schedulers import PopulationBasedTraining
from ray import tune
from ray.tune import CLIReporter

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [6]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [7]:
def get_model(params):
    compute_dynamic_weights=False
    conc_loss_weight=0.5 if params == None else params['conc_loss_weight']
    counter_loss_weight=0.5 if params == None else params['counter_loss_weight']
    attention_to_conc=False
    conc_decoder=True
    model     = BartModelV2.from_pretrained('facebook/bart-base', compute_dynamic_weights=False, 
                                            conc_loss_weight = conc_loss_weight, 
                                            counter_loss_weight=counter_loss_weight, 
                                            attention_to_conc=attention_to_conc, 
                                            conc_decoder=conc_decoder).to(device)

    original_bart_model = BartModel.from_pretrained('facebook/bart-base').to(device)

    #load the weights of the two decoders
    model.conclusion_decoder.load_state_dict(original_bart_model.decoder.state_dict())
    model.counter_decoder.load_state_dict(original_bart_model.decoder.state_dict())
    
    data_collator= DataCollatorForSeq2Seq(tokenizer, model)
    
    return data_collator, model

In [8]:
data_fold = '../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/'

In [9]:
#Taking unique posts from valid dataset and sample only 1500 instances
# valid_df = pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc.pkl')
# valid_unique_df = valid_df.drop_duplicates('post_id')
# valid_sample_df = valid_unique_df.sample(1500)
# valid_sample_df.to_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc_sample.pkl')

In [10]:
train_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/train_conclusion_comp_remove_75sem_perc.pkl'))
valid_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc_sample.pkl'))

In [11]:
#Encoding function for joint generation of conclusion and counter
def preprocess_function(examples, tokenizer, premises_clm, counter_clm, conclusion_clm, max_input_length=512, max_conc_length=100, max_counter_length=200):
    premises   = examples[premises_clm]
    conclusions = examples[conclusion_clm]
    counters = examples[counter_clm]
    
        
    premises = [' '.join(x) for x in premises] if isinstance(premises[0], list) else premises
    counters = [' '.join(x) for x in counters] if isinstance(counters[0], list) else counters
    conclusions = [' '.join(x) for x in conclusions] if isinstance(conclusions[0], list) else conclusions
    
    model_inputs = tokenizer(premises, max_length=max_input_length, truncation=True, padding='max_length')
        
    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        counter_labels = tokenizer(counters, max_length=max_counter_length, truncation=True, padding='max_length')
        conclusion_labels = tokenizer(conclusions, max_length=max_conc_length, truncation=True, padding='max_length')

    model_inputs["conclusion_labels"] = conclusion_labels["input_ids"]
    model_inputs["counter_labels"] = counter_labels["input_ids"]
    model_inputs["counter_decoder_attention_mask"] = counter_labels['attention_mask']
    model_inputs["conclusion_decoder_attention_mask"] = conclusion_labels['attention_mask']
    
    return model_inputs

In [12]:
#downsample the training dataset
#tmp_ds = train_ds.train_test_split(0.005)
#train_ds = tmp_ds['test']

In [13]:
len(train_ds)

92397

In [14]:
len(valid_ds)

1500

In [15]:
train_tokenized_ds = train_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)
valid_tokenized_ds = valid_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)

  0%|          | 0/93 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [None]:
#Train the model for different config
batch_size = 32
for conc_loss_weight, count_loss_weight in [(0.0, 1.0), (0.8, 0.2), (0.6, 0.4)]:
        data_collator, model = get_model({'conc_loss_weight': conc_loss_weight, 'counter_loss_weight': count_loss_weight})
        args = Seq2SeqTrainingArguments(
            "../data/output/joint-con-counter-bart-model-no-attention-finetuned/{}-{}".format(str(conc_loss_weight).replace('.','-'), str(count_loss_weight).replace('.','-')),
            evaluation_strategy = "steps",
            learning_rate=2e-5,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            weight_decay=0.01,
            save_total_limit=5,
            num_train_epochs=3,
            load_best_model_at_end=True,
            predict_with_generate=True,
            metric_for_best_model='bert-fscore',
            label_names=['conclusion_labels', 'counter_labels']
        )

        trainer = Seq2TwoSeqTrainer(
            model,
            args,
            train_dataset=train_tokenized_ds,
            eval_dataset=valid_tokenized_ds,
            data_collator=data_collator,
            tokenizer=tokenizer,
            compute_metrics=lambda x : compute_metrics(x, tokenizer)
        )
        
        trainer.train()
        trainer.save_model()

loading configuration file https://huggingface.co/facebook/bart-base/resolve/main/config.json from cache at /mnt/ceph/storage/data-tmp/2021//sile2804/.cache/huggingface/transformers/f5310d276a6d1648d00c32fadc8bf7b4607e0fbd5b404fc4a0045960aa2bdfdb.da0f3c0e2dc1c2fecc46738a1ebf4806f2fc36aae3d5c1947f21e063e7cab34b
Model config BartConfig {
  "_name_or_path": "bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_eos_token_id": 

Step,Training Loss,Validation Loss,Bleu,Precisions,Brevity Penalty,Length Ratio,Translation Length,Reference Length,Bert-fscore
500,3.609,2.186493,0.003279,"[0.2644909246537761, 0.03488187368086603, 0.006132086554049292, 0.0012906996566251856]",0.199469,0.382834,45563,119015,-0.1
1000,2.9334,1.975879,0.008466,"[0.22590898022479397, 0.030381434003684975, 0.005624641795620243, 0.0010880356875705524]",0.591375,0.655606,78027,119015,-0.05
1500,2.7739,1.860821,0.012432,"[0.2024202590373091, 0.028100919225503618, 0.005090045517514348, 0.0008251552173042259]",1.0,1.086628,129325,119015,-0.06
2000,2.6811,1.788559,0.011168,"[0.18609079619152752, 0.02614333519241495, 0.004678032971678174, 0.0006835659356308744]",1.0,1.21783,144940,119015,-0.06
2500,2.6153,1.743694,0.012686,"[0.20776221815362791, 0.029009469412295204, 0.005226038865492922, 0.0008222416362608562]",1.0,1.018813,121254,119015,-0.03
3000,2.5551,1.713333,0.010239,"[0.18743553501165144, 0.026323909438037568, 0.00399839025846737, 0.0005571505355199853]",1.0,1.319682,157062,119015,-0.04
3500,2.5049,1.689409,0.011487,"[0.19510490855358895, 0.0281501427249131, 0.004366212833266881, 0.0007261061567201125]",1.0,1.3107,155993,119015,-0.03
4000,2.4973,1.669175,0.011338,"[0.19046334347648775, 0.02714634069100928, 0.004380215493813345, 0.0007295735545727475]",1.0,1.339201,159385,119015,-0.03
4500,2.4632,1.654655,0.011443,"[0.19289937146057626, 0.027200201017651863, 0.004477138689834485, 0.0007298802740252257]",1.0,1.350166,160690,119015,-0.02
5000,2.4427,1.641294,0.011075,"[0.18964968558262685, 0.02670676537117006, 0.004159204754859975, 0.0007142108297950411]",1.0,1.320136,157116,119015,-0.02


The following columns in the evaluation set  don't have a corresponding argument in `BartModelV2.forward` and have been ignored: __index_level_0__, comment_id, masked_premises, post, post_id, num_cand_conc, title, n_sentences, split, counter, premises_with_conclusion.
***** Running Evaluation *****
  Num examples = 1500
  Batch size = 32
Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /mnt/ceph/storage/data-tmp/2021//sile2804/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_

Step,Training Loss,Validation Loss,Bleu,Precisions,Brevity Penalty,Length Ratio,Translation Length,Reference Length,Bert-fscore
500,2.2076,1.461286,0.006346,"[0.23390075603967245, 0.0325493794369054, 0.005856056482077168, 0.0011840124490451786]",0.418664,0.534563,63621,119015,-0.12
1000,1.6834,1.311789,0.008375,"[0.22013250175940982, 0.029401586870010134, 0.004999118543329892, 0.0008984264702107451]",0.641394,0.692467,82414,119015,-0.03
1500,1.5391,1.232665,0.011348,"[0.19344615069315338, 0.02713843785615962, 0.004258280822020326, 0.0007417386960536474]",1.0,1.147939,136622,119015,-0.05
2000,1.4464,1.183043,0.01106,"[0.19075367962888723, 0.026463947665499175, 0.004404190565375161, 0.0006731366689351126]",1.0,1.173701,139688,119015,-0.04
2500,1.3732,1.150451,0.011826,"[0.18884688746315242, 0.026900297879796743, 0.004519213741809811, 0.0008519778056201898]",1.0,1.211402,144175,119015,-0.04
3000,1.3199,1.130394,0.010367,"[0.17594698898810535, 0.025157232704402517, 0.004012249738081258, 0.0006503385244283408]",1.0,1.484838,176718,119015,-0.05
3500,1.2644,1.115216,0.010357,"[0.17585980644103014, 0.02524394950107503, 0.003919061648785369, 0.0006614720556084982]",1.0,1.536697,182890,119015,-0.05
4000,1.245,1.101812,0.010642,"[0.17487256404404014, 0.025800760489606278, 0.004115703757720397, 0.0006907539244849984]",1.0,1.546141,184014,119015,-0.05
4500,1.2109,1.093985,0.010128,"[0.17631593452020064, 0.024937475684988608, 0.003956733733116629, 0.0006047589442152264]",1.0,1.52443,181430,119015,-0.04
5000,1.1879,1.086706,0.010441,"[0.17759593615894934, 0.02545636295480104, 0.0041359714950219405, 0.0006355956687044249]",1.0,1.491963,177566,119015,-0.04


Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/roberta-large/resolve/main/config.json from cache at /mnt/ceph/storage/data-tmp/2021//sile2804/.cache/huggingface/transformers/dea67b44b38d504f2523f3ddb6acb601b23d67bee52c942da336fa1283100990.94cae8b3a8dbab1d59b9d4827f7ce79e73124efa6bb970412cd503383a95f373
Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.9.1",
  "type_vocab_size": 1,
  "u