### This notebook is to jointly train BART-v2 model for both generating the conclusion and the counter

In [8]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [10]:
sys.path.append('../src-py')

In [12]:
import transformers
import datasets
from utils import *
from mt_bart_v2 import *

print(f"Running on transformers v{transformers.__version__} and datasets v{datasets.__version__}")

Running on transformers v4.9.1 and datasets v1.10.2


In [13]:
import torch
import json

import nltk
import numpy as np
import pandas as pd

from pathlib import Path
from datasets import load_dataset, load_metric, Dataset

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BartTokenizer, BartForConditionalGeneration

import ray
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.schedulers import PopulationBasedTraining
from ray import tune
from ray.tune import CLIReporter

In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [15]:
device

device(type='cuda')

In [16]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [17]:
def get_model(params):
    compute_dynamic_weights=False
    conc_loss_weight=0.5 if params == None else params['conc_loss_weight']
    counter_loss_weight=0.5 if params == None else params['counter_loss_weight']
    attention_to_conc=False
    conc_decoder=True
    model     = BartModelV2.from_pretrained('facebook/bart-base', compute_dynamic_weights=False, 
                                            conc_loss_weight = conc_loss_weight, 
                                            counter_loss_weight=counter_loss_weight, 
                                            attention_to_conc=attention_to_conc, 
                                            conc_decoder=conc_decoder).to(device)

    original_bart_model = BartModel.from_pretrained('facebook/bart-base').to(device)

    #load the weights of the two decoders
    model.conclusion_decoder.load_state_dict(original_bart_model.decoder.state_dict())
    model.counter_decoder.load_state_dict(original_bart_model.decoder.state_dict())
    
    data_collator= DataCollatorForSeq2Seq(tokenizer, model)
    
    return data_collator, model

In [19]:
data_fold = '../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/'

In [20]:
#Taking unique posts from valid dataset and sample only 1500 instances
# valid_df = pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc.pkl')
# valid_unique_df = valid_df.drop_duplicates('post_id')
# valid_sample_df = valid_unique_df.sample(1500)
# valid_sample_df.to_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc_sample.pkl')

# test_df = pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/test_concusion_comp_remove_75sem_perc.pkl')
# test_unique_df = test_df.drop_duplicates('post_id')
# test_unique_df.to_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/test_concusion_comp_remove_75sem_perc_sample.pkl')

#Taking unique posts from valid dataset and sample only 1500 instances
# valid_df = pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_all.pkl')
# valid_unique_df = valid_df.drop_duplicates('post_id')
# valid_sample_df = valid_unique_df.sample(1500)
# valid_sample_df.to_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_all_sample.pkl')

# test_df = pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/test_concusion_all.pkl')
# test_unique_df = test_df.drop_duplicates('post_id')
# test_unique_df = test_df.sample(2500)
# test_unique_df.to_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/test_concusion_all_sample.pkl')

In [21]:
# train_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/train_conclusion_comp_remove_75sem_perc.pkl'))
# valid_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_comp_remove_75sem_perc_sample.pkl'))

train_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/train_conclusion_all.pkl'))
valid_ds = Dataset.from_pandas(pd.read_pickle(data_fold+'/reddit_data/conclusion_and_ca_generation/valid_conclusion_all_sample.pkl'))

In [22]:
#Encoding function for joint generation of conclusion and counter
def preprocess_function(examples, tokenizer, premises_clm, counter_clm, conclusion_clm, max_input_length=512, max_conc_length=100, max_counter_length=200):
    premises   = examples[premises_clm]
    conclusions = examples[conclusion_clm]
    counters = examples[counter_clm]
    
        
    premises = [' '.join(x) for x in premises] if isinstance(premises[0], list) else premises
    counters = [' '.join(x) for x in counters] if isinstance(counters[0], list) else counters
    conclusions = [' '.join(x) for x in conclusions] if isinstance(conclusions[0], list) else conclusions
    
    model_inputs = tokenizer(premises, max_length=max_input_length, truncation=True, padding='max_length')
        
    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        counter_labels = tokenizer(counters, max_length=max_counter_length, truncation=True, padding='max_length')
        conclusion_labels = tokenizer(conclusions, max_length=max_conc_length, truncation=True, padding='max_length')

    model_inputs["conclusion_labels"] = conclusion_labels["input_ids"]
    model_inputs["conclusion_decoder_attention_mask"] = conclusion_labels['attention_mask']
    model_inputs["labels"] = counter_labels["input_ids"]
    model_inputs["decoder_attention_mask"] = counter_labels['attention_mask']
    
    return model_inputs

In [23]:
#downsample the training dataset
#tmp_ds = train_ds.train_test_split(0.005)
#train_ds = tmp_ds['test']

In [24]:
len(train_ds)

295914

In [25]:
len(valid_ds)

1500

In [26]:
# train_tokenized_ds = train_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)
# valid_tokenized_ds = valid_ds.map(lambda x :preprocess_function(x, tokenizer, 'masked_premises', 'counter', 'title'), batched=True)

train_tokenized_ds = train_ds.map(lambda x :preprocess_function(x, tokenizer, 'post', 'counter', 'title'), batched=True)
valid_tokenized_ds = valid_ds.map(lambda x :preprocess_function(x, tokenizer, 'post', 'counter', 'title'), batched=True)

  0%|          | 0/296 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [17]:
# #Train the model for different config
# batch_size = 32
# for conc_loss_weight, count_loss_weight in [(0.0, 1.0), (0.8, 0.2), (0.6, 0.4)]:
#         data_collator, model = get_model({'conc_loss_weight': conc_loss_weight, 'counter_loss_weight': count_loss_weight})
#         args = Seq2SeqTrainingArguments(
#             "../data/output/joint-con-counter-bart-model-no-attention-finetuned/{}-{}".format(str(conc_loss_weight).replace('.','-'), str(count_loss_weight).replace('.','-')),
#             evaluation_strategy = "steps",
#             learning_rate=2e-5,
#             per_device_train_batch_size=batch_size,
#             per_device_eval_batch_size=batch_size,
#             weight_decay=0.01,
#             save_total_limit=5,
#             num_train_epochs=3,
#             load_best_model_at_end=True,
#             predict_with_generate=True,
#             metric_for_best_model='bert-fscore',
#             label_names=['conclusion_labels', 'counter_labels']
#         )

#         trainer = Seq2TwoSeqTrainer(
#             model,
#             args,
#             train_dataset=train_tokenized_ds,
#             eval_dataset=valid_tokenized_ds,
#             data_collator=data_collator,
#             tokenizer=tokenizer,
#             compute_metrics=lambda x : compute_metrics(x, tokenizer)
#         )
        
#         trainer.train()
#         trainer.save_model()

#### Train a dyanmic weighting model:

In [27]:
batch_size = 32

In [28]:
model = BartModelV2.from_pretrained('facebook/bart-base', compute_dynamic_weights=True, conc_decoder=True).to(device)

Some weights of BartModelV2 were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['conclusion_decoder.layernorm_embedding.bias', 'log_vars', 'conclusion_decoder.embed_tokens.weight', 'conclusion_decoder.layernorm_embedding.weight', 'conclusion_decoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
original_bart_model = BartModel.from_pretrained('facebook/bart-base').to(device)
#load the weights of the two decoders
model.conclusion_decoder.load_state_dict(original_bart_model.decoder.state_dict())

<All keys matched successfully>

In [30]:
# data_collator= DataCollatorForSeq2Seq(tokenizer, model)

# args = Seq2SeqTrainingArguments(
#     "../data/output/joint-con-counter-bart-model-no-attention-finetuned/dynamic-weight",
#     evaluation_strategy = "steps",
#     eval_steps=500,
#     learning_rate=2e-5,
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     weight_decay=0.01,
#     save_total_limit=5,
#     num_train_epochs=6,
#     load_best_model_at_end=True,
#     predict_with_generate=True,
#     metric_for_best_model='bert-fscore',
#     label_names=['labels', 'conclusion_labels']
# )

# trainer = Seq2SeqTrainer(
#     model,
#     args,
#     train_dataset=train_tokenized_ds,
#     eval_dataset=valid_tokenized_ds,
#     data_collator=data_collator,
#     tokenizer=tokenizer,
#     compute_metrics=lambda x : compute_metrics(x, tokenizer)
# )

data_collator= DataCollatorForSeq2Seq(tokenizer, model)

args = Seq2SeqTrainingArguments(
    "../data/output/joint-con-counter-bart-model-no-attention-finetuned-on-all-data/dynamic-weight",
    evaluation_strategy = "steps",
    eval_steps=500,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=5,
    num_train_epochs=6,
    load_best_model_at_end=True,
    predict_with_generate=True,
    metric_for_best_model='bert-fscore',
    label_names=['labels', 'conclusion_labels']
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_tokenized_ds,
    eval_dataset=valid_tokenized_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=lambda x : compute_metrics(x, tokenizer)
)

In [None]:
trainer.train()
trainer.save_model()

The following columns in the training set  don't have a corresponding argument in `BartModelV2.forward` and have been ignored: title, split, comment_id, post, counter, n_sentences, __index_level_0__, post_id.
***** Running training *****
  Num examples = 295914
  Num Epochs = 6
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 55488


Step,Training Loss,Validation Loss
