In [1]:
import os
import sys

sys.path.append('../src-py/')

os.environ['CUDA_VISIBLE_DEVICES']='0'

In [2]:
import csv
import random
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import torch
import transformers
from rouge_score import rouge_scorer
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.nn import CrossEntropyLoss, MSELoss
import os
from model import MultiTaskBart
from model import OurModel
from utils import parse_df
import time
import sys
import tqdm
import pandas as pd
from datasets import load_dataset, load_metric, Dataset

import nltk
pd.set_option('display.max_colwidth', None)

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using GPU? ", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0))


data_dir = '../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/'
teacher_model_path='../../multitask-counter-arg-generation/data/output/stance_classification/best_model/'

batch_size=4

Using GPU?  True
Device name: A100-SXM4-40GB


In [4]:
import prompted_conclusion_utils as conc_utils
from prompted_conclusion_utils import *

from transformers.generation_logits_process import * 

cuda


In [5]:
#Teacher model
stance_classifier_teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)
stance_classifier_teacher_model     = AutoModelForSequenceClassification.from_pretrained(teacher_model_path)
arg_stance_pipeline = TextClassificationPipeline(model=stance_classifier_teacher_model, tokenizer=stance_classifier_teacher_tokenizer, framework='pt', task='stance_classification', device=0)

In [6]:
#Our model versions with the best performing checkpoint on the validation set..
model_without_stance    = OurModel.load('../../multitask-counter-arg-generation/data/output/ca-final-models/mt-v4.baseline_1/trained_models/models-global-step-5000', 'facebook/bart-large',  model_config=transformers.AutoConfig.from_pretrained('facebook/bart-large'))
model_with_stance = OurModel.load('../../multitask-counter-arg-generation/data/output/ca-final-models/mt-v4.baseline_2/trained_models/models-global-step-4000', 'facebook/bart-large',  model_config=transformers.AutoConfig.from_pretrained('facebook/bart-large'))

In [7]:
tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/bart-large')

model_without_stance.to(device)
model_with_stance.to(device)

_ = model_with_stance.eval()
_ = model_without_stance.eval()

In [8]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
        
def get_stance_scores(sents1, sents2):
    #compute stance score using our trained model
    text_inputs = [x[0] + ' </s> ' + x[1] for x in zip(sents1, sents2)]
    stance_results = arg_stance_pipeline(text_inputs, truncation=True)
    stance_labels = [int(x['label'].split('_')[-1]) for x in stance_results]
    stance_scores = [x['score'] for x in stance_results]
    return sum(stance_labels)/len(stance_labels), stance_labels, stance_scores  #The score is the percentage of cases we generated a counter

def counters_coherence(post_conclusions, post_counters):
    post_counters = [nltk.sent_tokenize(x) for x in post_counters]
    conclusion_counter_sent_pairs = [(x[1], s) for x in zip(post_counters, post_conclusions) for s in x[0]]
    #print(conclusion_counter_sent_pairs)
    conclusions, counter_sents = zip(*conclusion_counter_sent_pairs)
    _, stance_labels, stance_scores = get_stance_scores(conclusions, counter_sents)
    stance_scores = [x[0] * -1 if x[1] == 0 else x[0] for x in zip(stance_scores, stance_labels)]

    #collect counter_scores
    counter_scores = []
    idx = 0
    #print(len(stance_scores))
    for i, post_counter in enumerate(post_counters):
        #print(len(post_counter))
        counter_scores.append(stance_scores[idx: idx + len(post_counter)])
        idx+=len(post_counter)
    
    #print(counter_scores)
    return [np.mean(s) for s in counter_scores]

def get_best_counters(conclusions, counters, num_sequences):
    #choose best counter
    best_counters = []
    for chunk in chunks(list(zip(conclusions, counters)), num_sequences):
        chunk_conclusions, chunk_counters = zip(*chunk)
        scores = counters_coherence(chunk_conclusions, chunk_counters)
        best_counters.append((chunk_conclusions[np.argmax(scores)], chunk_counters[np.argmax(scores)]))
        
    return best_counters

def generate_counters(model, tokenizer, data_loader, argument_gen_kwargs, conclusion_gen_kwargs, skip_special_tokens=True):
    processors = LogitsProcessorList()
    generated_counter_arguments = []
    generated_conclusions = []
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            generated_argument_tokens   = model.generate_counter_argument(input_ids, attention_mask, argument_gen_kwargs, processors)
            generated_conclusion_tokens = model.generate_conclusion(input_ids, attention_mask, conclusion_gen_kwargs, processors)
                        
            generated_argument_tokens = generated_argument_tokens.cpu().numpy()
            decoded_arguments = tokenizer.batch_decode(generated_argument_tokens, skip_special_tokens=skip_special_tokens)
            
            generated_conclusion_tokens = generated_conclusion_tokens.cpu().numpy()
            decoded_conclusions = tokenizer.batch_decode(generated_conclusion_tokens, skip_special_tokens=skip_special_tokens)
            
            generated_counter_arguments +=decoded_arguments
            generated_conclusions +=decoded_conclusions
            
    return generated_conclusions, generated_counter_arguments

In [9]:
conclusion_gen_kwargs = {
    "do_sample": False, 
    "min_length":30,
    "top_p":0.95, 
    "num_beams":10,
    "num_return_sequences":10
}


argument_gen_kwargs = {
    "do_sample": True, 
    "max_length":100,
    "min_length":50,
    "top_p":0.95, 
    "no_repeat_ngram_size":3,
    "top_k": 50,
    "num_beams":10,
    "num_return_sequences":10
}

In [10]:
df = pd.read_pickle(data_dir + 'test_conclusion_all_preprocessed.pkl')[['post_id', 'title', 'post', 'counter']]
#df = pd.read_pickle(data_dir + 'sample_test_conclusion_all_preprocessed.pkl')[['post_id', 'title', 'post', 'counter']]
df['post'] = df.post.apply(lambda x: ' '.join(x))

In [11]:
ds = Dataset.from_pandas(df[['post']])
ds = ds.map(lambda a: tokenizer(a['post'], padding='max_length', max_length=256, truncation=True), 
                                   remove_columns=[ '__index_level_0__'], batched=True)

ds.set_format(type='torch', columns=['input_ids', 'attention_mask'])
dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size)



  0%|          | 0/9 [00:00<?, ?ba/s]

In [None]:
no_stance_conclusions, no_stance_counter_arguments = generate_counters(model_without_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)
#stance_conclusions, stance_counter_arguments = generate_counters(model_with_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)

In [None]:
best_no_stance_conclusions, best_no_stance_counters = zip(*get_best_counters(no_stance_conclusions, no_stance_counter_arguments, argument_gen_kwargs['num_return_sequences']))
#best_stance_conclusions, best_stance_counters = zip(*get_best_counters(stance_conclusions, stance_counter_arguments, argument_gen_kwargs['num_return_sequences']))

In [None]:
df['all_pred_counter_arguments_no_stance'] = list(chunks(no_stance_counter_arguments, argument_gen_kwargs['num_return_sequences']))
df['all_pred_conclusions_no_stance'] = list(chunks(no_stance_conclusions, argument_gen_kwargs['num_return_sequences']))

df['pred_counter_arguments_no_stance'] = best_no_stance_counters
df['pred_conclusions_no_stance'] = best_no_stance_conclusions

In [None]:
df[['title', 'pred_conclusions_no_stance', 'pred_counter_arguments_no_stance']].head()

In [None]:
df.to_pickle('../data/output/ca-final-models/mt-v4/results/all_test_preds_df.pkl')

#### Predict single counters:

In [None]:
conclusion_gen_kwargs = {
    "do_sample": False, 
    #"max_length":20,
    "min_length":30,
    "top_p":0.95, 
    "num_beams":1,
    "num_return_sequences":1
}

argument_gen_kwargs = {
    "do_sample": True, 
    "max_length":100,
    "min_length":50,
    "top_p":0.95, 
    "top_k": 50,
    "no_repeat_ngram_size":3,
    "num_beams":4,
    "num_return_sequences":1
}

In [None]:
no_stance_conclusions, no_stance_counter_arguments = generate_counters(model_without_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)
#stance_conclusions, stance_counter_arguments       = generate_counters(model_with_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)

In [None]:
df['single_pred_counter_arguments_no_stance'] = no_stance_counter_arguments
#df['single_pred_counter_arguments_stance'] = stance_counter_arguments

In [None]:
df[['title', 'single_pred_counter_arguments_no_stance' , 'pred_counter_arguments_no_stance']].head()

In [None]:
df.to_pickle('../data/output/ca-final-models/mt-v4/results/all_test_preds_df.pkl')