In [1]:
%load_ext autoreload

In [2]:
import os
import sys
from scipy import stats
import re
import random
from argparse import Namespace

sys.path.append('../src-py/')

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [10]:
%autoreload

from utils import *
from ca_utils import *

2022-08-10 16:22:29,919 loading file ../../data-ceph/arguana/arg-generation/claim-target-tagger/model/final-model.pt
2022-08-10 16:22:57,845 SequenceTagger predicts: Dictionary with 4 tags: <unk>, B-CT, I-CT, O


In [4]:
import pickle
import torch
import json

import nltk
import numpy as np
import pandas as pd

from pathlib import Path
from tabulate import tabulate
pd.set_option('display.max_colwidth', None)

import matplotlib.pyplot as plt

In [5]:
from datasets import load_dataset, load_metric, Dataset

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

## General:

In [7]:
local_home_dir = '../sample-data'

data_unique_path = '../sample-data/test_conclusion_all_preprocessed.pkl'
data_path = '../sample-data/test_conclusion_all.pkl'

In [11]:
def generate_ds_attacks(ds, model, tokenizer, premises_clm, conclusion_clm, gen_kwargs, skip_special_tokens=True, batch_size=5):
    
    ds = ds.map(lambda x :preprocess_function(x, tokenizer, premises_clm, 'counter', conclusion_clm=conclusion_clm), batched=True)
    ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
    dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    attacks = generate_counters(model, tokenizer, dataloader, gen_kwargs, skip_special_tokens=skip_special_tokens)
    
    return attacks

def create_predictions_df(reddit_sample_valid_ds, gen_kwargs, premises_clm='masked_premises'):
   
    known_conc_attacks  = generate_ds_attacks(reddit_sample_valid_ds, known_conclusion_model, known_conclusion_tokenizer, premises_clm, 'title', gen_kwargs)
    bart_conc_attacks   = generate_ds_attacks(reddit_sample_valid_ds, known_conclusion_model, known_conclusion_tokenizer, premises_clm, 'bart_conclusion', gen_kwargs)
    masked_conc_attacks = generate_ds_attacks(reddit_sample_valid_ds, known_conclusion_model, known_conclusion_tokenizer, premises_clm, None, gen_kwargs)
    
    #update max_gen_length to account to the generated conclusion
    gen_kwargs['max_length'] = gen_kwargs['max_length'] + 50
    joint_conc_baseline_attacks  = generate_ds_attacks(reddit_sample_valid_ds, pred_conclusion_model, pred_conclusion_tokenizer, premises_clm, None, gen_kwargs, skip_special_tokens=False)

    reddit_pred_df = pd.DataFrame(list(zip(
                                           reddit_sample_valid_ds['post_id'],
                                           reddit_sample_valid_ds['title'], 
                                           reddit_sample_valid_ds['conclusion_targets'],
                                           reddit_sample_valid_ds['conclusion_stance'],
                                           reddit_sample_valid_ds['bart_conclusion'], 
                                           reddit_sample_valid_ds[premises_clm],
                                           reddit_sample_valid_ds['counter'], 
                                           known_conc_attacks, masked_conc_attacks, 
                                           bart_conc_attacks, joint_conc_baseline_attacks)), 
                    columns=['post_id', 'conclusion', 'conclusion_target', 'conclusion_stance', 'bart_conclusion', 
                             'premises', 'gt_attack', 'known_conc_attacks', 'masked_conc_attacks', 
                             'bart_conc_attacks',  'joint_conc_baseline_attacks'])

    reddit_pred_df['argument'] = reddit_pred_df.apply(lambda row: row['conclusion'] + ' : ' + ' '.join(row['premises']), axis=1)
    reddit_pred_df['premises'] = reddit_pred_df['premises'].apply(lambda x: ' '.join(x))

    #process the jointly generated conclusion and counter
    reddit_pred_df['joint_conc_baseline'] = reddit_pred_df['joint_conc_baseline_attacks'].apply (lambda x: x.split('<counter>')[0])
    reddit_pred_df['joint_conc_baseline_attacks'] = reddit_pred_df['joint_conc_baseline_attacks'].apply (lambda x: x.split('<counter>')[1] if '<counter>' in x else x)
    reddit_pred_df['joint_conc_baseline'] = reddit_pred_df['joint_conc_baseline'].apply (lambda x: re.sub('<s>|</s>|<conclusion>|<counter>|<pad>', '', x).strip())
    reddit_pred_df['joint_conc_baseline_attacks'] = reddit_pred_df['joint_conc_baseline_attacks'].apply (lambda x: re.sub('<s>|</s>|<conclusion>|<counter>|<pad>', '', x).strip())

    return reddit_pred_df

----------------------

## Generated Predictions:

In [12]:
known_conclusion_model = BartForConditionalGeneration.from_pretrained(local_home_dir + '/models/known-conc-model/checkpoint-9500').to(device)
known_conclusion_tokenizer = BartTokenizer.from_pretrained(local_home_dir + '/models/known-conc-model/checkpoint-9500')

pred_conclusion_model = BartForConditionalGeneration.from_pretrained(local_home_dir + '/models/pred-conc-model').to(device)
pred_conclusion_tokenizer = BartTokenizer.from_pretrained(local_home_dir + '/models/pred-conc-model')

In [13]:
valid_df = pd.read_pickle( data_unique_path)

In [14]:
#Create a dataset
print('Testing on {} posts'.format(len(valid_df)))
valid_ds = Dataset.from_pandas(valid_df.sample(10))
valid_ds = valid_ds.flatten_indices()

Testing on 8533 posts


Flattening the indices:   0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
#Generate counters according to best parameters...
gen_kwargs = {
    "do_sample": True, 
    "max_length":100,
    "min_length":50,
    "top_k": 50,
    "no_repeat_ngram_size":3,
    "top_p":0.95, 
    "num_beams":4
}

#generate predictions
reddit_pred_df = create_predictions_df(valid_ds, gen_kwargs, premises_clm='post')

In [None]:
reddit_pred_df.to_pickle('../data-sample/output/test_all_reddit_pred_test_with_sampling_4beam_df.pkl')

---------