In [None]:
import sys
import textattack
from textattack.models.wrappers import ModelWrapper
from typing import List, Tuple
from transformers import pipeline
from datasets import load_dataset
from string import punctuation
import argparse
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import os
import tarfile
import requests
from io import BytesIO
from bs4 import BeautifulSoup
import torch
import importlib.util
from pathlib import Path
import numpy as np
import pandas as pd
import shutil
import subprocess
import zipfile
import json
from collections import OrderedDict
from torch.nn.functional import softmax

In [None]:
!pip install requirements.txt

In [None]:
class EmotionWrapper(ModelWrapper):

    def __init__(self, model):
        self.model = model

    def __call__(self, input_texts: List[str]) -> List[List[float]]:

        """
        Args:
            input_texts: List[str]

        Return:
            ret: List[List[float]]
            a list of elements, one per element of input_texts. Each element is a list of probabilities, one for each label.
        """
        ret = []
        for i in input_texts:
            pred = self.model(i)[0]
            scores = []
            for j in pred:
                scores.append(j['score'])
            ret.append(scores)
        return ret

model = pipeline("text-classification", model='bhadresh-savani/distilbert-base-uncased-emotion', return_all_scores=True, device=-1)
model_wrapper = EmotionWrapper(model)

attack = textattack.attack_recipes.BadCharacters2021.build(
    model_wrapper, 
    goal_function_type="targeted_bonus",
    perturbation_type=args.perturbation_type
)
dataset = textattack.datasets.HuggingFaceDataset("emotion", split="test")
print(dataset[0])
attack_args = textattack.AttackArgs(
    num_examples=10,
    log_to_csv="results/emotion/log.csv"
)
attacker = textattack.Attacker(attack, dataset, attack_args)
attacker.attack_dataset()

if args.store_results == False:
    if os.path.isdir("results/emotion"):
        shutil.rmtree("results/emotion")
