## Setup

If using Google Colab (recommended), run these setup cells to download the rest of the repository.

In [None]:
!git clone https://github.com/smcaleese/text-counterfactual-explanation-methods

In [None]:
%cd text-counterfactual-explanation-methods
%pwd

Install necessary dependencies.

In [None]:
%pip install transformers datasets textdistance openai

## Download datasets

Download the SST-2 and QNLI datasets, clean the sentences, and create a list of input sentences.

Select a dataset and the number of samples to use.

In [None]:
num_samples = 1000

dataset = "sst_2"
# dataset = "qnli"

In [None]:
import re

def format_sentence(sentence, dataset):
    sentence = sentence.lower()

    # remove two spaces around a comma:
    sentence = re.sub(r"\s(')\s(ve|re|s|t|ll|d)", r"\1\2", sentence)

    # remove spaces around hyphens:
    sentence = re.sub(r"-\s-", "--", sentence)
    sentence = re.sub(r"(\w)\s-\s(\w)", r"\1-\2", sentence)

    def replace(match):
        return match.group(1)

    # remove spaces before punctuation and n't:
    sentence = re.sub(r"\s([.!,?:;')]|n't)", replace, sentence)

    # remove spaces after opening parenthesis:
    sentence = re.sub(r"([(])\s", replace, sentence)

    if dataset == "qnli":
        sentence = re.sub(r"\s(\[sep\])\s", " [SEP] ", sentence)
    
    return sentence

In [None]:
from datasets import load_dataset

if dataset == "sst_2":
    sst = load_dataset("stanfordnlp/sst2")

    sst_sentences = sst["train"]["sentence"]
    sst_labels = sst["train"]["label"]

    sst_sentences_subset = sst_sentences[:num_samples]
    sst_labels_subset = sst_labels[:num_samples]

    sst_sentences_subset_formatted = [format_sentence(sentence, dataset) for sentence in sst_sentences_subset]

elif dataset == "qnli":
    qnli = load_dataset("glue", "qnli")

    qnli_questions = qnli["train"]["question"]
    qnli_answers = qnli["train"]["sentence"]
    qnli_labels = qnli["train"]["label"]

    qnli_questions_subset = qnli_questions[:num_samples]
    qnli_answers_subset = qnli_answers[:num_samples]
    qnli_labels_subset = qnli_labels[:num_samples]

    qnli_questions_subset_formatted = [format_sentence(sentence, dataset) for sentence in qnli_questions_subset]
    qnli_answers_subset_formatted = [format_sentence(sentence, dataset) for sentence in qnli_answers_subset]


Write the sentences to a file named `sst-input.csv` or `qnli-input.csv`.

In [None]:
%pwd

In [None]:
import pandas as pd

if dataset == "sst_2":
    df_sst = pd.DataFrame({
        "original_text": sst_sentences_subset_formatted,
        "original_label": sst_labels_subset
    })
    df_sst.to_csv(f"./input/{dataset}-input.csv", index=False)

elif dataset == "qnli":
    df_qnli = pd.DataFrame({
        "original_question": qnli_questions_subset_formatted,
        "original_answer": qnli_answers_subset_formatted,
        "original_label": qnli_labels_subset
    })
    df_qnli.to_csv(f"./input/{dataset}-input.csv", index=False)

## Choose dataset

In [None]:
if dataset == "sst_2":
    input_file = f"{dataset}-input"
    model_name = "textattack/bert-base-uncased-SST-2"
    fizle_task = "sentiment analysis on the SST-2 dataset"
elif dataset == "qnli":
    input_file = f"{dataset}-input"
    model_name = "textattack/bert-base-uncased-QNLI"
    fizle_task = "natural language inference on the QNLI dataset"


## Create input dataframe


In [None]:
%pwd

In [None]:
import pandas as pd

df_input = pd.read_csv(f"input/{input_file}.csv")
df_input.head()

In [None]:
df_input.shape

## Load models

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load the sentiment model and tokenizer.

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

if dataset == "sst_2":
    id2label = {0: "NEGATIVE", 1: "POSITIVE"}
    label2id = {"NEGATIVE": 0, "POSITIVE": 1}
    sentiment_model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        id2label=id2label,
        label2id=label2id
    ).to(device)

elif dataset == "qnli":
    id2label = {0: "entailment", 1: "not_entailment"}
    label2id = {"entailment": 0, "not_entailment": 1}
    sentiment_model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        id2label=id2label,
        label2id=label2id
    ).to(device)

sentiment_model_tokenizer = AutoTokenizer.from_pretrained(model_name)

Load the GPT-2 model for calculating perplexity.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

Load the language model for CLOSS.

In [None]:
import transformers

LM_model = transformers.BertForMaskedLM.from_pretrained("bert-base-uncased").to(device)
LM_model.lm_head = LM_model.cls

## Helper function

In [None]:
import re
import textdistance
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import time

def calculate_score(text, sentiment_model_tokenizer, dataset, device):
    def tokenize_with_correct_token_type_ids(input_text, tokenizer):
        # Tokenize the input
        tokens = tokenizer(input_text, return_tensors="pt", padding=True)

        # Get the position of the first [SEP] token
        sep_pos = (tokens.input_ids == tokenizer.sep_token_id).nonzero()[0, 1].item()

        # Create token_type_ids
        token_type_ids = torch.zeros_like(tokens.input_ids)
        token_type_ids[0, sep_pos+1:] = 1  # Set to 1 after the first [SEP] token

        # Update the tokens dictionary
        tokens['token_type_ids'] = token_type_ids

        return tokens

    if type(text) == list:
        if type(text[0]) == str:
            tokens = text
            ids = sentiment_model_tokenizer.convert_tokens_to_ids(tokens)
            text = sentiment_model_tokenizer.decode(ids[1:-1])
        elif type(text[0]) == int:
            ids = text
            text = sentiment_model_tokenizer.decode(ids[1:-1])

    if dataset == "sst_2":
        inputs = sentiment_model_tokenizer(text, max_length=512, truncation=True, return_tensors="pt").to(device)
    elif dataset == "qnli":
        inputs = tokenize_with_correct_token_type_ids(text, sentiment_model_tokenizer).to(device)

    logits = sentiment_model(**inputs).logits
    prob_positive = torch.nn.functional.softmax(logits, dim=1)[0][1].item()
    return prob_positive

def calculate_perplexity(text):
    inputs = gpt2_tokenizer(text, return_tensors="pt").to(device)
    loss = gpt2_model(**inputs, labels=inputs["input_ids"]).loss
    perplexity = torch.exp(loss).item()
    return perplexity

def is_flip(original_score, counterfactual_score):
    # might need to be updated for AG News
    positive_to_negative = original_score >= 0.5 and counterfactual_score < 0.5
    negative_to_positive = original_score < 0.5 and counterfactual_score >= 0.5
    return positive_to_negative or negative_to_positive

def truncate_text(text, max_length=100):
    tokens = text.split()
    if len(tokens) > max_length:
        text = " ".join(tokens[:max_length])
    return text

def get_all_embeddings(model, tokenizer):
    all_word_embeddings = torch.zeros((tokenizer.vocab_size, 768)).detach().to(device)
    for i in range(tokenizer.vocab_size):
        input_tensor = torch.tensor(i).view(1, 1).to(device)
        word_embedding = model.bert.embeddings.word_embeddings(input_tensor)
        all_word_embeddings[i, :] = word_embedding
    all_word_embeddings = all_word_embeddings.detach().requires_grad_(False)
    return all_word_embeddings

def get_levenshtein_similarity_score(original_text, counterfactual_text):
    score = 1 - textdistance.levenshtein.normalized_distance(original_text, counterfactual_text)
    return score

def format_polyjuice_output(polyjuice_output, original_question, original_answer):
    # Helper function to calculate cosine similarity
    def get_cosine_similarity(text1, text2):
        vectorizer = CountVectorizer().fit_transform([text1, text2])
        return cosine_similarity(vectorizer)[0][1]

    sep_token = " [SEP] "

    # 1. Return the output if it's already valid
    if sep_token in polyjuice_output:
        return polyjuice_output

    # Replace invalid separator tokens
    polyjuice_output = re.sub(r"\[(\w+)\]", sep_token, polyjuice_output)

    # If it's still valid after replacement, return it
    if sep_token in polyjuice_output:
        return polyjuice_output

    # Check if the output is more similar to a question or an answer
    similarity_to_question = get_cosine_similarity(polyjuice_output, original_question)
    similarity_to_answer = get_cosine_similarity(polyjuice_output, original_answer)

    if polyjuice_output.strip().endswith("?") or similarity_to_question > similarity_to_answer:
        # It's likely a question, so use the new question with the original answer
        return f"{polyjuice_output} [SEP] {original_answer}"
    else:
        # It's likely an answer, so use the original question with the new answer
        return f"{original_question} [SEP] {polyjuice_output}"

def get_output(df_input, counterfactual_method, args, timeout=10, batch_size=10):
    df_input = df_input.copy()
    output_data = {
        "original_text": [],
        "original_score": [],
        "original_perplexity": [],
        "counterfactual_text": [],
        "counterfactual_score": [],
        "counterfactual_perplexity": [],
        "found_flip": [],
        "levenshtein_similarity_score": []
    }
    exclude_indices = [399]
    for i, row in df_input.iterrows():
        if counterfactual_method.__name__ == "generate_polyjuice_counterfactual" and i in exclude_indices:
            continue
        try:
            if dataset == "sst_2":
                original_text = row["original_text"]
            elif dataset == "qnli":
                original_question, original_answer = row["original_question"], row["original_answer"]
                original_text = f"{original_question} [SEP] {original_answer}"
            else:
                raise ValueError(f"Unsupported dataset: {dataset}")

            original_text = format_sentence(original_text, dataset)
            print(f"Processing input {i + 1}/{len(df_input)}: num tokens: {len(original_text.split())}")

            original_score = calculate_score(original_text, sentiment_model_tokenizer, dataset, device)
            original_perplexity = calculate_perplexity(original_text)

            args_with_score = {**args, "original_score": original_score}
            counterfactual_text = counterfactual_method(original_text, calculate_score, args)

            if counterfactual_text is None:
                print(f"Timeout occurred for sample {i}, skipping")
                continue

            counterfactual_text = format_sentence(counterfactual_text, dataset)
            if dataset == "qnli" and counterfactual_method.__name__ == "generate_polyjuice_counterfactual":
                counterfactual_text = format_polyjuice_output(counterfactual_text, original_question, original_answer)

            label_width = 20
            print(f"\n{'original_text:'.ljust(label_width)} {original_text}")
            print(f"{'counterfactual_text:'.ljust(label_width)} {counterfactual_text}\n")

            counterfactual_score = calculate_score(counterfactual_text, sentiment_model_tokenizer, dataset, device)
            counterfactual_perplexity = calculate_perplexity(counterfactual_text)
            found_flip = is_flip(original_score, counterfactual_score)
            levenshtein_similarity_score = get_levenshtein_similarity_score(original_text, counterfactual_text)

            output_data["original_text"].append(original_text)
            output_data["original_score"].append(original_score)
            output_data["original_perplexity"].append(original_perplexity)
            output_data["counterfactual_text"].append(counterfactual_text)
            output_data["counterfactual_score"].append(counterfactual_score)
            output_data["counterfactual_perplexity"].append(counterfactual_perplexity)
            output_data["found_flip"].append(found_flip)
            output_data["levenshtein_similarity_score"].append(levenshtein_similarity_score)

        except Exception as e:
            print(f"Exception {e}")
            print(f"Failed to generate counterfactual, skipping sample")
            continue

    df_output = pd.DataFrame(output_data)
    return df_output


In [None]:
all_word_embeddings = get_all_embeddings(sentiment_model, sentiment_model_tokenizer).to(device)

In [None]:
from google.colab import userdata

client = OpenAI(api_key=userdata.get("API_KEY"))

Test the accuracy of the model.

In [None]:
correct = 0

for i in range(len(df_input)):
    print(f"i: {i}")
    row = df_input.iloc[i]

    if dataset == "sst_2":
        original_text, original_label = row["original_text"], row["original_label"]
    elif dataset == "qnli":
        original_question, original_answer, original_label = row["original_question"], row["original_answer"], row["original_label"]
        original_text = f"{original_question} [SEP] {original_answer}"

    score = calculate_score(original_text, sentiment_model_tokenizer, dataset, device)
    y_hat = 1 if score >= 0.5 else 0
    if y_hat == original_label:
        correct += 1

accuracy = correct / len(df_input)
print(f"accuracy: {accuracy}")

## Counterfactual generator functions

In [None]:
from CLOSS.closs import generate_counterfactual
import re

def generate_polyjuice_counterfactual(original_text, _, args):
    ctrl_code = None if dataset == "qnli" else "negation"
    perturbations = pj.perturb(
        orig_sent=original_text,
        ctrl_code=ctrl_code,
        num_perturbations=1,
        perplex_thred=None
    )
    counterfactual_text = perturbations[0]
    return counterfactual_text

def generate_closs_counterfactual(original_text, calculate_score, args):
    counterfactual_text = generate_counterfactual(
        original_text,
        sentiment_model,
        LM_model,
        calculate_score,
        sentiment_model_tokenizer,
        all_word_embeddings,
        device,
        args
    )
    return counterfactual_text

def call_openai_api(system_prompt, model):
    completion = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt}
        ],
        top_p=1,
        temperature=0.4,
        frequency_penalty=1.1
    )
    output = completion.choices[0].message.content
    return output

def generate_naive_fizle_counterfactual(original_text, _, args):
    original_score, model = args["original_score"], args["model"]
    original_id = 1 if original_score >= 0.5 else 0
    cf_id = 0 if original_id == 1 else 1

    original_label = id2label[original_id]
    cf_label = id2label[cf_id]

    system_prompt = f"""In the task of {fizle_task}, a trained black-box classifier correctly predicted the label '{original_label}' for the following text. Generate a counterfactual explanation by making minimal changes to the input text, so that the label changes from '{original_label}' to '{cf_label}'. Use the following definition of 'counterfactual explanation': "A counterfactual explanation reveals what should have been different in an instance to observe a diverse outcome." Enclose the generated text within <new> tags.
    -
    Text: {original_text}"""

    for i in range(10):
        print(f"attempt: {i + 1}")
        output = call_openai_api(system_prompt, model)
        if not output:
            continue
        counterfactual_text = re.search("<new>(.*?)</new>", output).group(1)
        if counterfactual_text:
            return counterfactual_text

    if not output:
        print("No counterfactual generated.")

    print("Failed to generate counterfactual surrounded by <new> tags")
    counterfactual_text = output[5:-6]

    return counterfactual_text

def generate_guided_fizle_counterfactual(original_text, _, args):
    original_score, model = args["original_score"], args["model"]
    original_id = 1 if original_score >= 0.5 else 0
    cf_id = 0 if original_id == 1 else 1

    original_label = id2label[original_id]
    cf_label = id2label[cf_id]

    system_prompt = ""

    # 1. Find important words
    step1_system_prompt = " ".join([
        f"In the task of {fizle_task}, a trained black-box classifier correctly predicted the label '{original_label}' for the following text.",
        f"Explain why the model predicted the '{original_label}' label by identifying the words in the input that caused the label. List ONLY the words as a comma separated list.",
        f"\n-\nText: {original_text}",
        f"\nImportant words identified: "
    ])
    system_prompt += step1_system_prompt
    important_words = call_openai_api(step1_system_prompt, model)
    system_prompt += important_words + "\n"

    # 2. Generate the final counterfactual
    correct_output_format = False
    for i in range(10):
        step2_system_prompt = " ".join([
            f"Generate a counterfactual explanation for the original text by ONLY changing a minimal set of the words you identified, so that the label changes from '{original_label}' to '{cf_label}'.",
            f"Use the following definition of 'counterfactual explanation': 'A counterfactual explanation reveals what should have been different in an instance to observe a diverse outcome.'",
            f"Enclose the generated text within <new> tags."
        ])
        final_system_prompt = system_prompt + step2_system_prompt
        step2_output = call_openai_api(final_system_prompt, model)
        if not step2_output:
            continue
        counterfactual_text = re.search("<new>(.*?)</new>", step2_output).group(1)
        if counterfactual_text:
            return counterfactual_text

    if not output:
        print("No counterfactual generated.")

    print("Failed to generate counterfactual surrounded by <new> tags")
    counterfactual_text = output[5:-6]

    return counterfactual_text


## Run CLOSS and HotFlip


Move to the main parent directory.

In [None]:
%cd "CLOSS"
%pwd

In [None]:
df_input.head()

1. Run HotFlip:

In [None]:
args = {
    "beam_width": 15,
    "w": 5,
    "K": 30,
    "tree_depth": 0.3,
    "substitution_evaluation_method": "hotflip_only",
    "substitution_gen_method": "hotflip_only",
    "dataset": dataset
}

df_output_hotflip = get_output(df_input, generate_closs_counterfactual, args)

In [None]:
df_output_hotflip.head()

In [None]:
df_output_hotflip.to_csv(f"./output/hotflip-output-{dataset}-{num_samples}.csv", index=False)

2. Run CLOSS without optimization and without retraining the language modeling head:

In [None]:
df_input.head()

In [None]:
args = {
    "beam_width": 15,
    "w": 5,
    "K": 30,
    "substitutions_after_loc": 0.3,
    "tree_depth": 0.3,
    "substitution_evaluation_method": "SVs",
    "substitution_gen_method": "no_opt_lmh",
    "dataset": dataset
}

df_output_closs = get_output(df_input, generate_closs_counterfactual, args)

In [None]:
df_output_closs.head()

In [None]:
df_output_closs.to_csv(f"./output/closs-output-{dataset}-{num_samples}.csv", index=False)

## Run Polyjuice

### Setup

In [None]:
%cd polyjuice
%pwd

In [None]:
!python -m spacy download en_core_web_sm

In [None]:
%pip install -e .

Make sure the model is being imported properly.

In [None]:
import importlib
import polyjuice

importlib.reload(polyjuice)
print(polyjuice.__file__)

In [None]:
from polyjuice import Polyjuice

pj = Polyjuice(model_path="uw-hai/polyjuice", is_cuda=True)

Test the model.

In [None]:
text = "julia is played with exasperating blandness by laura regan ."
perturbations = pj.perturb(
    orig_sent=text,
    ctrl_code="negation",
    perplex_thred=None
)
perturbations

Run the model and get the output.

In [None]:
df_input.head()

In [None]:
df_output_polyjuice = get_output(df_input, generate_polyjuice_counterfactual, {})

In [None]:
df_output_polyjuice.head()

In [None]:
%cd ..
%pwd

In [None]:
df_output_polyjuice.to_csv(f"./output/polyjuice-output-{dataset}-{num_samples}.csv", index=False)

## FIZLE

* Naive: uses a single prompt.
* Guided: Uses two prompts. The first prompt identifies important words and the second prompt generates the counterfactual.

Hyperparameters: top_p sampling = 1, temperature t = 0.4 and repetition penalty = 1.1.


### 1. FIZLE naive

In [None]:
df_input.head()

In [None]:
args = {"model": "gpt-4-turbo"}
df_output = get_output(df_input, generate_naive_fizle_counterfactual, args)

In [None]:
df_output.head()

In [None]:
df_output.to_csv(f"./output/fizlenaive-output-{dataset}-new-2.csv", index=False)

### FIZLE guided

In [None]:
df_input.head()

In [None]:
args = {"model": "gpt-4-turbo"}
df_output = get_output(df_input, generate_naive_fizle_counterfactual, args)

In [None]:
df_output.head()

In [None]:
df_output.to_csv(f"./output/fizleguided-output-{dataset}.csv", index=False)