In [1]:
import os
import sys

os.environ['CUDA_VISIBLE_DEVICES']='1'

In [None]:
import csv
import random
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import torch
import transformers
from rouge_score import rouge_scorer
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.nn import CrossEntropyLoss, MSELoss
import os
from model import MultiTaskBart
from model import OurModel
from utils import parse_df
import time
import sys
import tqdm
import pandas as pd
from datasets import load_dataset, load_metric, Dataset

import nltk
pd.set_option('display.max_colwidth', None)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using GPU? ", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0))


data_dir = '../../data-ceph/arguana/arg-generation/multi-taks-counter-argument-generation/reddit_data/conclusion_and_ca_generation/'
teacher_model_path='../multitask-counter-arg-generation/data/output/stance_classification/best_model/'

batch_size=4

In [None]:
#Teacher model
stance_classifier_teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)
stance_classifier_teacher_model     = AutoModelForSequenceClassification.from_pretrained(teacher_model_path)
arg_stance_pipeline = TextClassificationPipeline(model=stance_classifier_teacher_model, tokenizer=stance_classifier_teacher_tokenizer, framework='pt', task='stance_classification', device=0)

In [5]:
#Our model versions with the best performing checkpoint on the validation set..
model_without_stance    = OurModel.load('../multitask-counter-arg-generation/data/output/ca-final-models/mt-v4.baseline_1/trained_models/models-global-step-5500', 'facebook/bart-large',  model_config=transformers.AutoConfig.from_pretrained('facebook/bart-large'))
model_with_stance = OurModel.load('../multitask-counter-arg-generation/data/output/ca-final-models/mt-v4.baseline_2/trained_models/models-global-step-4000', 'facebook/bart-large',  model_config=transformers.AutoConfig.from_pretrained('facebook/bart-large'))

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/bart-large')

model_without_stance.to(device)
model_with_stance.to(device)

_ = model_with_stance.eval()
_ = model_without_stance.eval()

In [7]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
        
def get_stance_scores(sents1, sents2):
    #compute stance score using our trained model
    text_inputs = [x[0] + ' </s> ' + x[1] for x in zip(sents1, sents2)]
    stance_results = arg_stance_pipeline(text_inputs, truncation=True)
    stance_labels = [int(x['label'].split('_')[-1]) for x in stance_results]
    stance_scores = [x['score'] for x in stance_results]
    return sum(stance_labels)/len(stance_labels), stance_labels, stance_scores  #The score is the percentage of cases we generated a counter

def counters_coherence(post_conclusions, post_counters):
    post_counters = [nltk.sent_tokenize(x) for x in post_counters]
    conclusion_counter_sent_pairs = [(x[1], s) for x in zip(post_counters, post_conclusions) for s in x[0]]
    #print(conclusion_counter_sent_pairs)
    conclusions, counter_sents = zip(*conclusion_counter_sent_pairs)
    _, stance_labels, stance_scores = get_stance_scores(conclusions, counter_sents)
    stance_scores = [x[0] * -1 if x[1] == 0 else x[0] for x in zip(stance_scores, stance_labels)]

    #collect counter_scores
    counter_scores = []
    idx = 0
    #print(len(stance_scores))
    for i, post_counter in enumerate(post_counters):
        #print(len(post_counter))
        counter_scores.append(stance_scores[idx: idx + len(post_counter)])
        idx+=len(post_counter)
    
    #print(counter_scores)
    return [np.mean(s) for s in counter_scores]

def get_best_counters(conclusions, counters, num_sequences):
    #choose best counter
    best_counters = []
    for chunk in chunks(list(zip(conclusions, counters)), num_sequences):
        chunk_conclusions, chunk_counters = zip(*chunk)
        scores = counters_coherence(chunk_conclusions, chunk_counters)
        best_counters.append((chunk_conclusions[np.argmax(scores)], chunk_counters[np.argmax(scores)]))
        
    return best_counters

def generate_counters(model, tokenizer, data_loader, argument_gen_kwargs, conclusion_gen_kwargs, skip_special_tokens=True):
    
    generated_counter_arguments = []
    generated_conclusions = []
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            generated_argument_tokens   = model.generate_counter_argument(input_ids, attention_mask, argument_gen_kwargs)
            generated_conclusion_tokens = model.generate_conclusion(input_ids, attention_mask, conclusion_gen_kwargs)
                        
            generated_argument_tokens = generated_argument_tokens.cpu().numpy()
            decoded_arguments = tokenizer.batch_decode(generated_argument_tokens, skip_special_tokens=skip_special_tokens)
            
            generated_conclusion_tokens = generated_conclusion_tokens.cpu().numpy()
            decoded_conclusions = tokenizer.batch_decode(generated_conclusion_tokens, skip_special_tokens=skip_special_tokens)
            
            generated_counter_arguments +=decoded_arguments
            generated_conclusions +=decoded_conclusions
            
    return generated_conclusions, generated_counter_arguments

In [8]:
conclusion_gen_kwargs = {
    "do_sample": False, 
    #"max_length":20,
    "min_length":30,
    "top_p":0.95, 
    "num_beams":10,
    "num_return_sequences":10
}

# argument_gen_kwargs = {
#     "do_sample": True, 
#     "max_length":100,
#     "min_length":50,
#     "top_p":0.95, 
#     "num_beams":5,
#     "num_return_sequences":5
# }

argument_gen_kwargs = {
    "do_sample": True, 
    "max_length":100,
    "min_length":50,
    "top_p":0.95, 
    "no_repeat_ngram_size":3,
    "top_k": 50,
    "num_beams":10,
    "num_return_sequences":10
}

In [9]:
#df = pd.read_pickle(data_dir + 'test_conclusion_all_preprocessed.pkl')[['post_id', 'title', 'post', 'counter']]
df = pd.read_pickle(data_dir + 'sample_test_conclusion_all_preprocessed.pkl')[['post_id', 'title', 'post', 'counter']]
df['post'] = df.post.apply(lambda x: ' '.join(x))

In [10]:
ds = Dataset.from_pandas(df[['post']])
ds = ds.map(lambda a: tokenizer(a['post'], padding='max_length', max_length=256, truncation=True), 
                                   remove_columns=[ '__index_level_0__'], batched=True)

ds.set_format(type='torch', columns=['input_ids', 'attention_mask'])
dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size)

  0%|          | 0/2 [00:00<?, ?ba/s]

In [11]:
no_stance_conclusions, no_stance_counter_arguments = generate_counters(model_without_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)
#stance_conclusions, stance_counter_arguments = generate_counters(model_with_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)

  next_indices = next_tokens // vocab_size
  next_indices = next_tokens // vocab_size


In [12]:
best_no_stance_conclusions, best_no_stance_counters = zip(*get_best_counters(no_stance_conclusions, no_stance_counter_arguments, argument_gen_kwargs['num_return_sequences']))
#best_stance_conclusions, best_stance_counters = zip(*get_best_counters(stance_conclusions, stance_counter_arguments, argument_gen_kwargs['num_return_sequences']))

In [13]:
df['all_pred_counter_arguments_no_stance'] = list(chunks(no_stance_counter_arguments, argument_gen_kwargs['num_return_sequences']))
df['all_pred_conclusions_no_stance'] = list(chunks(no_stance_conclusions, argument_gen_kwargs['num_return_sequences']))

#df['all_pred_counter_arguments_stance'] = list(chunks(stance_counter_arguments, argument_gen_kwargs['num_return_sequences']))
#df['all_pred_conclusions_stance'] = list(chunks(stance_conclusions, argument_gen_kwargs['num_return_sequences']))

#-----------------

df['pred_counter_arguments_no_stance'] = best_no_stance_counters
df['pred_conclusions_no_stance'] = best_no_stance_conclusions

#df['pred_counter_arguments_stance'] = best_stance_counters
#df['pred_conclusions_stance'] = best_stance_conclusions

In [16]:
df[['title', 'pred_conclusions_no_stance', 'pred_counter_arguments_no_stance']].head()

Unnamed: 0,title,pred_conclusions_no_stance,pred_counter_arguments_no_stance
500025,I Don't Believe in the Big Bang,"I don't think the Big Bang is a fact, and don't believe it's","the big bang is a scientific theory, not a fact. it can't be proven or disproved, but it can be argued that the universe was created in 6 days and rested on the 7th. there is no reason to doubt the existence of the big bang."
534021,There is no reason for Britain to remain in the EU.,The UK should leave the EU and stay in the EEC. Why should,"i don't think britain has any interest in leaving the eu. british sovereignty is not an issue, it's a matter of national sovereignty. it's not like the uk has the right to tell you what you can and can't do, or what you are allowed to do, and how you can get around them. if you want to be an island nation, you have to be able to make your own laws, and if you don't"
415645,The Trump Administration Would Be Significantly More Harmful to US Domestic and International Policy W/out Moderates Like Mattis,Donald Trump's decision to ban on transgender people in the military is a good thing,"i don't think that mattis should be held accountable for the actions of the president. the president has every right to do whatever he wants, but that doesn't mean that he has a moral obligation to follow through on his campaign promises."
507298,Poisoning the well against fascists is a dishonest and therefore ineffective way of combating fascists,"Everyone seems to be a fascist these days, at least to the ""normies""","i think it's important to note that fascism is a very broad term. it can be applied to a wide range of political movements and ideologies, and it can also be used to refer to any group of people who have a particular set of beliefs. the definition of 'fascism' has changed a lot in the past few decades, but that doesn't mean that it's no longer a useful term. i don't think you have to be a fascist to be an"
505763,Women have an easier time finding romantic partners than men,Women have an easier time finding a romantic partner than men. I don't think,"i think you're looking at it the wrong way. it's not that women have an easier time finding romantic partners than men, it's that women are much more likely to date incels than men. i'm not saying that men have a harder time finding a romantic partner than women, i'm just saying that they are less likely to be able to find a partner who isn't an incel."


In [17]:
df.to_pickle('../multitask-counter-arg-generation/data/output/ca-final-models/mt-v4/results/test_preds_df.pkl')

#### Predict single counters:

In [18]:
conclusion_gen_kwargs = {
    "do_sample": False, 
    #"max_length":20,
    "min_length":30,
    "top_p":0.95, 
    "num_beams":1,
    "num_return_sequences":1
}

argument_gen_kwargs = {
    "do_sample": True, 
    "max_length":100,
    "min_length":50,
    "top_p":0.95, 
    "top_k": 50,
    "no_repeat_ngram_size":3,
    "num_beams":4,
    "num_return_sequences":1
}

In [19]:
no_stance_conclusions, no_stance_counter_arguments = generate_counters(model_without_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)
#stance_conclusions, stance_counter_arguments       = generate_counters(model_with_stance, tokenizer, dataloader, argument_gen_kwargs, conclusion_gen_kwargs)

  next_indices = next_tokens // vocab_size


In [20]:
df['single_pred_counter_arguments_no_stance'] = no_stance_counter_arguments
#df['single_pred_counter_arguments_stance'] = stance_counter_arguments

In [21]:
df[['title', 'single_pred_counter_arguments_no_stance' , 'pred_counter_arguments_no_stance']].head()

Unnamed: 0,title,single_pred_counter_arguments_no_stance,pred_counter_arguments_no_stance
500025,I Don't Believe in the Big Bang,"science is the best we have to go off of. the big bang theory is the only thing that can explain the existence of the universe. if there is a god, then the universe can't be true. the universe is just a simulation, and nothing more.","the big bang is a scientific theory, not a fact. it can't be proven or disproved, but it can be argued that the universe was created in 6 days and rested on the 7th. there is no reason to doubt the existence of the big bang."
534021,There is no reason for Britain to remain in the EU.,"the eu is not a single entity. britain is a part of the eu, but it does not have the right to take part in it. if british citizens want to live in europe, they are free to do so, but that doesn't mean that they have a say in how they live their lives.","i don't think britain has any interest in leaving the eu. british sovereignty is not an issue, it's a matter of national sovereignty. it's not like the uk has the right to tell you what you can and can't do, or what you are allowed to do, and how you can get around them. if you want to be an island nation, you have to be able to make your own laws, and if you don't"
415645,The Trump Administration Would Be Significantly More Harmful to US Domestic and International Policy W/out Moderates Like Mattis,"i think you're missing a few things. first of all, he is a narcissist. he's a sociopath. second, he's an idiot. third, he can't be trusted to be a good president. fourth, he has shown that he is not capable of being a president. fifth, he doesn't care about the military. sixth, he isn't a good commander in chief. seventh, he hates his own country. eighth, he believes that the","i don't think that mattis should be held accountable for the actions of the president. the president has every right to do whatever he wants, but that doesn't mean that he has a moral obligation to follow through on his campaign promises."
507298,Poisoning the well against fascists is a dishonest and therefore ineffective way of combating fascists,"fascism is not a movement, it's a term. it is a movement. fascism is a political movement. it's not an ideology, it is an ideology. fascism isn't a movement; it's an ideology that exists and has existed for as long as humans have existed, and it still exists today. there's no reason to believe that it will change in the near future.","i think it's important to note that fascism is a very broad term. it can be applied to a wide range of political movements and ideologies, and it can also be used to refer to any group of people who have a particular set of beliefs. the definition of 'fascism' has changed a lot in the past few decades, but that doesn't mean that it's no longer a useful term. i don't think you have to be a fascist to be an"
505763,Women have an easier time finding romantic partners than men,"women have a harder time finding a partner than men, because they are more likely to be in a relationship with someone who is anorexic. that's because women are less likely to want to date someone who has an incel mindset. i'm not saying that men have a hard time finding romantic partners because they have an ancel mindset, but i am saying that a lot of people who are incel have a much harder time in finding a romantic partner because they","i think you're looking at it the wrong way. it's not that women have an easier time finding romantic partners than men, it's that women are much more likely to date incels than men. i'm not saying that men have a harder time finding a romantic partner than women, i'm just saying that they are less likely to be able to find a partner who isn't an incel."


In [22]:
df.to_pickle('../multitask-counter-arg-generation/data/output/ca-final-models/mt-v4/results/test_preds_df.pkl')

In [24]:
print('Am done...')

Am done...
