Polyjuice - Intro ML Fall 22 Course Project

## Background

Counterfactuals are conditional sentences that discuss what would have been true if conditions were different. Therefore, pertubations of counterfactuals are important for training and analysis of a variety of different NLP models as they are used to assess how a prediction changes when the original texts were changed slightly [1].

Counterfactuals generation nowadays relies on human annotators or automatic text generators [2]. Out of the two methods, manual generation of counterfactual examples are costly and time-consuming and could bear human biases. At the same time, the existing automated generators are mostly not application agnostic, which results in ignoring subsets of perturbations that may be useful for other applications. Neither of the methods are ideal for diverse and efficient counterfactuals generation. 

Polyjuice, on the otherhand, is a counterfactual generator that is application agnostic, meaning that it will first generate a pool of general purpose counterfactual examples where users could choose from based on their use cases [3]. 

Polyjuice is expected to produce counterfactual examples that are close to the original text, fluent in expression, robust for each application, and covers a large variety of use cases [4]. 

## Working with the existing code

Source: https://github.com/tongshuangwu/polyjuice

Below, we are cloning the repo from Polyjuice github and installing dependecies required.

After the polyjuice package is imported, we used a sample prompt to feed into polyjuice to get counterfactual examples, which we used the perturb() function to achieve. 

We could also use perturb() function with more argument, including control code to pass in (ctrl_code=), number of counterfactual examples to gerneate(num_perturbations=). The control codes could be chosen from 'resemantic', 'restructure', 'negation', 'insert', 'lexical', 'shuffle', 'quantifier' and 'delete'.

After the counterfactuals are generated, we could detect the control code of the generated text by calling detect_ctrl_code() on the original text and the generated text. 

We had to change the version of the dependencies to the minimum dependencies required since there was a transformer error because of version mismatch. 

We also had to import "nltk" package and download "omw-1.4" directly before importing Polyjuice to avoid an error raised by package not found.

In [2]:
!pip install munch==2.5.0 spacy==3.0.6
!pip install sentence-transformers==1.1.0 transformers==4.5.1
!pip install pattern==3.6.0
!pip install nltk scipy zss

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
!pip install polyjuice_nlp
!pip install torch
# The SpaCy language package
!python -m spacy download en_core_web_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting polyjuice_nlp
  Downloading polyjuice_nlp-0.1.5-py3-none-any.whl (30 kB)
Installing collected packages: polyjuice-nlp
Successfully installed polyjuice-nlp-0.1.5
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting en-core-web-sm==3.0.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl (13.7 MB)
[K     |████████████████████████████████| 13.7 MB 27.6 MB/s 
Installing collected packages: en-core-web-sm
  Attempting uninstall: en-core-web-sm
    Found existing installation: en-core-web-sm 3.4.1
    Uninstalling en-core-web-sm-3.4.1:
      Successfully uninstalled en-core-web-sm-3.4.1
Successfully installed en-core-web-sm-3.0.0
[38;5;2m✔ Download

In [4]:
import nltk
nltk.download('omw-1.4')
from polyjuice import Polyjuice

pj = Polyjuice(model_path="uw-hai/polyjuice", is_cuda=True)

# the base sentence
text = "It is great for kids."
perturbations = pj.perturb(text)

[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


Downloading:   0%|          | 0.00/828 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/510M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/167 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [5]:
perturbations = pj.perturb(
    orig_sent=text,
    # can specify where to put the blank. Otherwise, it's automatically selected.
    # Can be a list or a single sentence.
    blanked_sent="It is [BLANK] for kids.",
    # can also specify the ctrl code (a list or a single code.)
    # The code should be from 'resemantic', 'restructure', 'negation', 'insert', 'lexical', 'shuffle', 'quantifier', 'delete'.
    ctrl_code="negation",
    # Customzie perplexity score. 
    perplex_thred=5,
    # number of perturbations to return
    num_perturbations=1,
    # the function also takes in additional arguments for huggingface generators.
    num_beams=3
)

# return: [
# 'It is not great for kids.', 
# 'It is great for kids but not for anyone.',
# 'It is great for kids but not for any adults.']

In [6]:
pj.detect_ctrl_code(
    "it's great for kids.", 
    "It is great for kids but not for any adults.")
# return: negation

'negation'

## Validating a claim

The claim we chose to validate is that the success rate of satisfactory counterfactuals generated by Polyjuice with a control code is much higher than those generated without a control code [5]. In the paper, the success rate increased from 5% to 47% by using control code "negation" to perturb the original prompts.

We chose this claim to validate because of two reasons. The first reason is that the claim has clear process and results, making it possible to reproduce. 

The second reason is that the result from this claim is essential for arguing that Polyjuice performs significantly better than other existing text generators - in this case, we chose to compare it with GPT-2 fine-tuned without control code (Polyjuice -a).

To validate the claim, we are following the method discussed in section A.2.2 of the paper.  We are using 100 prompts instead of 300 prompts [6] mentioned in the paper because the process of generating counterfactuals without control codes take longer time than the online notebook permit for a single section. Therefore, we decided to use 100 prompts, which is a large enough sample size to see a significant improvement on the result, but also small enough to be able to finish within the permitted time.

We chose "negation" as the control code because it is said to produce one of the most significant improvement when control code is passed in [7]. 

We chose to use the dataset from Stanford NLP because it is rated as one of the top sentiment analysis datasets for machine learning [8], and it has a relatively small file size in comparison to the others. Moreover, since it is adapted from movie reviews from rotten tomatoes, it is a good representation of naturally occuring sentences [9] that Polyjuice trains on.

We start by importing the dataset zip file from Stanford NLP website, and unzip the file. And then we open the text file and choose the first 100 prompts as input for the models.

We first train the GPT-2 finetuned without control code model, which is also called Polyjuice -a [10]. The code for perturb() and detect_control_code is adapted from https://github.com/tongshuangwu/polyjuice, but the rest of the logic was added by us. For example, the for loop to loop through the sentences in the input, the try-catch block to account for exceptions arise during perturbations, and calculating the success rate. 

The success rate is calculated by detecting the control code from the perturbed sentences, and compare it with the desired control code [11]. If at least one of the control codes from the perturbed sentences match the desired control code, the perturbation is deemed successful.

After we get the success rate from Polyjuice -a, we use the same method to calculate success rate from Polyjuice. The only difference is this time when we call perturb(), we are passing in the control code argument "negation".

After the success rates for both models are calculated, we calculate the difference between them to see the improvement.

We could see significantly improvement in success rate when control code is passed in as an arguement. We ran the experiment multiple times and we got success rate for Polyjuice -a to be ~11%, and Polyjuice to be ~42%, which is a ~31% increase in success rate.

One interesting find is that Polyjuice is not just more effective, it is also more efficient. We could see a big gap in execution time where Polyjuice -a takes almost 3x as much time as Polyjuice to generate conterfactuals.

From the experiment results, we conclude that the success rate of generating satisfactory counterfactual examples increased significantly by using control code "negation" to perturb the original prompts.

In [7]:
# getting the dataset from Stanford NLP: https://nlp.stanford.edu/sentiment/code.html
from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile
zipurl = 'http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip'
with urlopen(zipurl) as zipresp:
    with ZipFile(BytesIO(zipresp.read())) as zfile:
        zfile.extractall('/tmp')

In [8]:
# using the first 100 prompt from the dataset as input
nltk.download('punkt')
with open('/tmp/stanfordSentimentTreebank/original_rt_snippets.txt', 'r') as in_file:
    text = in_file.read()
    sents = nltk.sent_tokenize(text)

sents = sents[:100]
len(sents)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


100

In [9]:
desired_control_code = "negation"

In [10]:
%%time

# Polyjuice -a (without control code)
success_num = 0
size=0

for text in sents:
  perturbations = []
  
  try:
    perturbations = pj.perturb(
      orig_sent=text,
      # number of perturbations to return
      num_perturbations=3,
      # the function also takes in additional arguments for huggingface generators.
      num_beams=5
    )
  except Exception as e:
      print(f"Could not produce counterfactuals for the sentence {text}, skipping to next sentence")
      print(e)
    
  if perturbations == []:
    size+=1
    continue

  control_codes = []
  for p in perturbations:
    control_codes.append(pj.detect_ctrl_code(text, p))

  for code in control_codes:
    if code == desired_control_code:
      success_num +=1
      break
  size+=1
success_rate = success_num/size
success_rate

CPU times: user 2h 12min 51s, sys: 2min 20s, total: 2h 15min 12s
Wall time: 2h 14min 55s


0.11

In [11]:
%%time

# polyjuice (with control code)
success_num = 0
size = 0

for text in sents:
  perturbations = []

  perturbations = pj.perturb(
        orig_sent=text,
        ctrl_code=desired_control_code,
        # number of perturbations to return
        num_perturbations=3,
        # the function also takes in additional arguments for huggingface generators.
        num_beams=5
  )

  control_codes = []
  for p in perturbations:
    control_codes.append(pj.detect_ctrl_code(text, p))

  for code in control_codes:
    if code == desired_control_code:
      success_num +=1
      break
  size+=1
success_rate_with_code = success_num/size
success_rate_with_code

CPU times: user 26min 34s, sys: 31.9 s, total: 27min 6s
Wall time: 27min 3s


0.42

In [12]:
success_rate_increase = success_rate_with_code - success_rate
success_rate_increase

0.31

## References

[1] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6707. arXiv preprint arXiv:2101.00288 (2021).

[2] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6707. arXiv preprint arXiv:2101.00288 (2021).

[3] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6707-6708. arXiv preprint arXiv:2101.00288 (2021).

[4] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6708. arXiv preprint arXiv:2101.00288 (2021).

[5] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6710. arXiv preprint arXiv:2101.00288 (2021).

[6] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6721. arXiv preprint arXiv:2101.00288 (2021).

[7] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6721. arXiv preprint arXiv:2101.00288 (2021).

[8] ---. Top 12 Free Sentiment Analysis Datasets | Classified and Labeled. 27 July 2021, www.repustate.com/blog/sentiment-analysis-datasets.

[9] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6709. arXiv preprint arXiv:2101.00288 (2021).

[10] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6721. arXiv preprint arXiv:2101.00288 (2021).

[11] Wu, Tongshuang, et al. "Polyjuice: Generating counterfactuals for explaining, evaluating, and improving models." Pages 6721. arXiv preprint arXiv:2101.00288 (2021).