In [4]:
import sys
sys.path.append('../../multitask-learning-transformers/shared_encoder')

In [1]:
import logging
import torch
import torch.nn as nn
import nltk
import numpy as np
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
from tqdm import tqdm as tqdm1

In [2]:
torch.cuda.is_available()

True

In [41]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [7]:
import transformers
#from accelerate import Accelerator
from filelock import FileLock
from transformers import set_seed
from transformers.file_utils import is_offline_mode
from utils.arguments import parse_args
from multitask_model import MultitaskModel
#from preprocess import convert_to_features
from multitask_data_collator import MultitaskTrainer, NLPDataCollator
from multitask_eval import multitask_eval_fn
from checkpoint_model import save_model
from pathlib import Path

In [8]:
# from evaluate_bleu import *
from sklearn.metrics import f1_score, classification_report

In [9]:
set_seed(42)

In [16]:
import pandas as pd
import pickle

df_train = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/preprocessed_train_conclusion_all.pkl')
df_validation = pd.read_pickle('../../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/sample_valid_conclusion_all_preprocessed.pkl')

In [17]:
tasks = ['counter_gen', 'title_gen']

In [20]:
df_train['title_input']=df_train.apply(lambda x:str(x['post'])+'<s>'+str(x['title']),axis=1)
df_train['counter_input']=df_train.apply(lambda x:str(x['post'])+'<s>'+str(x['counter']),axis=1)

df_validation['title_input']=df_validation.apply(lambda x:str(x['post'])+'<s>'+str(x['title']),axis=1)
df_validation['counter_input']=df_validation.apply(lambda x:str(x['post'])+'<s>'+str(x['counter']),axis=1)

In [21]:
df_train_new = df_train[['title_input','counter_input']]
df_validate_new = df_validation[['title_input','counter_input']]

In [22]:
model_name = 'facebook/bart-base'

In [23]:
model_names = [model_name] * 2

In [24]:
class MultitaskBartModel(transformers.PreTrainedModel):
    def __init__(self, encoder, taskmodels_dict):
        """
        Setting MultitaskModel up as a PretrainedModel allows us
        to take better advantage of Trainer features
        """
        super().__init__(transformers.PretrainedConfig())

        self.encoder = encoder
        self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict):
        """
        This creates a MultitaskModel using the model class and config objects
        from single-task models.

        We do this by creating each single-task model, and having them share
        the same encoder transformer.
        """
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name,
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = model.model.encoder
            else:
                model.model.encoder = shared_encoder
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)
    
    @classmethod
    def load(cls, model_folder, model_type_dict, model_config_dict):
        """
        This loads a MultitaskModel using the model class and config objects
        from single-task models.
        """
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                f"{model_folder}/{task_name}_model",
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = model.model.encoder
            else:
                model.model.encoder = shared_encoder
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)

    def forward(self, task_name, **kwargs):
        return self.taskmodels_dict[task_name](**kwargs)
    
    def generate_text(self, task_name, input_, num_beams=1, early_stopping=False, max_length=512, length_penalty=1.0):
        return self.taskmodels_dict[task_name].generate(
            input_, 
            num_beams=num_beams, 
            early_stopping=early_stopping, 
            max_length=max_length, 
            length_penalty=length_penalty,
            #return_dict_in_generate=True, 
            #output_scores=True
        )

    def resize_token_embeddings(self, new_num_tokens):
        for task_name, model in self.taskmodels_dict.items():
            model.resize_token_embeddings(new_num_tokens)
        

In [48]:
multitask_model = MultitaskBartModel.create(
    model_name=model_names[0],
    model_type_dict={
        "counter_gen": transformers.BartForConditionalGeneration,
        "title_gen": transformers.BartForConditionalGeneration,
    },
    model_config_dict={
        "counter_gen": transformers.AutoConfig.from_pretrained('facebook/bart-base'),
        "title_gen": transformers.AutoConfig.from_pretrained('facebook/bart-base'),
    },
)

loading configuration file https://huggingface.co/facebook/bart-base/resolve/main/config.json from cache at /mnt/ceph/storage/data-tmp/current//sile2804/.cache/huggingface/transformers/f5310d276a6d1648d00c32fadc8bf7b4607e0fbd5b404fc4a0045960aa2bdfdb.a243ed957122436adb0b8d8e9d20f896f45c174b6324d625ca0a20a84f72a910
Model config BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos

In [49]:
tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/bart-base')

Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/facebook/bart-base/resolve/main/config.json from cache at /mnt/ceph/storage/data-tmp/current//sile2804/.cache/huggingface/transformers/f5310d276a6d1648d00c32fadc8bf7b4607e0fbd5b404fc4a0045960aa2bdfdb.a243ed957122436adb0b8d8e9d20f896f45c174b6324d625ca0a20a84f72a910
Model config BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim"

In [28]:
from datasets import Dataset
dataset = Dataset.from_pandas(df_train)
# title_dataset = Dataset.from_pandas(df_train_new)

In [29]:
dataset

Dataset({
    features: ['post_id', 'split', 'comment_id', 'title', 'post', 'n_sentences', 'counter', 'bot_comment', 'counter_conclusion', 'counter_conclusions', 'title_input', 'counter_input', '__index_level_0__'],
    num_rows: 25704
})

In [30]:
def preprocess_function(examples, tokenizer, input_clm, output_clm, max_input_length=512, max_target_length=512):
    input_examples  = examples[input_clm]
    output_examples = examples[output_clm]
        
    if isinstance(input_examples[0], list):
        input_examples = [' '.join(x) for x in input_examples]
    
    processed_output = tokenizer(input_examples, max_length=max_input_length, truncation=True, padding='max_length')
    
    if isinstance(output_examples[0], list):
        output_examples = [' '.join(x) for x in output_examples]
    

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(output_examples, max_length=max_target_length, truncation=True, padding='max_length')    
    
    processed_output["labels"] = labels["input_ids"]
    return processed_output

In [53]:
tokenized_counter_ds = dataset.map(lambda a: preprocess_function(a, tokenizer, 'post', 'counter', 512, 256),
                                   remove_columns=['post_id', 'split', 'comment_id', 'title', 'post', 'n_sentences', 'counter', 'bot_comment', 'counter_conclusion', 'counter_conclusions', 'title_input', 'counter_input', '__index_level_0__'],
                                   batched=True)
tokenized_conclusion_ds = dataset.map(lambda a: preprocess_function(a, tokenizer, 'post', 'title', 512, 100),
                                      remove_columns=['post_id', 'split', 'comment_id', 'title', 'post', 'n_sentences', 'counter', 'bot_comment', 'counter_conclusion', 'counter_conclusions', 'title_input', 'counter_input', '__index_level_0__'],
                                      batched=True)

  0%|          | 0/26 [00:00<?, ?ba/s]

  0%|          | 0/26 [00:00<?, ?ba/s]

In [54]:
tokenized_dict = {'title_gen':tokenized_conclusion_ds,'counter_gen':tokenized_counter_ds}

In [55]:
tokenized_conclusion_ds

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 25704
})

In [56]:
tokenized_counter_ds[0]

{'input_ids': [0,
  118,
  679,
  14,
  10,
  2352,
  16,
  101,
  143,
  97,
  1963,
  442,
  265,
  8,
  3891,
  144,
  582,
  1122,
  2556,
  4,
  939,
  524,
  45,
  10,
  3458,
  621,
  2185,
  98,
  939,
  109,
  45,
  216,
  5,
  1498,
  10575,
  9,
  5,
  903,
  8,
  1408,
  9,
  10,
  2352,
  53,
  939,
  109,
  1346,
  51,
  64,
  3363,
  10,
  205,
  1280,
  9,
  1055,
  4,
  45,
  4378,
  686,
  141,
  6030,
  42,
  1566,
  16,
  6,
  53,
  24,
  982,
  14,
  52,
  115,
  2364,
  41,
  943,
  6121,
  325,
  1932,
  228,
  76,
  8,
  11,
  10,
  86,
  147,
  52,
  25,
  10,
  247,
  32,
  11,
  35245,
  9,
  1932,
  9,
  1126,
  27892,
  10,
  2352,
  1302,
  5701,
  7,
  162,
  4,
  25434,
  705,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  

In [58]:
train_dataset = {
    'counter_gen':tokenized_counter_ds, 'title_gen':tokenized_conclusion_ds
}

# train

best parameters: lr 5e-05, epochs 7, bs 4, numbeams 10; f1: 0.7980, bleu: 0.4150

In [59]:
trainer = MultitaskTrainer(
    model=multitask_model,
    args=transformers.TrainingArguments(
        output_dir='multitask_trainer_output',
        overwrite_output_dir=True,
        learning_rate=5e-5,
        do_train=True,
        num_train_epochs=7,
        per_device_train_batch_size=4,
        save_steps=3000,
        report_to='none'
    ),
    data_collator=DataCollatorForSeq2Seq(tokenizer, multitask_model),
    train_dataset=tokenized_dict,
)

PyTorch: setting up devices


In [60]:
trainer.train()

***** Running training *****
  Num examples = 51408
  Num Epochs = 7
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 44982


Step,Training Loss


KeyboardInterrupt: 

In [52]:
def save_model(multitask_model, suffix=''):
    for task_name in ["input_comment_pairs", "tfidf_cluster"]:
        multitask_model.taskmodels_dict[task_name].config.to_json_file(
            f"./multitask_model{suffix}/{task_name}_model/config.json"
        )
        torch.save(
            multitask_model.taskmodels_dict[task_name].state_dict(),
            f"./multitask_model{suffix}/{task_name}_model/pytorch_model.bin",
        )
        tokenizer.save_pretrained(f"./multitask_model{suffix}/{task_name}_model/")

In [None]:
save_model(multitask_model)

# eval

In [None]:
multitask_model = MultitaskBartModel.load(
    model_folder='multitask_model_best',
    model_type_dict={
        "tfidf_cluster": transformers.BartForSequenceClassification,
        "input_comment_pairs": transformers.BartForConditionalGeneration,
    },
    model_config_dict={
        "tfidf_cluster": transformers.AutoConfig.from_pretrained(
            'multitask_model/tfidf_cluster_model', num_labels=num_labels
        ),
        "input_comment_pairs": transformers.AutoConfig.from_pretrained(
            'multitask_model/input_comment_pairs_model'
        ),
    },
)

In [None]:
_ = multitask_model.to('cuda')

In [None]:
_ = multitask_model.eval()

In [None]:
def eval_cluster_fn(multitask_model, model_name, features_dict, batch_size=8):
    metric = load_metric("f1")
    task_name = 'tfidf_cluster'
    val_len = len(features_dict[task_name]["validation"])
    
    preds_all = []
    refs_all = []

    for index in tqdm(range(0, val_len, batch_size)):
        input_ = features_dict[task_name]["validation"][
            index : min(index + batch_size, val_len)
        ]['input_ids']
        labels = features_dict[task_name]["validation"][
            index : min(index + batch_size, val_len)
        ]["labels"]
        attention_masks = features_dict[task_name]["validation"][
            index : min(index + batch_size, val_len)
        ]['attention_mask']            

        inputs={}
        inputs["input_ids"] = torch.LongTensor(input_).to(multitask_model.device)
        inputs["attention_mask"] = torch.LongTensor(attention_masks).to(multitask_model.device)
        logits = multitask_model("tfidf_cluster", **inputs)[0]
        preds_all.extend(torch.argmax(
            torch.FloatTensor(torch.softmax(logits, dim=1).detach().cpu().tolist()),
            dim=1,
        ))
        refs_all.extend(labels)

    metric.add_batch(predictions=preds_all, references=refs_all)       
    print(f"F1: {metric.compute(average='macro')}")
    return preds_all, refs_all

In [None]:
preds_all, refs_all = eval_cluster_fn(multitask_model, model_name, features_dict)

In [None]:
print(classification_report(refs_all, preds_all))

In [None]:
def eval_bleu_fn(multitask_model, model_name, features_dict, batch_size=8, num_beams=7):
    preds_all = []
    scores = []
    refs_all = []
    #in_all = []
    for task_name in ['input_comment_pairs']:
        val_len = len(features_dict[task_name]["validation"])
        
        for index in tqdm(range(0, val_len, batch_size)):            
            input_ = features_dict[task_name]["validation"][
                index : min(index + batch_size, val_len)
            ]['input_ids']
            labels = features_dict[task_name]["validation"][
                index : min(index + batch_size, val_len)
            ]["labels"]
            attention_masks = features_dict[task_name]["validation"][
                index : min(index + batch_size, val_len)
            ]['attention_mask']            
                       
            with torch.no_grad():
                outputs = multitask_model.generate_text("input_comment_pairs", input_.to(multitask_model.device), num_beams=num_beams, early_stopping=True, max_length=512)
            in_all.extend([tokenizer.decode(inp, skip_special_tokens=True) for inp in input_])
            #preds_all.extend([postprocess(tokenizer.decode(out, skip_special_tokens=True)) for out in outputs])
            preds_all.extend([tokenizer.decode(out, skip_special_tokens=True) for out in outputs])
            #preds_all.extend([tokenizer.decode(out, skip_special_tokens=True) for out in outputs.sequences])
            #scores.extend(outputs.sequences_scores)
            refs_all.extend([tokenizer.decode(ref, skip_special_tokens=True) for ref in labels])

    print_bleu(refs_all, preds_all)
    return preds_all, refs_all, in_all #, scores

In [None]:
preds_all, refs_all, in_all = eval_bleu_fn(multitask_model, model_name, features_dict, num_beams=1)

In [None]:
save_model(multitask_model, '_best')

In [None]:
with open('multitask-tfidf-tagging-preds.txt', 'w') as f:
    for pred in preds_all:
        f.write("%s\n" % pred)

In [None]:
c

In [None]:
c = 0
for i in range(len(refs_all)):
    print_ = True
    if refs_all[i] == preds_all[i]:
        print_ = False
        c+=1
    
    if print_:
        print(i+1, in_all[i])
        print('\tRef:',  refs_all[i])
        print('\tOur:', preds_all[i])
        print()