<a href="https://colab.research.google.com/github/riccardocappi/Text-Adversarial-Attack/blob/adversarial-training/adversarial_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Dependencies

In [None]:
!pip install textattack

## Imports

In [None]:
import textattack
import transformers
from textattack.datasets import HuggingFaceDataset
from textattack import Attacker
from textattack.attack_recipes import BAEGarg2019

# Helper methods

In [None]:
class FixedHuggingFaceDataset(HuggingFaceDataset):
    def __init__(self, name_or_dataset, subset=None, split="train", dataset_columns=None, label_map=None,
                 label_names=None, output_scale_factor=None, shuffle=False, seed=69, subset_size=None, offset=0):
        super().__init__(name_or_dataset=name_or_dataset, subset=subset, split=split, dataset_columns=dataset_columns,
                         label_map=label_map, label_names=label_names, output_scale_factor=output_scale_factor,
                         shuffle=shuffle)
        if shuffle:
            self._dataset = self._dataset.shuffle(seed=seed).flatten_indices()
        if subset_size is not None:
            self._dataset = self._dataset.skip(offset).take(subset_size)

# Loading model

In [None]:
model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

# Loading dataset

In [None]:
subset_train = FixedHuggingFaceDataset("imdb", split="train", subset_size=1024, shuffle=True)
subset_eval = FixedHuggingFaceDataset("imdb", split="test", subset_size=64, shuffle=True)

# Adversarial training BAEGarg

In [None]:
attack = BAEGarg2019.build(model_wrapper)

### Attack before model training

In [None]:
attack_args = textattack.AttackArgs(num_examples=-1, parallel=True, disable_stdout=True)
attacker = Attacker(attack, subset_eval, attack_args)
adv_exp_bae_bert_imdb = attacker.attack_dataset()

### Define adversarial training hyperparams

In [None]:
training_args = textattack.TrainingArgs(
    num_epochs=5,
    num_clean_epochs=0,
    num_train_adv_examples=128,
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=1,
    attack_epoch_interval=1,
    log_to_tb=True,
    parallel=True
)

trainer = textattack.Trainer(
    model_wrapper,
    "classification",
    attack,
    subset_train,
    subset_eval,
    training_args
)

### Adversarial training

In [None]:
trainer.train()

# Check trained model accuracy

In [None]:
fine_tuned_model = transformers.AutoModelForSequenceClassification.from_pretrained("outputs/BAEGarg/best_model")
fine_tuned_tokenizer = transformers.AutoTokenizer.from_pretrained("outputs/BAEGarg/best_model")
model_wrapper_fine_tuned = textattack.models.wrappers.HuggingFaceModelWrapper(fine_tuned_model, fine_tuned_tokenizer)
fine_tuned_attack = BAEGarg2019.build(model_wrapper_fine_tuned)

In [None]:
fine_tuned_attack_args = textattack.AttackArgs(num_examples=-1, parallel=True, disable_stdout=True)
fine_tuned_attacker = Attacker(fine_tuned_attack, subset_eval, fine_tuned_attack_args)
fine_tuned_eval = fine_tuned_attacker.attack_dataset()