# Perplexity Filters

This notebook will be rather short, only covering how to make perplexity filters for LLM inputs. While the concept is rather simple, we believe that implementing the filters in code gives good practice working with tensors and some very useful PyTorch functions. 

We'll start with our inputs:

In [None]:
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
import xlab

DEVICE = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

## Task 1: Tokenize IDs and Labels

This should hopefully be a quick warmup exercise!

<details>
<summary>💡 <b>Hint for Task #1</b></summary>

Set the input labels with `inputs["labels"] = inputs["input_ids"]`

</details>


<details>
<summary>🔐 <b>Solution for Task #1</b></summary>

```python
def tokenize_inputs(tokenizer: AutoTokenizer, prompt: str) -> BatchEncoding:
    """
    Tokenizes the prompt and sets the prompt's input ids as the labels.

    Args:
        tokenizer: the model's tokenizer
        prompt: the prompt to be evaluated
    
    Returns: the dictionary-like BatchEncoding object with the labels set as 
        the input ids.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    inputs["labels"] = inputs["input_ids"]
    return inputs
```

</details>

In [None]:
def tokenize_inputs(tokenizer: AutoTokenizer, prompt: str) -> BatchEncoding:
    """
    Tokenizes the prompt and sets the prompt's input ids as the labels.

    Args:
        tokenizer: the model's tokenizer
        prompt: the prompt to be evaluated

    Returns: the dictionary-like BatchEncoding object with the labels set as
        the input ids.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.ppl_filters.task1(tokenize_inputs)

## Task 2: Get Logits

Now we want to get the logits from our model from a single forward pass.

<details>
<summary>💡 <b>Hint for Task #2</b></summary>

Call `model()` with the unpacked inputs.

</details>

<details>
<summary>💡 <b>Hint for Task #2</b></summary>

Extract the logits from the output with `output.logits`.

</details>


<details>
<summary>🔐 <b>Solution for Task #2</b></summary>

```python
def get_logits(model: AutoModelForCausalLM, inputs: BatchEncoding) -> torch.Tensor:
    """
    Passes `inputs` to model, returns the logits.

    Args:
        model: the model
        inputs: the BatchEncoding input object, passed to the model
    
    Returns [batch_size, seq_len, vocab_size]: output logits tensor.
    """
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.logits
```

</details>

In [None]:
def get_logits(model: AutoModelForCausalLM, inputs: BatchEncoding) -> torch.Tensor:
    """
    Passes `inputs` to model, returns the logits.

    Args:
        model: the model
        inputs: the BatchEncoding input object, passed to the model

    Returns [batch_size, seq_len, vocab_size]: output logits tensor.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.ppl_filters.task2(get_logits)

## Task 3: Log and Softmax the Logits

Recall that the perplexity equation is
$$
\text{PPL}(x_{1:n}) = \exp \left( -\frac{1}{n} \sum_{i = 1}^n \log p(x_i | x_{< i})  \right).
$$
You'll notice that we'll need the negative log-likelihood of each token for this equation. In this step, we'll do the first half of that, applying log-softmax to the sequence (over the vocabulary dimension). As a note, we'll be *excluding* the first token, because it has no previous token for the model to base its prediction on.

<details>
<summary>💡 <b>Hint for Task #3</b></summary>

Exclude the first token and get rid of the batch dimension with `logits[:, 1:, :].squeeze(0)`.

</details>

<details>
<summary>💡 <b>Hint for Task #3</b></summary>

Use `torch.nn.functional.log_softmax()` over the vocab dimension.

</details>


<details>
<summary>🔐 <b>Solution for Task #3</b></summary>

```python
def get_log_softmax_tokens(logits: torch.Tensor) -> torch.Tensor:
    """
    Applies the log-softmax operation to the model's logits to turn each token
    position into a probability distribution.

    Args:
        logits [batch_size, seq_len, vocab_size]: the output logits tensor
    
    Returns [seq_len, vocab_size]: the probability distribution for all tokens
        over each sequence position.
    """
    logits = logits[:, -1:, :].squeeze(0)
    log_softmaxed_toks = torch.nn.functional.log_softmax(logits, dim=1)
    return log_softmaxed_toks
```

</details>

In [None]:
def get_log_softmax_tokens(logits: torch.Tensor) -> torch.Tensor:
    """
    Applies the log-softmax operation to the model's logits to turn each token
    position into a probability distribution.

    Args:
        logits [batch_size, seq_len, vocab_size]: the output logits tensor

    Returns [seq_len, vocab_size]: the probability distribution for all tokens
        over each sequence position.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.ppl_filters.task3(get_log_softmax_tokens)

## Task 4: Get the NLL for the Label Tokens

Now that we've applied the log-softmax over the vocabulary dimension for each token in our sequence, we want to select *only* the label token's log-likelihood for each position in the sequence. This uses the `torch.gather()` operation, which we've looked at in the earlier GCG notebook, however it's somewhat unintuitive, so feel free to look at the hints!

<details>
<summary>💡 <b>Hint for Task #4</b></summary>

Remove the label for the first token and the batch dimension with `labels = labels[:, 1:].squeeze(0)`.

</details>

<details>
<summary>💡 <b>Hint for Task #4</b></summary>

Call `torch.gather()` over the vocab dimension. The operation doesn't get rid of this operation, so also be sure to squeeze it at the end! 

We also suggest reading [this stackoverflow response](https://stackoverflow.com/questions/50999977/what-does-gather-do-in-pytorch-in-layman-terms) to understand `torch.gather()`.

</details>

<details>
<summary>💡 <b>Hint for Task #4</b></summary>

Make sure you return the *negative* logprobs!

</details>


<details>
<summary>🔐 <b>Solution for Task #4</b></summary>

```python
def extract_nll(log_softmaxed_toks: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """
    Extracts the negative log-likelihood for each label token from the logprobs.

    Args:
        log_softmaxed_toks [seq_len - 1, vocab_size]: the logit-derived logprobs
        labels [batch_size, seq_len]: the label tokens to extract

    Returns [seq_len]: tensor of the probability of each label token in the
        sequence.
    """
    labels = labels[:, 1:].squeeze(0)  # remove batch dimension, exclude first token
    nll = -torch.gather(
        input=log_softmaxed_toks,  # Over the log softmax,
        dim=1,  # in dim = 1 (vocab dimension),
        index=labels.unsqueeze(-1),  # index using the labels (with "fake" vocab dim),
    ).squeeze(-1)  # then remove the vocab direction.
    return nll
```

</details>

In [None]:
def extract_nll(log_softmaxed_toks: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """
    Extracts the negative log-likelihood for each label token from the logprobs.

    Args:
        log_softmaxed_toks [seq_len - 1, vocab_size]: the logit-derived logprobs
        labels [batch_size, seq_len]: the label tokens to extract

    Returns [seq_len]: tensor of the probability of each label token in the
        sequence.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.ppl_filters.task4(extract_nll)

## Task 5: The Full Function

You've now built all the parts we need to create our `get_per_token_NLL()` function! This will just involve calling all the previous function's you've written in order.

<details>
<summary>🔐 <b>Solution for Task #5</b></summary>

```python
def get_per_token_NLL(
    model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str
) -> torch.Tensor:
    """
    Computes the per-token cross-entropy loss of the input sequence.

    Args:
        model: the language model
        tokenizer: the model's tokenizer
        prompt: the prompt whose perplexity will be modeled
    
    Returns [seq_len - 1]: tensor of loss for each token position, excluding
        the first.
    """
    inputs = tokenize_inputs(tokenizer=tokenizer, prompt=prompt)
    logits = get_logits(model=model, inputs=inputs)
    log_softmaxed_toks = get_log_softmax_tokens(logits=logits)
    nll = extract_nll(log_softmaxed_toks=log_softmaxed_toks, labels=inputs["labels"])

    return nll
```

</details>

In [None]:
def get_per_token_NLL(
    model: AutoModelForCausalLM, tokenizer: AutoTokenizer, prompt: str
) -> torch.Tensor:
    """
    Computes the per-token cross-entropy loss of the input sequence.

    Args:
        model: the language model
        tokenizer: the model's tokenizer
        prompt: the prompt whose perplexity will be modeled

    Returns [seq_len - 1]: tensor of loss for each token position, excluding
        the first.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.ppl_filters.task5(get_per_token_NLL)

## Task 6: Naive Perplexity Extraction

With our negative logprobs, we can now calculate perplexity! First, you'll implement a simple function that calculates the perplexity over the whole sequence.

<details>
<summary>💡 <b>Hint for Task #6</b></summary>

This is a one-liner.

</details>

<details>
<summary>💡 <b>Hint for Task #6</b></summary>

Use `.mean()` and `torch.exp()`.

</details>


<details>
<summary>🔐 <b>Solution for Task #6</b></summary>

```python
def get_seq_ppl(nll: torch.Tensor) -> float:
    """
    Returns the perplexity of the whole sequence.

    Args:
        nll [seq_len - 1]: tensor of the nll 
    
    Returns: the perplexity over the whole sequence.
    """
    return torch.exp(nll.mean())
```

</details>

In [None]:
def get_seq_ppl(nll: torch.Tensor) -> float:
    """
    Returns the perplexity of the whole sequence.

    Args:
        nll [seq_len - 1]: tensor of the nll

    Returns: the perplexity over the whole sequence.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.ppl_filters.task6(get_seq_ppl)

## Task 7: Sliding Window Perplexity

To perhaps better measure the impact of the adversarial suffix on perplexity, we can also use a sliding window to calculate the perplexity over a slice of the input tokens. In this task, you'll implement such a function.

<details>
<summary>💡 <b>Hint for Task #7</b></summary>

You should probably keep track of the `start` and `end` of the slice.

</details>

<details>
<summary>💡 <b>Hint for Task #7</b></summary>

Calculate each sliding window's perplexity in a `for` or `while` loop.

</details>


<details>
<summary>🔐 <b>Solution for Task #7</b></summary>

```python
def get_max_sliding_window_ppl(nll: torch.Tensor, window_size) -> float:
    """
    Returns the max perplexity over the `nll` tensor with the fixed window size.

    Args:
        nll [seq_len - 1]: tensor of the nll
        window_size: the size of the sliding window
    
    Returns: the maximum perplexity evaluated over the sliding window.
    """
    seq_len = len(nll)
    assert window_size <= seq_len, "Window size must be no greater than the sequence length"
    
    max_ppl = -float("inf")
    start = 0
    end = window_size

    while end <= seq_len:
        ppl = torch.exp(nll[start:end].mean())
        max_ppl = ppl if ppl > max_ppl else max_ppl
        start, end = start + 1, end + 1
    
    return max_ppl
```

</details>

In [None]:
def get_max_sliding_window_ppl(nll: torch.Tensor, window_size) -> float:
    """
    Returns the max perplexity over the `nll` tensor with the fixed window size.

    Args:
        nll [seq_len - 1]: tensor of the nll
        window_size: the size of the sliding window

    Returns: the maximum perplexity evaluated over the sliding window.
    """
    raise NotImplementedError()

In [None]:
_ = xlab.tests.ppl_filters.task7(get_max_sliding_window_ppl)

## Task 8: The `prompt_is_perplexing()` Function

Finally, we'll create the `prompt_is_perplexing()` function, which returns `True` if a prompt's perplexity is higher than a passed threshold (in addition to the prompt's perplexity; the return value is a tuple). 

<details>
<summary>💡 <b>Hint for Task #8</b></summary>

This function is only a few lines.

</details>

<details>
<summary>💡 <b>Hint for Task #8</b></summary>

First call `get_per_token_NLL()`, then determine which perplexity measurement function to call.

</details>


<details>
<summary>🔐 <b>Solution for Task #8</b></summary>

```python
def prompt_is_perplexing(
    prompt: str,
    threshold: float,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    windowed: bool = False,
    window_size: Optional[int] = None
) -> tuple[bool, float]:
    """
    Returns `True` if the perplexity of the prompt is above `threshold`, 
    returns `False` otherwise. To be used in a perplexity filter as an LLM 
    safeguard.

    Args:
        prompt: the prompt to evaluate
        threshold: the threshold above which to classify a prompt as perplexing
        model: the model
        tokenizer: the model's tokenizer
        window: whether or not to use a sliding window to measure perplexity
        window_size [Optional]: the size of the sliding window, if applicable
    
    Returns: tuple of (bool, perplexity), where perplexity is the prompt's 
        perplexity, with the bool `True` or `False` depending on if the PPL is 
        greater than `threshold`.
    """
    if windowed:
        assert window_size is not None, "Must pass window size if using sliding window!"

    nll = get_per_token_NLL(model=model, tokenizer=tokenizer, prompt=prompt)
    ppl = get_max_sliding_window_ppl(nll, window_size) if windowed else get_seq_ppl(nll)

    return (False, ppl) if ppl > threshold else (True, ppl)
```

</details>

In [None]:
def prompt_is_perplexing(
    prompt: str,
    threshold: float,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    windowed: Optional[bool] = False,
    window_size: Optional[int] = None,
) -> tuple[bool, float]:
    """
    Returns `True` if the perplexity of the prompt is above `threshold`,
    returns `False` otherwise. To be used in a perplexity filter as an LLM
    safeguard.

    Args:
        prompt: the prompt to evaluate
        threshold: the threshold above which to classify a prompt as perplexing
        model: the model
        tokenizer: the model's tokenizer
        window: whether or not to use a sliding window to measure perplexity
        window_size [Optional]: the size of the sliding window, if applicable

    Returns: tuple of (bool, perplexity), where perplexity is the prompt's
        perplexity, with the bool `True` or `False` depending on if the PPL is
        greater than `threshold`.
    """
    raise NotImplementedError()

Now that the filter function is working, we've provided a barebones example of using it. Hopefully, running the cell below you'll see that the perplexity of the first prompt is rather low (most likely less than 100), wheraeas the second prompt with its adversarial suffix will have a very high perplexity.

Feel free to play around more with these functions, perhaps also comparing the simple vs. sliding window implementations!

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


clean_prompt = "Make new friends, but keep the old; one is silver and the other's gold"
perplexing, ppl = prompt_is_perplexing(
    prompt=clean_prompt, threshold=500, model=model, tokenizer=tokenizer, windowed=False
)
if perplexing:
    print(f"The prompt '{clean_prompt}' is perplexing, with PPL {ppl}")
else:
    print(f"The prompt '{clean_prompt}' is not perplexing, with PPL {ppl}")


adv_prompt = "Make new friends, but keep the old describing. + similarlyNow write oppositeley.]( Me giving**ONE please? revert with --Two"
perplexing, ppl = prompt_is_perplexing(
    prompt=adv_prompt, threshold=500, model=model, tokenizer=tokenizer, windowed=False
)
if perplexing:
    print(f"The prompt '{adv_prompt}' is perplexing, with PPL {ppl}")
else:
    print(f"The prompt '{adv_prompt}' is not perplexing, with PPL {ppl}")