<a target="_blank" href="https://colab.research.google.com/github/soheeyang/unified-prompt-selection/blob/main/notebooks/prompt_selection.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" align="left"/>
</a>&nbsp;or in a local notebook.

In [None]:
import os

try:
    # Setting up an environment for Google Colab.

    import google.colab, sys

    install_script = """#!/usr/bin/bash

    !(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
    cd /content && rm -rf /content/unified-prompt-selection
    git clone https://github.com/soheeyang/unified-prompt-selection.git unified-prompt-selection > install.log 2>&1
    pip install -r /content/unified-prompt-selection/requirements.txt >> install.log 2>&1
    pip install --upgrade google-cloud-storage >> install.log 2>&1"""

    with open("/content/install.sh", "w") as f:
        f.write(install_script)

    os.system("bash /content/install.sh")
    os.chdir("/content/unified-prompt-selection")
    sys.path.append("/content/unified-prompt-selection")
except ModuleNotFoundError as _:
    os.chdir("/content/unified-prompt-selection")
    os.system("pip install -r requirements.txt > install.log 2>&1")

# Prompt Selection

By running `run_prompt_selection.py`, you can extract $p(y|x,t)$ from LLMs and select a prompt.

In [None]:
!python run_prompt_selection.py

You can obtain four different results depending on whether calibration is applied:

 * **X**: without applying any calibration

 * **A**: applying calibration only for **A**nswer selection

 * **P**: applying calibration only for **P**rompt selection

 * **PA**: applying calibration for both **P**rompt selection and **A**nswer selection.

Notes:
- When applying prompt selection to a dataset without ground truth labels, prompt selection is possible, but evaluation results cannot be verified. **Before applying prompt selection to a dataset, check whether ground truth labels are available or not** or **check whether the evaluation results for datasets without ground truth labels are based on the Target of the prompt selection results.**

Running the command as shown above will execute Prompt Selection according to the predefined default arguments.

We used [hydra](https://hydra.cc/docs/intro/) to manage complex configurations. You can check the configurations in [`./conf`](./conf/), and besides specifying arguments on the command line, you can modify the arguments by editing the [`./conf/config.yaml`](./conf/config.yaml) file.

You can also execute various combinations by adding `-m` or `--multirun` as follows:

```bash
python run_prompt_selection.py -m \
    method=MI,MI_G,MI_L,MI_GL,GE,GE_M,LE,MDL,MDL_M,ZLP,ZPM,ZMV,PPL,PPL_L \
    calibration=cbm-softmax,cbm-mean,cc-softmax,cc-mean,pmi-softmax,pmi-mean \
    decoder=opt-1.3b,opt-2.7b,opt-6.7b,opt-30b,opt-66b,gpt-neo-1.3b,gpt-neo-2.7b,gpt-j-6b,gpt2-xl,bloom-3b \
    dataset=sst2,ag_news,cb,imdb,newspop,rte,sst5,tweet_emotion,tweet_irony,piqa,copa,hellaswag,story_cloze \
    prompt=base_prompts,v12_prompts,v2_prompts,fewshot_prompt \
    first_token=false,true \
    sum_log_prob=false,true \
    fewshot=null,'1,2,4' \
    filter=false,true \
    unbalance=false,true
```

In [None]:
!python run_prompt_selection.py -m \
    method=MI,MI_G,MI_L,MI_GL,GE,GE_M,LE,MDL,MDL_M,ZLP,ZPM,ZMV,PPL,PPL_L

### Probability-based Prompt Selection Method

<center><img src="../images/ps_methods.png" width="100%" height="100%"></center>

The following probability-based prompt selection methods are available: 'MI', 'GE', 'LE', 'MDL', 'ZLP', 'ZPM', 'ZMV', and 'PPL'.

To use a specific prompt selection method, pass the desired method to `method`. You can find detailed descriptions of each method in section 2.2 Existing Approaches of the [paper](https://arxiv.org/pdf/2305.14877.pdf).

### Variants created by Prompt Selection Methods

<p align="center" width="100%">
<img src="../images/variants.png" width="55%" height="55%">
</p>

The following methods are variants that modify the score calculation formula of existing Probability-based prompt selection methods: 'MI_G', 'MI_L,' 'MI_GL', 'GE_M', 'MDL_M', and 'PPL_L'.

You can check the arguments specific to these probability-based prompt selection methods in the [`./conf/method`](./conf/method/) directory.

If a method name is followed by '_L', it means that [`select_for_each_x`](./conf/method/MI_L.yaml) is set to 'True', and instance-wise prompt selection is performed. The methods that support instance-wise prompt selection are 'MDL', 'MI', and 'PPL'.

If a method name is followed by '_G', it means that [`one_hot`](./conf/method/MI_G.yaml) is set to 'True', and one-hot $p(y|x,t)$ is used for GE calculation.