# Attack RobEn Extension

Here we load in a model trained using clustering similar to RobEn to be robust against synonym based substitution attacks. We then perform attacks to test the performance of the model 

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/QData/TextAttack/blob/master/docs/2notebook/Example_5_Explain_BERT.ipynb)

[![View Source on GitHub](https://img.shields.io/badge/github-view%20source-black.svg)](https://github.com/QData/TextAttack/blob/master/docs/2notebook/Example_5_Explain_BERT.ipynb)

## Pip installs

In [None]:
!pip install captum
!pip install textattack
!pip install torchfile
!pip install allennlp
!pip install pytorch_transformers 
!pip install PreTrainedTokenizer
!pip install --upgrade flair

In [None]:
from captum.attr import visualization as viz
from textattack.datasets import HuggingFaceDataset
from textattack.models.tokenizers import AutoTokenizer
from textattack.models.wrappers import ModelWrapper, HuggingFaceModelWrapper
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from IPython.display import display, HTML
from collections import OrderedDict

import torch
import numpy as np

from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset
'''
from utils_glue import InputExample
from utils import Clustering
from recoverer import ClusterRepRecoverer
'''

## Make GPU available

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(device)

## Load our model and dataset

In [None]:
dataset = load_dataset('glue', 'mrpc', split='validation')
dir = 'MRPC_3_No_Stop'
clustering = Clustering.from_pickle(dir+'/vocab100000_ed1.pkl', max_num_possibilities=10)
recoverer = ClusterRepRecoverer(dir, clustering)
original_model = BertForSequenceClassification.from_pretrained(dir)
original_tokenizer = BertTokenizer.from_pretrained(dir)
model_wrapper = HuggingFaceModelWrapper(original_model,original_tokenizer)
model_wrapper.model.to(device)

In [None]:
new_dataset = []
for x in dataset:
  o = OrderedDict([("s1", x['sentence1']), ("s2", x['sentence2'])])
  new_dataset.append([x['sentence1'],x['label']])
print(new_dataset[0])

## Define some useful functions

In [None]:
def captum_form(encoded):
    input_dict = {k: [_dict[k] for _dict in encoded] for k in encoded[0]}
    batch_encoded = {k: torch.tensor(v).to(device) for k, v in input_dict.items()}
    return batch_encoded


def calculate(input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
    return model_wrapper.model(
        input_ids,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        attention_mask=attention_mask,
    ).logits


def display_html(html_str):
    display(HTML(html_str))

## Pick an Attribution Algorithm

In [None]:
from captum.attr import LayerIntegratedGradients

# more algorithms are avaliable at:
# https://github.com/pytorch/captum/blob/master/docs/algorithms_comparison_matrix.md

lig = LayerIntegratedGradients(calculate, model_wrapper.model.bert.embeddings)

## Pick an Attack Algorithm

In [None]:
from textattack.attack_recipes import TextFoolerJin2019
attack = TextFoolerJin2019.build(model_wrapper)

In [None]:
from textattack.attack_results import FailedAttackResult
results_iterable = attack.attack_dataset(new_dataset)

viz_list = []

correct = 0
total = 0

for n, result in enumerate(results_iterable):
    s1 = result.original_text().split('\n')[0][4:]
    s2 = result.original_text().split('\n')[1][4:]
    ex = recoverer.recover_example(InputExample(guid=n,text_a=s1,text_b=s2,label=new_dataset[n][1]))
    s1_pert = result.perturbed_text().split('\n')[0][4:]
    s2_pert = result.perturbed_text().split('\n')[1][4:]
    ex_pert = recoverer.recover_example(InputExample(guid=n,text_a=s1_pert,text_b=s2_pert,label=new_dataset[n][1]))

    orig = "S1: " + ex.text_a +"\n" + "S2: " + ex.text_b
    pert = "S1: " + ex_pert.text_a +"\n" + "S2: " + ex_pert.text_b
    #orig = result.original_text()
    print('Original Sentences: ', orig)
    #pert = result.perturbed_text()
    print('Perturbed Sentences: ', pert)

    # get prediction
    encoded = model_wrapper.tokenizer.batch_encode([orig])
    batch_encoded = captum_form(encoded)
    logit = calculate(**batch_encoded).detach().cpu().numpy()

    pert_encoded = model_wrapper.tokenizer.batch_encode([pert])
    pert_batch_encoded = captum_form(pert_encoded)
    logit_pert = calculate(**pert_batch_encoded).detach().cpu().numpy()

    print(np.argmax(logit))
    print(np.argmax(logit_pert))
    print('Original label: ', new_dataset[n][1])
    if (np.argmax(logit_pert) == new_dataset[n][1]):
      correct += 1
    total += 1
print('Number Correct: ', correct)
print('Total: ', total)