In [8]:
import os
import warnings
warnings.filterwarnings('ignore')
from textattack.attack_recipes import TextFoolerJin2019
from textattack.models.wrappers import HuggingFaceModelWrapper
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from textattack import Attacker, AttackArgs
from textattack.datasets import HuggingFaceDataset
import pandas as pd

In [None]:
os.system("textattack train "
          "--model bert-base-uncased "
          "--dataset imdb "
          "--model-max-length 128 "
          "--per-device-train-batch-size 32 "
          "--num-epochs 3 "
          "--output-dir ./exp")

In [9]:
def get_model_attack(dataset_name, model_type):
    # Load IMDb/SST-2 model
    if model_type == "bert":
        if dataset_name == "imdb":
            model_name = "textattack/bert-base-uncased-imdb"  # Pretrained BERT on IMDb
        else:
            model_name = "textattack/bert-base-uncased-SST-2" # Pretrained BERT on SST2
            
    elif model_type == "roberta":
        if dataset_name == "imdb":
            model_name = "textattack/roberta-base-imdb"  # Pretrained roBERTa on IMDb
        else:
            model_name = "textattack/roberta-base-SST-2" # Pretrained roBERTa on SST2
        
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
    attack = TextFoolerJin2019.build(model_wrapper)
    return attack
    

def get_dataset(dataset_name):
    if dataset_name == "imdb":
        return HuggingFaceDataset("imdb", split="test")
    else:
        return HuggingFaceDataset("glue", "sst2", split="validation")
    
def evaluate(dataset_name):
    df_db = pd.read_csv(f"{model_type}_{dataset_name}_textfooler_results.csv")
    success_rate = df_db["result_type"].value_counts(normalize=True).get("Successful", 0.0)
    print(f"{dataset_name} TextFooler Success Rate: {success_rate:.2%}")
    
def evasion_eval(dataset_name, model_type):
    attack = get_model_attack(dataset_name, model_type)
    db = get_dataset(dataset_name)

    attack_args = AttackArgs(
        num_examples=20,       # Number of samples to attack
        disable_stdout=True,
        log_to_csv=f"{model_type}_{dataset_name}_textfooler_results.csv"
    )

    attacker = Attacker(attack, db, attack_args)
    attacker.attack_dataset()

    evaluate(dataset_name)

In [10]:
dataset_name = "imdb" # or imdb
model_type = "bert" #or roberta
   
evasion_eval(dataset_name, model_type)

textattack: Unknown if model of class <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.
textattack: Loading [94mdatasets[0m dataset [94mimdb[0m, split [94mtest[0m.
textattack: Logging to CSV at path bert_imdb_textfooler_results.csv


Attack(
  (search_method): GreedyWordSwapWIR(
    (wir_method):  delete
  )
  (goal_function):  UntargetedClassification
  (transformation):  WordSwapEmbedding(
    (max_candidates):  50
    (embedding):  WordEmbedding
  )
  (constraints): 
    (0): WordEmbeddingDistance(
        (embedding):  WordEmbedding
        (min_cos_sim):  0.5
        (cased):  False
        (include_unknown_words):  True
        (compare_against_original):  True
      )
    (1): PartOfSpeech(
        (tagger_type):  nltk
        (tagset):  universal
        (allow_verb_noun_swap):  True
        (compare_against_original):  True
      )
    (2): UniversalSentenceEncoder(
        (metric):  angular
        (threshold):  0.840845057
        (window_size):  15
        (skip_text_shorter_than_window):  True
        (compare_against_original):  False
      )
    (3): RepeatModification
    (4): StopwordModification
    (5): InputColumnModification(
        (matching_column_labels):  ['premise', 'hypothesis']
       



  0%|                                                               | 0/20 [00:00<?, ?it/s][A[A

  5%|██▊                                                    | 1/20 [01:04<20:31, 64.80s/it][A[A

[Succeeded / Failed / Skipped / Total] 1 / 0 / 0 / 1:   5%| | 1/20 [01:04<20:31, 64.82s/it][A[A

[Succeeded / Failed / Skipped / Total] 1 / 0 / 0 / 1:  10%| | 2/20 [01:08<10:20, 34.50s/it][A[A

[Succeeded / Failed / Skipped / Total] 2 / 0 / 0 / 2:  10%| | 2/20 [01:09<10:21, 34.51s/it][A[A

[Succeeded / Failed / Skipped / Total] 2 / 0 / 0 / 2:  15%|▏| 3/20 [01:11<06:43, 23.74s/it][A[A

[Succeeded / Failed / Skipped / Total] 3 / 0 / 0 / 3:  15%|▏| 3/20 [01:11<06:43, 23.75s/it][A[A

[Succeeded / Failed / Skipped / Total] 3 / 0 / 0 / 3:  20%|▏| 4/20 [01:41<06:44, 25.30s/it][A[A

[Succeeded / Failed / Skipped / Total] 4 / 0 / 0 / 4:  20%|▏| 4/20 [01:41<06:44, 25.31s/it][A[A

[Succeeded / Failed / Skipped / Total] 4 / 0 / 1 / 5:  25%|▎| 5/20 [01:41<05:03, 20.25s/it][A[A

[Succeed


+-------------------------------+--------+
| Attack Results                |        |
+-------------------------------+--------+
| Number of successful attacks: | 18     |
| Number of failed attacks:     | 0      |
| Number of skipped attacks:    | 2      |
| Original accuracy:            | 90.0%  |
| Accuracy under attack:        | 0.0%   |
| Attack success rate:          | 100.0% |
| Average perturbed word %:     | 11.04% |
| Average num. words per input: | 203.75 |
| Avg num queries:              | 624.67 |
+-------------------------------+--------+
imdb TextFooler Success Rate: 90.00%



