# Example attacks against Large Language Models with TextAttack

In this notebook, we will show you how to use Textattack, a library to test the robustness of LLMs.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zangobot/teaching_material/blob/HEAD/05-LLMSecurity.ipynb)

In [None]:
!pip install textattack tensorflow tensorflow_hub

In [None]:
import transformers

An attack in TextAttack consists of four parts:

- **Goal function**: determines if the attack is successful or not: evaluates how well a perturbed attacked_text object is achieving a specified goal. A common goal function is **targeted**/**untargeted classification**, where the attack tries to perturb an input to change its classification.

- **Transformation**: takes a text input and transforms it, for example by replacing words or phrases with similar ones, while trying not to change the meaning. Paraphrase and synonym substitution are two broad classes of transformations.

- **Constraints**: determine whether or not a given transformation is valid. Transformations don't perfectly preserve syntax or semantics, so additional constraints can increase the probability that these qualities are preserved from the source to adversarial example. There are many types of constraints: overlap constraints that measure edit distance, syntactical  constraints check part-of-speech and grammar errors, and semantic constraints like language models and sentence encoders.

- **Search method**: explores the space of possible transformations within the defined constraints and attempt to find a successful perturbation which satisfies the goal function. Some examples are greedy search with word importance ranking, beam search, brute force, and genetic algorithm.


In this way, attacking an NLP model can be framed as a combinatorial search problem. The attacker must search within all potential transformations to find a sequence of transformations that generate a successful adversarial example.

TextAttack provides a set of **Attack Recipes** that assemble attacks from the literature from these four components. We can exploit an attack recipe or build an attack from scratch.


# Example 1: Attack Recipe

In [None]:
from textattack.datasets import HuggingFaceDataset

# load the dataset from HuggingFace
dataset = HuggingFaceDataset("rotten_tomatoes", None, "test")

In [None]:
# print some examples
print(dataset[0])
print(dataset[1])
print(dataset[2])
print(dataset[3])
print(dataset[4])

In [None]:
from textattack.models.wrappers import HuggingFaceModelWrapper

# load a pre-trained BERT-based model for sequence classification from the HuggingFace model hub
model = transformers.AutoModelForSequenceClassification.from_pretrained(
    "textattack/distilbert-base-uncased-rotten-tomatoes"
)

# load the corresponding tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "textattack/distilbert-base-uncased-rotten-tomatoes"
)

# TextAttack requires both the model and the tokenizer to be wrapped by a ModelWrapper class that implements the forward pass operation given a list of input texts
model_wrapper = HuggingFaceModelWrapper(model, tokenizer)

In [None]:
print("\nMODEL: ", model)
print("\nTOKENIZER: ", tokenizer)
print("\nWRAPPER: ", model_wrapper)

Let's use TextFooler, which operates at word-level.

The main idea behind **TextFooler** is to generate adversarial examples by iteratively replacing or modifying words in the input text to maximize the model's prediction error while minimizing the perturbation's perceptibility to human readers. The framework utilizes semantic and syntactic similarity metrics to guide the perturbation process, ensuring that the modified text remains grammatically and semantically coherent.

In [None]:
from textattack.attack_recipes import TextFoolerJin2019

# load the attack
attack = TextFoolerJin2019.build(model_wrapper)

In [None]:
print("\nATTACK: ", attack)

By printing the attack we can check the definition of the four components: search method, goal function, transformation, constraints.

Let's run the attack.

In [None]:
from textattack import Attacker
from textattack import AttackArgs

attack_args = AttackArgs(num_examples = 10)

attacker = Attacker(attack, dataset, attack_args)

TextAttack attacks iterate through a dataset and, for each correctly predicted sample, search for an adversarial perturbation. If an example is incorrectly predicted to begin with, it is not attacked (Skipped), since the input already fools the model.

In [None]:
attack_results = attacker.attack_dataset()

To visualize the results, we are logging AttackResult objects using a CSVLogger. This logger stores all attack results in a dataframe, which we can easily access and display. Since we set color_method to 'html', the attack results will display their differences, in color, in HTML. Using IPython utilities and pandas

In [None]:
import pandas as pd
from textattack.loggers import CSVLogger
from textattack.attack_results import SuccessfulAttackResult

# pd.options.display.max_colwidth = (
#     480  # increase colum width so we can actually read the examples
# )

logger = CSVLogger(color_method="html")

for result in attack_results:
    if isinstance(result, SuccessfulAttackResult):
        logger.log_attack_result(result)

from IPython.core.display import display, HTML

results = pd.DataFrame.from_records(logger.row_list)
display(HTML(results[["original_text", "perturbed_text"]].to_html(escape=False)))

# Example 2: Attack from scratch


In [None]:
# Load the dataset
from textattack.datasets import HuggingFaceDataset

dataset = HuggingFaceDataset("ag_news", None, "test")

In [None]:
# Load the model
import transformers
from textattack.models.wrappers import HuggingFaceModelWrapper

model = transformers.AutoModelForSequenceClassification.from_pretrained(  # loads a pre-trained BERT-based model for sequence classification from the HuggingFace model hub
    "textattack/bert-base-uncased-ag-news"
)
tokenizer = transformers.AutoTokenizer.from_pretrained( # loads the corresponding tokenizer for the BERT model
    "textattack/bert-base-uncased-ag-news"
)

model_wrapper = HuggingFaceModelWrapper(model, tokenizer)

### Banana word swap
We're going to try a very simple transformation: one that replaces any given word with the word 'banana'. In TextAttack, there's an abstract WordSwap class that handles the heavy lifting of breaking sentences into words and avoiding replacement of stopwords. We can extend WordSwap and implement a single method, _get_replacement_words, to indicate to replace each word with 'banana'. 🍌

In [None]:
from textattack.transformations import WordSwap

class BananaWordSwap(WordSwap):
    """Transforms an input by replacing any word with 'banana'."""

    def _get_replacement_words(self, word):
        """Returns a list with one item, since `_get_replacement_words` is intended to return a list of candidate replacement words."""
        return ["banana"]

In [None]:
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedySearch
from textattack.constraints.pre_transformation import (
    RepeatModification,   # constraint disallowing the modification of words which have already been modified
    StopwordModification, # constraint disallowing the modification of stopwords
)
from textattack import Attack

# Create the goal function using the model
goal_function = UntargetedClassification(model_wrapper)

# Use Banana word swap class as the attack transformation
transformation = BananaWordSwap()

# Avoid modification of already modified indices and stopwords
constraints = [RepeatModification(), StopwordModification()]

# Use the Greedy search method
search_method = GreedySearch()

# Create the attack from the 4 components:
attack = Attack(goal_function, constraints, transformation, search_method)

In [None]:
print(attack)

In [None]:
from tqdm import tqdm
from textattack import Attacker
from textattack import AttackArgs

attack_args = AttackArgs(num_examples=10)

attacker = Attacker(attack, dataset, attack_args)

attack_results = attacker.attack_dataset()

In [None]:
import pandas as pd
from textattack.loggers import CSVLogger
from textattack.attack_results import SuccessfulAttackResult

# pd.options.display.max_colwidth = (
#     480  # increase colum width so we can actually read the examples
# )

logger = CSVLogger(color_method="html")

for result in attack_results:
    if isinstance(result, SuccessfulAttackResult):
        logger.log_attack_result(result)

from IPython.core.display import display, HTML

results = pd.DataFrame.from_records(logger.row_list)
display(HTML(results[["original_text", "perturbed_text"]].to_html(escape=False)))

Looks like some examples needed only a couple "banana"s, while others needed up to 17 "banana" substitutions to change the class score.