In [None]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
import transformers, nltk, pandas as pd, torch, string
from datasets import load_dataset, load_from_disk, DatasetDict, ClassLabel
from pprint import pprint
from datetime import datetime
import argparse
import functools


from textattack import Attack, AttackArgs,Attacker
from textattack.models.wrappers import HuggingFaceModelWrapper
from textattack.datasets import HuggingFaceDataset
from textattack.loggers import CSVLogger # tracks a dataframe for us.
from textattack.attack_recipes import AttackRecipe
from textattack.search_methods import BeamSearch
from textattack.constraints import Constraint
from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
from textattack.goal_functions import UntargetedClassification
from textattack.metrics.attack_metrics.attack_success_rate import AttackSuccessRate
from textattack.metrics.attack_metrics.words_perturbed import WordsPerturbed
from textattack.metrics.attack_metrics.attack_queries import AttackQueries
from textattack.metrics.quality_metrics.perplexity import Perplexity
from textattack.metrics.quality_metrics.use import USEMetric
from sentence_transformers.util import pytorch_cos_sim

from travis_attack.utils import display_all, merge_dicts, append_df_to_csv, set_seed
from travis_attack.data import prep_dsd_rotten_tomatoes,prep_dsd_simple,prep_dsd_financial
from travis_attack.config import Config
from travis_attack.models import _prepare_vm_tokenizer_and_model, get_vm_probs, prepare_models, get_nli_probs
from travis_attack.baseline_attacks import AttackRecipes, setup_baselines_parser
from fastcore.basics import in_jupyter


import warnings
warnings.filterwarnings("ignore", message="FutureWarning: The frame.append method is deprecated") 

path_baselines = "./baselines/"

set_seed(1000)


In [None]:
!jupyter nbconvert \
    --TagRemovePreprocessor.enabled=True \
    --TagRemovePreprocessor.remove_cell_tags="['hide']" \
    --TemplateExporter.exclude_markdown=True \
    --to python "baselines.ipynb"

[NbConvertApp] Converting notebook baselines.ipynb to python


In [None]:
######### CONFIG (default values) #########
param_d = dict(
    ds_name = "financial",
    split='test',
    sts_threshold = 0.8,
    contradiction_threshold = 0.2,
    acceptability_threshold = 0.5,
    pp_letter_diff_threshold = 30
)
###########################################

if not in_jupyter():  # override with any script options
    parser = setup_baselines_parser()
    newargs = vars(parser.parse_args())
    for k,v in newargs.items(): 
        if v is not None: param_d[k] = v

In [None]:
### Common attack components
attack_recipes = AttackRecipes(param_d)
attack_list = attack_recipes.get_attack_list()

Reusing dataset financial_phrasebank (/data/tproth/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/a6d468761d4e0c8ae215c77367e1092bead39deb08fbf4bffd7c0a6991febbf0)


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

Loading cached split indices for dataset at /data/tproth/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/a6d468761d4e0c8ae215c77367e1092bead39deb08fbf4bffd7c0a6991febbf0/cache-8cd79ddb74449b16.arrow and /data/tproth/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/a6d468761d4e0c8ae215c77367e1092bead39deb08fbf4bffd7c0a6991febbf0/cache-767acf0081605736.arrow
Loading cached split indices for dataset at /data/tproth/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/a6d468761d4e0c8ae215c77367e1092bead39deb08fbf4bffd7c0a6991febbf0/cache-8bc5c0fec38cf232.arrow and /data/tproth/.cache/huggingface/datasets/financial_phrasebank/sentences_50agree/1.0.0/a6d468761d4e0c8ae215c77367e1092bead39deb08fbf4bffd7c0a6991febbf0/cache-ea0232ccd720d2dc.arrow





HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Flattening the indices', max=2.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Flattening the indices', max=1.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Flattening the indices', max=1.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Flattening the indices', max=2.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Flattening the indices', max=1.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Flattening the indices', max=1.0, style=ProgressStyle(des…

textattack: Unknown if model of class <class 'transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.





If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`


## Attack 

In [None]:
hf_dataset = HuggingFaceDataset(attack_recipes.ds.dsd_raw[param_d['split']], dataset_columns=(['text'], 'label'))
for attack_json in attack_list:
    print("Now doing attack recipe number", attack_json['attack_num'], "with code", attack_json['attack_code'])
    datetime_now = datetime.now().strftime("%Y-%m-%d_%H%M%S")
    param_d['datetime'] =  datetime_now
    param_d['attack_num'] = attack_json['attack_num']
    param_d['attack_code'] = attack_json['attack_code']
    filename = f"{path_baselines}{datetime_now}_{param_d['ds_name']}_{param_d['split']}_{attack_json['attack_code']}.csv"
    attack_args = AttackArgs(num_examples=-1, enable_advance_metrics=True,
                            log_to_csv=filename, csv_coloring_style='plain', disable_stdout=True)
    attacker = Attacker(attack_json['attack_recipe'], hf_dataset, attack_args)

    # print("Current config for attack:")
    # print(d)

    attack_results = attacker.attack_dataset()

    attack_result_metrics = {
        **AttackSuccessRate().calculate(attack_results), 
        **WordsPerturbed().calculate(attack_results),
        **AttackQueries().calculate(attack_results),
        **Perplexity().calculate(attack_results),
        **USEMetric().calculate(attack_results)
    }
    attack_result_metrics.pop('num_words_changed_until_success')
    d = merge_dicts(param_d, attack_result_metrics)
    summary_df = pd.Series(d).to_frame().T
    append_df_to_csv(summary_df, f"{path_baselines}results.csv")

textattack: Logging to CSV at path ./baselines/2022-07-04_145655_financial_test_LM-WR-BS-b2m5.csv
  0%|          | 0/3 [00:00<?, ?it/s]

Attack(
  (search_method): BeamSearch(
    (beam_width):  2
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapMaskedLM(
    (method):  bae
    (masked_lm_name):  RobertaForCausalLM
    (max_length):  512
    (max_candidates):  5
    (min_confidence):  0.0005
  )
  (constraints): 
    (0): StsScoreConstraint(
        (compare_against_original):  True
      )
    (1): ContradictionScoreConstraint(
        (compare_against_original):  True
      )
    (2): AcceptabilityScoreConstraint(
        (compare_against_original):  True
      )
    (3): PpLetterDiffConstraint(
        (compare_against_original):  True
      )
    (4): LCPConstraint(
        (compare_against_original):  True
      )
    (5): RepeatModification
    (6): StopwordModification
  (is_black_box):  True
) 



  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
[Succeeded / Failed / Skipped / Total] 2 / 1 / 0 / 3: 100%|██████████| 3/3 [00:22<00:00,  7.65s/it]



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 2      |
| Number of failed attacks:     | 1      |
| Number of skipped attacks:    | 0      |
| Original accuracy:            | 100.0% |
| Accuracy under attack:        | 33.33% |
| Attack success rate:          | 66.67% |
| Average perturbed word %:     | 26.24% |
| Average num. words per input: | 16.0   |
| Avg num queries:              | 80.33  |
| Average Original Perplexity:  | 573.64 |
| Average Attack Perplexity:    | 761.11 |
| Average Attack USE Score:     | 0.71   |
+-------------------------------+--------+


textattack: Logging to CSV at path ./baselines/2022-07-04_145745_financial_test_LM-WR-BS-b5m25.csv
  0%|          | 0/3 [00:00<?, ?it/s]

Attack(
  (search_method): BeamSearch(
    (beam_width):  5
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapMaskedLM(
    (method):  bae
    (masked_lm_name):  RobertaForCausalLM
    (max_length):  512
    (max_candidates):  25
    (min_confidence):  0.0005
  )
  (constraints): 
    (0): StsScoreConstraint(
        (compare_against_original):  True
      )
    (1): ContradictionScoreConstraint(
        (compare_against_original):  True
      )
    (2): AcceptabilityScoreConstraint(
        (compare_against_original):  True
      )
    (3): PpLetterDiffConstraint(
        (compare_against_original):  True
      )
    (4): LCPConstraint(
        (compare_against_original):  True
      )
    (5): RepeatModification
    (6): StopwordModification
  (is_black_box):  True
) 



  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
[Succeeded / Failed / Skipped / Total] 2 / 1 / 0 / 3: 100%|██████████| 3/3 [01:14<00:00, 24.72s/it]



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 2      |
| Number of failed attacks:     | 1      |
| Number of skipped attacks:    | 0      |
| Original accuracy:            | 100.0% |
| Accuracy under attack:        | 33.33% |
| Attack success rate:          | 66.67% |
| Average perturbed word %:     | 13.57% |
| Average num. words per input: | 16.0   |
| Avg num queries:              | 434.0  |
| Average Original Perplexity:  | 573.64 |
| Average Attack Perplexity:    | 825.23 |
| Average Attack USE Score:     | 0.78   |
+-------------------------------+--------+


textattack: Logging to CSV at path ./baselines/2022-07-04_145926_financial_test_LM-WR-BS-b10m50.csv
  0%|          | 0/3 [00:00<?, ?it/s]

Attack(
  (search_method): BeamSearch(
    (beam_width):  10
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapMaskedLM(
    (method):  bae
    (masked_lm_name):  RobertaForCausalLM
    (max_length):  512
    (max_candidates):  50
    (min_confidence):  0.0005
  )
  (constraints): 
    (0): StsScoreConstraint(
        (compare_against_original):  True
      )
    (1): ContradictionScoreConstraint(
        (compare_against_original):  True
      )
    (2): AcceptabilityScoreConstraint(
        (compare_against_original):  True
      )
    (3): PpLetterDiffConstraint(
        (compare_against_original):  True
      )
    (4): LCPConstraint(
        (compare_against_original):  True
      )
    (5): RepeatModification
    (6): StopwordModification
  (is_black_box):  True
) 



  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
[Succeeded / Failed / Skipped / Total] 2 / 1 / 0 / 3: 100%|██████████| 3/3 [02:27<00:00, 49.30s/it]



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 2      |
| Number of failed attacks:     | 1      |
| Number of skipped attacks:    | 0      |
| Original accuracy:            | 100.0% |
| Accuracy under attack:        | 33.33% |
| Attack success rate:          | 66.67% |
| Average perturbed word %:     | 10.63% |
| Average num. words per input: | 16.0   |
| Avg num queries:              | 907.0  |
| Average Original Perplexity:  | 573.64 |
| Average Attack Perplexity:    | 830.06 |
| Average Attack USE Score:     | 0.82   |
+-------------------------------+--------+


textattack: Logging to CSV at path ./baselines/2022-07-04_150221_financial_test_LM-WADR-BS-b5m25.csv
  0%|          | 0/3 [00:00<?, ?it/s]

Attack(
  (search_method): BeamSearch(
    (beam_width):  5
  )
  (goal_function):  UntargetedClassification
  (transformation):  CompositeTransformation(
    (0): WordSwapMaskedLM(
        (method):  bae
        (masked_lm_name):  RobertaForCausalLM
        (max_length):  512
        (max_candidates):  25
        (min_confidence):  0.0005
      )
    (1): WordInsertionMaskedLM(
        (masked_lm_name):  RobertaForCausalLM
        (max_length):  512
        (max_candidates):  25
        (min_confidence):  0.0005
      )
    (2): WordMergeMaskedLM(
        (masked_lm_name):  RobertaForCausalLM
        (max_length):  512
        (max_candidates):  25
        (min_confidence):  0.0005
      )
    )
  (constraints): 
    (0): StsScoreConstraint(
        (compare_against_original):  True
      )
    (1): ContradictionScoreConstraint(
        (compare_against_original):  True
      )
    (2): AcceptabilityScoreConstraint(
        (compare_against_original):  True
      )
    (3): PpLetterDi


  0%|          | 0/75175004 [00:00<?, ?B/s][A
  0%|          | 9216/75175004 [00:00<37:40, 33256.82B/s][A
  0%|          | 31744/75175004 [00:00<29:29, 42457.37B/s][A
  0%|          | 88064/75175004 [00:00<21:39, 57795.07B/s][A
  0%|          | 102400/75175004 [00:00<18:26, 67830.66B/s][A
  0%|          | 191488/75175004 [00:00<13:24, 93191.25B/s][A
  0%|          | 223232/75175004 [00:01<11:14, 111139.99B/s][A
  0%|          | 352256/75175004 [00:01<08:12, 151963.33B/s][A
  1%|          | 404480/75175004 [00:01<06:52, 181410.56B/s][A
  1%|          | 520192/75175004 [00:01<05:11, 239673.00B/s][A
  1%|          | 581632/75175004 [00:01<04:47, 259204.23B/s][A
  1%|          | 688128/75175004 [00:01<04:12, 294973.02B/s][A
  1%|          | 839680/75175004 [00:01<03:14, 382804.46B/s][A
  1%|          | 912384/75175004 [00:02<03:32, 350223.42B/s][A
  1%|▏         | 1015808/75175004 [00:02<03:08, 393195.69B/s][A
  2%|▏         | 1160192/75175004 [00:02<02:30, 492036.57B/s][A

 14%|█▍        | 10429440/75175004 [00:18<02:03, 522268.54B/s][A
 14%|█▍        | 10504192/75175004 [00:18<01:53, 570385.71B/s][A
 14%|█▍        | 10567680/75175004 [00:18<02:00, 536039.21B/s][A
 14%|█▍        | 10663936/75175004 [00:19<01:48, 594717.16B/s][A
 14%|█▍        | 10728448/75175004 [00:19<01:55, 559359.77B/s][A
 14%|█▍        | 10823680/75175004 [00:19<01:47, 601123.03B/s][A
 14%|█▍        | 10887168/75175004 [00:19<01:52, 571413.06B/s][A
 15%|█▍        | 10947584/75175004 [00:19<01:53, 564010.84B/s][A
 15%|█▍        | 11005952/75175004 [00:19<01:53, 564332.43B/s][A
 15%|█▍        | 11064320/75175004 [00:19<02:00, 530452.20B/s][A
 15%|█▍        | 11152384/75175004 [00:19<01:53, 562660.92B/s][A
 15%|█▍        | 11215872/75175004 [00:20<02:03, 519086.53B/s][A
 15%|█▌        | 11312128/75175004 [00:20<01:52, 565435.28B/s][A
 15%|█▌        | 11375616/75175004 [00:20<01:59, 532732.86B/s][A
 15%|█▌        | 11463680/75175004 [00:20<01:46, 599595.02B/s][A
 15%|█▌   

 25%|██▌       | 18824192/75175004 [00:33<01:32, 610438.11B/s][A
 25%|██▌       | 18887680/75175004 [00:33<01:36, 584879.95B/s][A
 25%|██▌       | 18948096/75175004 [00:33<01:39, 562361.76B/s][A
 25%|██▌       | 19006464/75175004 [00:33<01:39, 563854.78B/s][A
 25%|██▌       | 19063808/75175004 [00:33<01:40, 557371.30B/s][A
 25%|██▌       | 19143680/75175004 [00:33<01:34, 595115.37B/s][A
 26%|██▌       | 19205120/75175004 [00:33<01:35, 588788.99B/s][A
 26%|██▌       | 19265536/75175004 [00:33<01:42, 545905.09B/s][A
 26%|██▌       | 19328000/75175004 [00:34<01:40, 557953.59B/s][A
 26%|██▌       | 19391488/75175004 [00:34<01:36, 576163.65B/s][A
 26%|██▌       | 19464192/75175004 [00:34<01:33, 594011.12B/s][A
 26%|██▌       | 19524608/75175004 [00:34<01:35, 585194.25B/s][A
 26%|██▌       | 19584000/75175004 [00:34<01:42, 540155.12B/s][A
 26%|██▌       | 19647488/75175004 [00:34<01:38, 562026.37B/s][A
 26%|██▌       | 19720192/75175004 [00:34<01:33, 595511.01B/s][A
 26%|██▋  

 37%|███▋      | 27488256/75175004 [00:48<01:23, 570624.69B/s][A
 37%|███▋      | 27553792/75175004 [00:48<01:35, 496670.12B/s][A
 37%|███▋      | 27648000/75175004 [00:48<01:23, 571510.24B/s][A
 37%|███▋      | 27714560/75175004 [00:48<01:32, 510328.16B/s][A
 37%|███▋      | 27815936/75175004 [00:48<01:21, 583176.06B/s][A
 37%|███▋      | 27883520/75175004 [00:48<01:37, 485898.46B/s][A
 37%|███▋      | 27983872/75175004 [00:48<01:27, 542143.44B/s][A
 37%|███▋      | 28047360/75175004 [00:49<01:32, 507709.29B/s][A
 37%|███▋      | 28143616/75175004 [00:49<01:23, 560893.67B/s][A
 38%|███▊      | 28206080/75175004 [00:49<01:27, 536182.48B/s][A
 38%|███▊      | 28296192/75175004 [00:49<01:17, 602380.83B/s][A
 38%|███▊      | 28362752/75175004 [00:49<01:30, 515857.84B/s][A
 38%|███▊      | 28455936/75175004 [00:49<01:25, 549476.91B/s][A
 38%|███▊      | 28516352/75175004 [00:49<01:28, 530200.36B/s][A
 38%|███▊      | 28615680/75175004 [00:50<01:18, 596462.25B/s][A
 38%|███▊ 

 53%|█████▎    | 39692288/75175004 [01:11<01:09, 509623.67B/s][A
 53%|█████▎    | 39784448/75175004 [01:11<01:00, 582995.08B/s][A
 53%|█████▎    | 39853056/75175004 [01:11<01:09, 509987.15B/s][A
 53%|█████▎    | 39952384/75175004 [01:11<01:02, 559799.79B/s][A
 53%|█████▎    | 40015872/75175004 [01:12<01:07, 520049.70B/s][A
 53%|█████▎    | 40112128/75175004 [01:12<00:58, 598497.52B/s][A
 53%|█████▎    | 40180736/75175004 [01:12<01:07, 518091.09B/s][A
 54%|█████▎    | 40271872/75175004 [01:12<00:59, 589629.73B/s][A
 54%|█████▎    | 40340480/75175004 [01:12<01:08, 511996.75B/s][A
 54%|█████▍    | 40439808/75175004 [01:12<01:01, 562594.38B/s][A
 54%|█████▍    | 40503296/75175004 [01:12<01:06, 522267.33B/s][A
 54%|█████▍    | 40599552/75175004 [01:13<00:57, 599872.67B/s][A
 54%|█████▍    | 40668160/75175004 [01:13<01:06, 517708.79B/s][A
 54%|█████▍    | 40752128/75175004 [01:13<00:59, 580366.83B/s][A
 54%|█████▍    | 40818688/75175004 [01:13<01:08, 502452.15B/s][A
 54%|█████

 66%|██████▌   | 49800192/75175004 [01:28<00:42, 592314.11B/s][A
 66%|██████▋   | 49870848/75175004 [01:29<00:48, 519292.43B/s][A
 66%|██████▋   | 49959936/75175004 [01:29<00:45, 554038.03B/s][A
 67%|██████▋   | 50023424/75175004 [01:29<00:48, 515918.33B/s][A
 67%|██████▋   | 50119680/75175004 [01:29<00:42, 595109.71B/s][A
 67%|██████▋   | 50188288/75175004 [01:29<00:48, 515090.68B/s][A
 67%|██████▋   | 50287616/75175004 [01:29<00:44, 562019.53B/s][A
 67%|██████▋   | 50351104/75175004 [01:29<00:47, 524641.29B/s][A
 67%|██████▋   | 50448384/75175004 [01:30<00:43, 570001.05B/s][A
 67%|██████▋   | 50510848/75175004 [01:30<00:47, 522726.03B/s][A
 67%|██████▋   | 50608128/75175004 [01:30<00:40, 602063.69B/s][A
 67%|██████▋   | 50675712/75175004 [01:30<00:47, 516205.86B/s][A
 68%|██████▊   | 50776064/75175004 [01:30<00:43, 564504.58B/s][A
 68%|██████▊   | 50839552/75175004 [01:30<00:46, 526092.13B/s][A
 68%|██████▊   | 50935808/75175004 [01:30<00:42, 567142.69B/s][A
 68%|█████

 80%|███████▉  | 59815936/75175004 [01:46<00:27, 561850.98B/s][A
 80%|███████▉  | 59879424/75175004 [01:46<00:29, 522730.75B/s][A
 80%|███████▉  | 59975680/75175004 [01:46<00:25, 594698.71B/s][A
 80%|███████▉  | 60043264/75175004 [01:46<00:29, 518079.69B/s][A
 80%|███████▉  | 60136448/75175004 [01:46<00:25, 587287.22B/s][A
 80%|████████  | 60204032/75175004 [01:46<00:29, 514651.53B/s][A
 80%|████████  | 60304384/75175004 [01:47<00:26, 564890.72B/s][A
 80%|████████  | 60367872/75175004 [01:47<00:28, 523905.90B/s][A
 80%|████████  | 60464128/75175004 [01:47<00:24, 596457.11B/s][A
 81%|████████  | 60531712/75175004 [01:47<00:28, 519375.36B/s][A
 81%|████████  | 60623872/75175004 [01:47<00:24, 587043.28B/s][A
 81%|████████  | 60690432/75175004 [01:47<00:28, 512126.07B/s][A
 81%|████████  | 60791808/75175004 [01:47<00:25, 562172.19B/s][A
 81%|████████  | 60855296/75175004 [01:48<00:27, 524358.70B/s][A
 81%|████████  | 60951552/75175004 [01:48<00:25, 565671.71B/s][A
 81%|█████

 93%|█████████▎| 69655552/75175004 [02:03<00:09, 575350.74B/s][A
 93%|█████████▎| 69721088/75175004 [02:03<00:10, 506995.72B/s][A
 93%|█████████▎| 69816320/75175004 [02:03<00:09, 574710.89B/s][A
 93%|█████████▎| 69881856/75175004 [02:03<00:10, 506971.28B/s][A
 93%|█████████▎| 69976064/75175004 [02:03<00:09, 573661.36B/s][A
 93%|█████████▎| 70041600/75175004 [02:04<00:10, 507908.29B/s][A
 93%|█████████▎| 70135808/75175004 [02:04<00:08, 572268.59B/s][A
 93%|█████████▎| 70201344/75175004 [02:04<00:09, 507800.74B/s][A
 94%|█████████▎| 70295552/75175004 [02:04<00:08, 572587.66B/s][A
 94%|█████████▎| 70361088/75175004 [02:04<00:09, 505574.46B/s][A
 94%|█████████▎| 70448128/75175004 [02:04<00:08, 562899.46B/s][A
 94%|█████████▍| 70511616/75175004 [02:04<00:09, 496175.96B/s][A
 94%|█████████▍| 70616064/75175004 [02:04<00:07, 573641.61B/s][A
 94%|█████████▍| 70682624/75175004 [02:05<00:08, 511446.11B/s][A
 94%|█████████▍| 70775808/75175004 [02:05<00:07, 574537.42B/s][A
 94%|█████

2022-07-04 15:04:40,292 copying /tmp/tmpjt6n3dw0 to cache at /home/tproth/.flair/models/en-upos-ontonotes-fast-v0.4.pt





2022-07-04 15:04:40,438 removing temp file /tmp/tmpjt6n3dw0
2022-07-04 15:04:40,448 loading file /home/tproth/.flair/models/en-upos-ontonotes-fast-v0.4.pt


  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
[Succeeded / Failed / Skipped / Total] 3 / 0 / 0 / 3: 100%|██████████| 3/3 [06:50<00:00, 136.94s/it]



+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 3      |
| Number of failed attacks:     | 0      |
| Number of skipped attacks:    | 0      |
| Original accuracy:            | 100.0% |
| Accuracy under attack:        | 0.0%   |
| Attack success rate:          | 100.0% |
| Average perturbed word %:     | 68.76% |
| Average num. words per input: | 16.0   |
| Avg num queries:              | 2265.0 |
| Average Original Perplexity:  | 446.47 |
| Average Attack Perplexity:    | 664.48 |
| Average Attack USE Score:     | 0.82   |
+-------------------------------+--------+


textattack: Logging to CSV at path ./baselines/2022-07-04_150938_financial_test_CF-WR-BS-b5m25.csv
  0%|          | 0/3 [00:00<?, ?it/s]

Attack(
  (search_method): BeamSearch(
    (beam_width):  5
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapEmbedding(
    (max_candidates):  25
    (embedding):  WordEmbedding
  )
  (constraints): 
    (0): StsScoreConstraint(
        (compare_against_original):  True
      )
    (1): ContradictionScoreConstraint(
        (compare_against_original):  True
      )
    (2): AcceptabilityScoreConstraint(
        (compare_against_original):  True
      )
    (3): PpLetterDiffConstraint(
        (compare_against_original):  True
      )
    (4): LCPConstraint(
        (compare_against_original):  True
      )
    (5): RepeatModification
    (6): StopwordModification
  (is_black_box):  True
) 



  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
[Succeeded / Failed / Skipped / Total] 2 / 1 / 0 / 3: 100%|██████████| 3/3 [01:14<00:00, 24.92s/it]



+-------------------------------+---------+
| Attack Results                |         |
+-------------------------------+---------+
| Number of successful attacks: | 2       |
| Number of failed attacks:     | 1       |
| Number of skipped attacks:    | 0       |
| Original accuracy:            | 100.0%  |
| Accuracy under attack:        | 33.33%  |
| Attack success rate:          | 66.67%  |
| Average perturbed word %:     | 13.57%  |
| Average num. words per input: | 16.0    |
| Avg num queries:              | 362.0   |
| Average Original Perplexity:  | 573.64  |
| Average Attack Perplexity:    | 1137.85 |
| Average Attack USE Score:     | 0.83    |
+-------------------------------+---------+


textattack: Logging to CSV at path ./baselines/2022-07-04_151119_financial_test_LM-WR-GA-m25p60mi20mr5.csv
  0%|          | 0/3 [00:00<?, ?it/s]

Attack(
  (search_method): ImprovedGeneticAlgorithm(
    (pop_size):  60
    (max_iters):  20
    (temp):  0.3
    (give_up_if_no_improvement):  False
    (post_crossover_check):  False
    (max_crossover_retries):  20
    (max_replace_times_per_index):  5
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapMaskedLM(
    (method):  bae
    (masked_lm_name):  RobertaForCausalLM
    (max_length):  512
    (max_candidates):  25
    (min_confidence):  0.0005
  )
  (constraints): 
    (0): StsScoreConstraint(
        (compare_against_original):  True
      )
    (1): ContradictionScoreConstraint(
        (compare_against_original):  True
      )
    (2): AcceptabilityScoreConstraint(
        (compare_against_original):  True
      )
    (3): PpLetterDiffConstraint(
        (compare_against_original):  True
      )
    (4): LCPConstraint(
        (compare_against_original):  True
      )
    (5): RepeatModification
    (6): StopwordModification
  (is_black_box):  Tr

  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
  self.df = self.df.append(row, ignore_index=True)
[Succeeded / Failed / Skipped / Total] 2 / 1 / 0 / 3: 100%|██████████| 3/3 [03:32<00:00, 70.91s/it]



+-------------------------------+---------+
| Attack Results                |         |
+-------------------------------+---------+
| Number of successful attacks: | 2       |
| Number of failed attacks:     | 1       |
| Number of skipped attacks:    | 0       |
| Original accuracy:            | 100.0%  |
| Accuracy under attack:        | 33.33%  |
| Attack success rate:          | 66.67%  |
| Average perturbed word %:     | 20.36%  |
| Average num. words per input: | 16.0    |
| Avg num queries:              | 917.67  |
| Average Original Perplexity:  | 573.64  |
| Average Attack Perplexity:    | 1080.47 |
| Average Attack USE Score:     | 0.75    |
+-------------------------------+---------+


## Example-specific metrics 

In [None]:
def display_adv_example(df): 
    from IPython.core.display import display, HTML
    pd.options.display.max_colwidth = 480 # increase column width so we can actually read the examples
    #display(HTML(df[['original_text', 'perturbed_text']].to_html(escape=False)))
    display(df[['original_text', 'perturbed_text']])

# def add_vm_score_and_label_flip(df, dataset, cfg, vm_tokenizer, vm_model): 
#     truelabels = torch.tensor(dataset._dataset['label'], device =cfg.device)
#     orig_probs =  get_vm_probs(df['original_text'].tolist(), cfg, vm_tokenizer, vm_model, return_predclass=False)
#     pp_probs = get_vm_probs(df['perturbed_text'].tolist(), cfg, vm_tokenizer, vm_model, return_predclass=False)
#     orig_predclass = torch.argmax(orig_probs, axis=1)
#     pp_predclass = torch.argmax(pp_probs, axis=1)
#     orig_truelabel_probs = torch.gather(orig_probs, 1, truelabels[:,None]).squeeze()
#     pp_truelabel_probs   = torch.gather(pp_probs, 1,   truelabels[:,None]).squeeze()
#     pp_predclass_probs   = torch.gather(pp_probs, 1,   pp_predclass[ :,None]).squeeze()
    
#     df['truelabel'] = truelabels.cpu().tolist()
#     df['orig_predclass'] = orig_predclass.cpu().tolist()
#     df['pp_predclass'] = pp_predclass.cpu().tolist()
#     df['orig_truelabel_probs'] = orig_truelabel_probs.cpu().tolist()
#     df['pp_truelabel_probs'] = pp_truelabel_probs.cpu().tolist()
#     df['vm_scores'] = (orig_truelabel_probs - pp_truelabel_probs).cpu().tolist()
#     df['label_flip'] = ((pp_predclass != truelabels) * 1).cpu().tolist()
#     return df

# def add_sts_score(df, sts_model, cfg): 
#     orig_embeddings  = sts_model.encode(df['original_text'].tolist(),  convert_to_tensor=True, device=cfg.device)
#     pp_embeddings    = sts_model.encode(df['perturbed_text'].tolist(), convert_to_tensor=True, device=cfg.device)
#     df['sts_scores'] = pytorch_cos_sim(orig_embeddings, pp_embeddings).diagonal().cpu().tolist()
#     return df

# def add_contradiction_score(df, cfg, nli_tokenizer, nli_model): 
#     contradiction_scores = get_nli_probs(df['original_text'].tolist(), df['perturbed_text'].tolist(), cfg, nli_tokenizer, nli_model)
#     df['contradiction_scores'] =  contradiction_scores[:,cfg.contra_label].cpu().tolist()
#     return df 

# def get_df_mean_cols(df): 
#     cols = ['label_flip', 'vm_scores', 'sts_scores',
#             'contradiction_scores', 'sts_threshold_met', 'contradiction_threshold_met']
#     s = df[cols].mean()
#     s.index = [f"{o}_mean" for o in s.index]
#     return dict(s)

# def get_cts_summary_stats(df): 
#     cols = ['vm_scores', 'sts_scores', 'contradiction_scores']
#     df_summary = df[cols].describe(percentiles=[.1,.25,.5,.75,.9]).loc[['std','10%','25%','50%','75%','90%']]
#     tmp_d = dict()
#     for c in cols: 
#         s = df_summary[c]
#         s.index = [f"{c}_{o}" for o in s.index]
#         tmp_d = merge_dicts(tmp_d, dict(s))
#     return tmp_d


In [None]:
#filename1 = f"/data/tproth/travis_attack/baselines/2022-04-21_044443_rotten_tomatoes_valid_BeamSearchLMAttack_beam_sz=2_max_candidates=5.csv"
#filename = filename1
df = pd.read_csv(filename)
#display_adv_example(df)

#df = add_vm_score_and_label_flip(df, dataset, cfg, vm_tokenizer, vm_model)
#df = df.query("result_type != 'Skipped'")
#df = add_sts_score(df, sts_model, cfg)
#df = add_contradiction_score(df, cfg, nli_tokenizer, nli_model)

#df['sts_threshold_met'] = df['sts_scores'] > d['sts_threshold']
#df['contradiction_threshold_met'] = df['contradiction_scores'] < d['contradiction_threshold']
#df.to_csv(f"{filename[:-4]}_processed.csv", index=False)

#d = merge_dicts(d, get_df_mean_cols(df))
#d = merge_dicts(d, get_cts_summary_stats(df))



## Old code 

In [None]:
# df1 = df.sample(5)
# orig_l = df1['original_text'].tolist()
# pp_l = df1['perturbed_text'].tolist()
# print(orig_l)
# print(pp_l)

In [None]:
# for orig, adv in zip(df1['original_text'].tolist(), df1['perturbed_text'].tolist()): 
#     print(f"{orig}{adv}")
#     print()

In [None]:
#df.iloc[104][['original_text', 'perturbed_text']].values

In [None]:
#filename1 = f"/data/tproth/travis_attack/baselines/2022-04-20_133329_rotten_tomatoes_valid_BeamSearchCFEmbeddingAttack_beam_sz=1_max_candidates=1_processed.csv"
#df = pd.read_csv(filename1)
#display_all(df.sample(2))

In [None]:
#df_results = pd.read_csv(f"/data/tproth/travis_attack/baselines/results.csv")