In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
import transformers, nltk, pandas as pd, torch
from datasets import load_dataset, load_from_disk, DatasetDict, ClassLabel
from pprint import pprint
from datetime import datetime
import argparse
import functools


from textattack import Attack, AttackArgs,Attacker
from textattack.models.wrappers import HuggingFaceModelWrapper
from textattack.datasets import HuggingFaceDataset
from textattack.loggers import CSVLogger # tracks a dataframe for us.
from textattack.attack_recipes import AttackRecipe
from textattack.search_methods import BeamSearch
from textattack.constraints import Constraint
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
from textattack.goal_functions import UntargetedClassification
from textattack.metrics.attack_metrics.attack_success_rate import AttackSuccessRate
from textattack.metrics.attack_metrics.words_perturbed import WordsPerturbed
from textattack.metrics.attack_metrics.attack_queries import AttackQueries
from textattack.metrics.quality_metrics.perplexity import Perplexity
from textattack.metrics.quality_metrics.use import USEMetric
from sentence_transformers.util import pytorch_cos_sim

from travis_attack.utils import display_all, merge_dicts, append_df_to_csv
from travis_attack.data import prep_dsd_rotten_tomatoes,prep_dsd_simple,prep_dsd_financial
from travis_attack.config import Config
from travis_attack.models import _prepare_vm_tokenizer_and_model, get_vm_probs, prepare_models, get_nli_probs
from fastcore.basics import in_jupyter

path_baselines = "./baselines/"
datetime_now = datetime.now().strftime("%Y-%m-%d_%H%M%S")


In [None]:
!jupyter nbconvert \
    --TagRemovePreprocessor.enabled=True \
    --TagRemovePreprocessor.remove_cell_tags="['hide']" \
    --TemplateExporter.exclude_markdown=True \
    --to python "baselines.ipynb"

[NbConvertApp] Converting notebook baselines.ipynb to python


In [None]:
def setup_baselines_parser(): 
    parser = argparse.ArgumentParser()
    parser.add_argument("--ds_name")
    parser.add_argument("--split")
    parser.add_argument("--attack_name")
    parser.add_argument("--num_examples", type=int)
    parser.add_argument("--beam_sz", type=int)
    parser.add_argument("--max_candidates", type=int)
    parser.add_argument("--sts_threshold", type=float)
    parser.add_argument("--contradiction_threshold", type=float)
    #parser.add_argument('args', nargs=argparse.REMAINDER)  # activate to put keywords in kwargs.
    return parser

In [None]:
######### CONFIG (default values) #########
d = dict(
    datetime=datetime_now,
    ds_name = "rotten_tomatoes",
    split = 'valid',
    attack_name = 'BeamSearchCFEmbeddingAttack',
    num_examples = -1,
    beam_sz = 1,
    max_candidates = 2,
    sts_threshold = 0.9,
    contradiction_threshold = 0.1
)
###########################################

if not in_jupyter():  # override with any script options
    parser = setup_baselines_parser()
    newargs = vars(parser.parse_args())
    for k,v in newargs.items(): 
        if v is not None: d[k] = v

In [None]:
class StsScoreConstraint(Constraint): 
    def __init__(self, sts_model, sts_threshold): 
        super().__init__(True)  # need the true here to compare against original (as opposed to previous x') I think
        self.sts_threshold = sts_threshold
        self.sts_model     = sts_model
        
    @functools.lru_cache(maxsize=2**14)
    def get_embedding(self, text):  return self.sts_model.encode(text)
    
    def _check_constraint(self, transformed_text, current_text):
        orig_embedding = self.get_embedding(current_text.text)
        pp_embedding   = self.get_embedding(transformed_text.text)
        sts_score = pytorch_cos_sim(orig_embedding, pp_embedding).item()
        if sts_score > self.sts_threshold:   return True 
        else:                                return False
        

class ContradictionScoreConstraint(Constraint): 
    def __init__(self, cfg, nli_tokenizer, nli_model, contradiction_threshold): 
        super().__init__(True) 
        self.cfg = cfg 
        self.nli_tokenizer = nli_tokenizer
        self.nli_model     = nli_model
        self.contradiction_threshold = contradiction_threshold
        
    def _check_constraint(self, transformed_text, current_text):
        orig =     current_text.text
        pp   = transformed_text.text
        contradiction_score = get_nli_probs(orig, pp, self.cfg, self.nli_tokenizer, self.nli_model).cpu()[0][0].item()
        if contradiction_score < self.contradiction_threshold:   return True 
        else:                                                    return False
    

class BeamSearchCFEmbeddingAttack(AttackRecipe):
    """Untarged classification + word embedding swap + [no repeat, no stopword, STS, contradiction score] constraints + beam search"""
    @staticmethod
    def build(model_wrapper, cfg,  sts_model, nli_tokenizer, nli_model, beam_sz=2, 
              max_candidates=5, sts_threshold=0.6, contradiction_threshold=0.8):
        goal_function = UntargetedClassification(model_wrapper)
        stopwords = nltk.corpus.stopwords.words("english") # The one used by default in textattack
        constraints = [RepeatModification(),
                       StopwordModification(stopwords), 
                       StsScoreConstraint(sts_model, sts_threshold), 
                       ContradictionScoreConstraint(cfg, nli_tokenizer, nli_model, contradiction_threshold)]
        transformation = WordSwapEmbedding(max_candidates=max_candidates)
        search_method = BeamSearch(beam_width=beam_sz)
        attack = Attack(goal_function, constraints, transformation, search_method)
        return attack

    
class BeamSearchLMAttack(AttackRecipe): 
    """Untarged classification + language model word swap + [no repeat, no stopword, STS, contradiction score] constraints + beam search"""
    @staticmethod
    def build(model_wrapper, cfg,  sts_model, nli_tokenizer, nli_model, beam_sz=2, 
              max_candidates=5, sts_threshold=0.6, contradiction_threshold=0.8):
        stopwords = nltk.corpus.stopwords.words("english") # The one used by default in textattack
        goal_function = UntargetedClassification(model_wrapper)
        constraints = [RepeatModification(),
                       StopwordModification(stopwords), 
                       StsScoreConstraint(sts_model, sts_threshold), 
                       ContradictionScoreConstraint(cfg, nli_tokenizer, nli_model, contradiction_threshold)]
        transformation = WordSwapMaskedLM(method='bae', masked_language_model='distilroberta-base', max_candidates=max_candidates)
        search_method = BeamSearch(beam_width=beam_sz)
        attack = Attack(goal_function, constraints, transformation, search_method)
        return attack

    

In [None]:
if   d['attack_name'] == 'BeamSearchLMAttack':          attack_recipe = BeamSearchLMAttack
elif d['attack_name'] == 'BeamSearchCFEmbeddingAttack': attack_recipe = BeamSearchCFEmbeddingAttack
filename = f"{path_baselines}{datetime_now}_{d['ds_name']}_{d['split']}_{d['attack_name']}_beam_sz={d['beam_sz']}_max_candidates={d['max_candidates']}.csv"

## Attack 

In [None]:
if d['ds_name'] == "financial_phrasebank":
    cfg = Config().adjust_config_for_financial_dataset()
    dsd = prep_dsd_financial(cfg)
elif d['ds_name'] == "rotten_tomatoes":      
    cfg = Config().adjust_config_for_rotten_tomatoes_dataset()
    dsd = prep_dsd_rotten_tomatoes(cfg)
elif d['ds_name'] == "simple":      
    cfg = Config().adjust_config_for_simple_dataset()
    dsd = prep_dsd_simple(cfg)
    #dataset = ...
dataset = HuggingFaceDataset(dsd[d['split']])


In [None]:
vm_tokenizer, vm_model, _,_,_, sts_model, nli_tokenizer, nli_model, cfg = prepare_models(cfg)
vm_model_wrapper = HuggingFaceModelWrapper(vm_model, vm_tokenizer)
attack = attack_recipe.build(vm_model_wrapper, cfg,  sts_model, nli_tokenizer, nli_model,
           beam_sz=d['beam_sz'], max_candidates=d['max_candidates'], sts_threshold=d['sts_threshold'],
           contradiction_threshold=d['contradiction_threshold']
)
attack_args = AttackArgs(num_examples=d['num_examples'], enable_advance_metrics=True,
                        log_to_csv=filename, csv_coloring_style='plain', disable_stdout=True)
attacker = Attacker(attack, dataset, attack_args)

Some weights of the model checkpoint at microsoft/deberta-base-mnli were not used when initializing DebertaForSequenceClassification: ['config']
- This IS expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
textattack: Unknown if model of class <class 'transformers.models.distilbert.modeling_distilbert.DistilBertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.


In [None]:
print("Current config for attack:")
print(d)

Current config for attack:
{'datetime': '2022-04-20_132518', 'ds_name': 'rotten_tomatoes', 'split': 'valid', 'attack_name': 'BeamSearchCFEmbeddingAttack', 'num_examples': 4, 'beam_sz': 1, 'max_candidates': 1, 'sts_threshold': 0.6, 'contradiction_threshold': 0.8}


In [None]:
attack_results = attacker.attack_dataset()

textattack: Logging to CSV at path ./baselines/2022-04-20_132518_rotten_tomatoes_valid_BeamSearchCFEmbeddingAttack_beam_sz=1_max_candidates=1.csv
  0%|          | 0/4 [00:00<?, ?it/s]

Attack(
  (search_method): BeamSearch(
    (beam_width):  1
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapEmbedding(
    (max_candidates):  1
    (embedding):  WordEmbedding
  )
  (constraints): 
    (0): StsScoreConstraint(
        (compare_against_original):  True
      )
    (1): ContradictionScoreConstraint(
        (compare_against_original):  True
      )
    (2): RepeatModification
    (3): StopwordModification
  (is_black_box):  True
) 



[Succeeded / Failed / Skipped / Total] 1 / 3 / 0 / 4: 100%|██████████| 4/4 [00:05<00:00,  1.36s/it]



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 1      |
| Number of failed attacks:     | 3      |
| Number of skipped attacks:    | 0      |
| Original accuracy:            | 100.0% |
| Accuracy under attack:        | 75.0%  |
| Attack success rate:          | 25.0%  |
| Average perturbed word %:     | 11.11% |
| Average num. words per input: | 16.0   |
| Avg num queries:              | 54.5   |
| Average Original Perplexity:  | 55.91  |
| Average Attack Perplexity:    | 619.99 |
| Average Attack USE Score:     | 0.81   |
+-------------------------------+--------+


## General metrics 

In [None]:
attack_result_metrics = {
    **AttackSuccessRate().calculate(attack_results), 
    **WordsPerturbed().calculate(attack_results),
    **AttackQueries().calculate(attack_results),
    **Perplexity().calculate(attack_results),
    **USEMetric().calculate(attack_results)
}
attack_result_metrics.pop('num_words_changed_until_success')
d = merge_dicts(d, attack_result_metrics)

## Example-specific metrics 

In [None]:
def display_adv_example(df): 
    from IPython.core.display import display, HTML
    pd.options.display.max_colwidth = 480 # increase column width so we can actually read the examples
    #display(HTML(df[['original_text', 'perturbed_text']].to_html(escape=False)))
    display(df[['original_text', 'perturbed_text']])

def add_vm_score_and_label_flip(df, dataset, cfg, vm_tokenizer, vm_model): 
    truelabels = torch.tensor(dataset._dataset['label'], device =cfg.device)
    orig_probs =  get_vm_probs(df['original_text'].tolist(), cfg, vm_tokenizer, vm_model, return_predclass=False)
    pp_probs = get_vm_probs(df['perturbed_text'].tolist(), cfg, vm_tokenizer, vm_model, return_predclass=False)
    orig_predclass = torch.argmax(orig_probs, axis=1)
    pp_predclass = torch.argmax(pp_probs, axis=1)
    orig_truelabel_probs = torch.gather(orig_probs, 1, truelabels[:,None]).squeeze()
    pp_truelabel_probs   = torch.gather(pp_probs, 1,   truelabels[:,None]).squeeze()
    pp_predclass_probs   = torch.gather(pp_probs, 1,   pp_predclass[ :,None]).squeeze()
    
    df['truelabel'] = truelabels.cpu().tolist()
    df['orig_predclass'] = orig_predclass.cpu().tolist()
    df['pp_predclass'] = pp_predclass.cpu().tolist()
    df['orig_truelabel_probs'] = orig_truelabel_probs.cpu().tolist()
    df['pp_truelabel_probs'] = pp_truelabel_probs.cpu().tolist()
    df['vm_scores'] = (orig_truelabel_probs - pp_truelabel_probs).cpu().tolist()
    df['label_flip'] = ((pp_predclass != truelabels) * 1).cpu().tolist()
    return df

def add_sts_score(df, sts_model, cfg): 
    orig_embeddings  = sts_model.encode(df['original_text'].tolist(),  convert_to_tensor=True, device=cfg.device)
    pp_embeddings    = sts_model.encode(df['perturbed_text'].tolist(), convert_to_tensor=True, device=cfg.device)
    df['sts_scores'] = pytorch_cos_sim(orig_embeddings, pp_embeddings).diagonal().cpu().tolist()
    return df

def add_contradiction_score(df, cfg, nli_tokenizer, nli_model): 
    contradiction_scores = get_nli_probs(df['original_text'].tolist(), df['perturbed_text'].tolist(), cfg, nli_tokenizer, nli_model)
    df['contradiction_scores'] =  contradiction_scores[:,0].cpu().tolist()
    return df 

def get_df_mean_cols(df): 
    cols = ['label_flip', 'vm_scores', 'sts_scores',
            'contradiction_scores', 'sts_threshold_met', 'contradiction_threshold_met']
    s = df[cols].mean()
    s.index = [f"{o}_mean" for o in s.index]
    return dict(s)

def get_cts_summary_stats(df): 
    cols = ['vm_scores', 'sts_scores', 'contradiction_scores']
    df_summary = df[cols].describe(percentiles=[.1,.25,.5,.75,.9]).loc[['std','10%','25%','50%','75%','90%']]
    tmp_d = dict()
    for c in cols: 
        s = df_summary[c]
        s.index = [f"{c}_{o}" for o in s.index]
        tmp_d = merge_dicts(tmp_d, dict(s))
    return tmp_d


In [None]:
#filename1 = f"/data/tproth/travis_attack/baselines/2022-04-21_044443_rotten_tomatoes_valid_BeamSearchLMAttack_beam_sz=2_max_candidates=5.csv"
#filename = filename1
df = pd.read_csv(filename)
#display_adv_example(df)
df = add_vm_score_and_label_flip(df, dataset, cfg, vm_tokenizer, vm_model)
df = df.query("result_type != 'Skipped'")
df = add_sts_score(df, sts_model, cfg)
df = add_contradiction_score(df, cfg, nli_tokenizer, nli_model)

df['sts_threshold_met'] = df['sts_scores'] > d['sts_threshold']
df['contradiction_threshold_met'] = df['contradiction_scores'] < d['contradiction_threshold']
df.to_csv(f"{filename[:-4]}_processed.csv", index=False)

d = merge_dicts(d, get_df_mean_cols(df))
d = merge_dicts(d, get_cts_summary_stats(df))

summary_df = pd.Series(d).to_frame().T
append_df_to_csv(summary_df, f"{path_baselines}results.csv")

NameError: name 'dataset' is not defined

## Old code 

In [None]:
# df1 = df.sample(5)
# orig_l = df1['original_text'].tolist()
# pp_l = df1['perturbed_text'].tolist()
# print(orig_l)
# print(pp_l)

In [None]:
# for orig, adv in zip(df1['original_text'].tolist(), df1['perturbed_text'].tolist()): 
#     print(f"{orig}{adv}")
#     print()

In [None]:
#df.iloc[104][['original_text', 'perturbed_text']].values

In [None]:
#filename1 = f"/data/tproth/travis_attack/baselines/2022-04-20_133329_rotten_tomatoes_valid_BeamSearchCFEmbeddingAttack_beam_sz=1_max_candidates=1_processed.csv"
#df = pd.read_csv(filename1)
#display_all(df.sample(2))

In [None]:
#df_results = pd.read_csv(f"/data/tproth/travis_attack/baselines/results.csv")