## Controlling LLM output with Logit Bias


<a href="https://help.openai.com/en/articles/5247780-using-logit-bias-to-define-token-probability">OpenAI Logit Bias Article</a>

Here you'll go through some basic logit bias exercizes for understanding logit bias with the OpenAI API.

We use the LogitBias class defined in logit_bias.py to create the bias map.


<img src="/Users/samuel.shapley/projects/AI/SemanticGPT/images/definition.png" width="750" height="300">




In [None]:
## Import necessary libraries

import numpy as np
import pandas as pd
import json
import textwrap
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
import time
import random
import os
import re

# Get the current working directory
cwd = os.getcwd()
# Get the parent directory of the current directory
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)

# Now Python knows where to find logit_bias package
from logit_bias import LogitBias

# Define the API key and model
api_key = 'API_KEY'

The below code shows the most basic implementation of logit bias. By supressing a given phrase (BIAS = -100), we can explore how LLM's alter their responses once a certain section of the probability distribution has been restricted.

In [None]:
# Initialize the LogitBias class
logit_bias_generator = LogitBias(
    api_key=api_key, 
    model='gpt-4', 
    suppressed_phrases=['2','two'],
    bias=-100,
    request_timeout=100
)

# Generate a single response using the generate_response method
logit_bias_generator.generate_response(
    prompt = "How many hands do humans have?",
    system_message="PAY ATTENTION TO THE QUESTION",
    temperature=0
)

We can now use this function to do quick experiments with language models.

#### Experiment 1: Arithmetic

What if we stop the language model from getting the right answer in basic addition questions? 
How close does it get?

Here is a high level overview of the following code:

1. Specify alphabet and special characters to supress
2. Generate a list of all possible sums of distinct integer pairs from 1 - 1000.
3. Choose a 10K random sample (this was for large statistics in the demo and takes ~1.5hr for ~£5)
4. Prompt the model with "a + b = " and supress the answer and numbers to sum.
5. Record the response, extracting the numbers and handling errors if any occur.

The result of this experiment is found in <code>arithmetic_pairs.json</code>, which is explored in the subsequent cell.

In [None]:
### Get number pairs from uniform distribution.

num_samples = 2000
sums = {}

while len(sums) < num_samples:
    a = random.randint(1, 100000)
    b = random.randint(1, 100000)
    sum_ab = a + b
    if sum_ab not in sums:
        sums[sum_ab] = [a, b, sum_ab]

sums = list(sums.values())

sums = random.sample(sums,num_samples)

# get the sums only from the 'sums' list
sum_values = [sum[2] for sum in sums] 

plt.figure(figsize=(6,4))
plt.hist(sum_values, bins=50, edgecolor='black')
plt.title('Distribution of Sums')
plt.xlabel('Sum')
plt.ylabel('Frequency')
plt.show()

In [None]:
alphabet = [
    'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W',
    'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',
    'u', 'v', 'w', 'x', 'y', 'z'
]

special_characters = [
    '{', '}', ',', '?', '*', '=', '.', '(', ')', '[', ']', '!', '+', '/', '%', '^', '_', '"', "'"
]


system_message = """You add two numbers together. Under no circumstances should you use any other characters than 0-9.
Respond with a single number.
"""

results = []
for i, (a, b, sum) in enumerate(tqdm(sums)):
    prompt = f"{a} + {b} ="
    try:
        logit_bias_generator = LogitBias(
            api_key=api_key, 
            model='gpt-4', 
            suppressed_phrases=[f"{sum}",f"{a}",f"{b}",f"{a}{b}"]+alphabet+special_characters,
            bias=-100,
            request_timeout=10)
        response = logit_bias_generator.generate_response(
            prompt=prompt,
            temperature=0)
    except Exception as e:
        results.append({
            'first_number': a,
            'second_number': b,
            'correct_sum': sum,
            'model_response': None,
            'extracted_numbers': None,
            'model_breakdown': True,
            'error': str(e)
        })
        continue

    numbers = "".join(c for c in response if c.isdigit())

    # print(str(sum), numbers)

    results.append({
        'first_number': a,
        'second_number': b,
        'correct_sum': sum,
        'model_response': response,
        'extracted_numbers': numbers,
        'model_breakdown': False,
        'error': None
    })

    if i % 5 == 0:
        with open('arithmetic_pairs_100K_test.json', 'w') as json_file:
            json.dump(results, json_file)


Now we've got our arithmetic_pairs JSON, load it into a dataframe.

In [None]:
# load the data from the JSON file
with open('arithmetic_pairs_100k.json', 'r') as json_file:
    data = json.load(json_file)

# create a DataFrame from the data
df = pd.DataFrame(data)

In [None]:
def convert_to_int(x):
    try:
        return int(x)
    except ValueError:
        return None

# make a copy of the original dataframe
df_clean = df.copy()

# create a mask for entries without a model breakdown
mask = ~df_clean['model_breakdown']

# convert the extracted numbers to integers
df_clean.loc[mask, 'extracted_numbers'] = df_clean.loc[mask, 'extracted_numbers'].apply(convert_to_int)



# compute the difference
df_clean.loc[mask, 'difference'] = abs(df_clean.loc[mask, 'correct_sum'] - df_clean.loc[mask, 'extracted_numbers'])

df_clean.loc[mask, 'difference_proportion'] = df_clean.loc[mask, 'difference'] / df_clean.loc[mask, 'correct_sum']

# plot the difference against the correct sum
plt.figure(figsize=(9, 4))
plt.scatter(df_clean['correct_sum'], df_clean['difference'],marker='.')
plt.xlabel('Correct Sum')
plt.ylabel('Difference')
plt.title('Error Size against True Answer Size')
plt.grid(True)
plt.show()


### Strangely, the differences follow three distinct lines.

In [None]:
import random

def average_difference():
    # Generate 1000K random numbers up to 200K
    orig_nums = [random.randint(1, 200000) for _ in range(10000000)]
    
    # Create a list of numbers by duplicating first digit
    trans_nums = [int(str(num)[0] + str(num)) for num in orig_nums]
    
    # Calculate and return the average difference
    diff = [(trans - orig)/orig for trans, orig in zip(trans_nums, orig_nums)]
    return sum(diff) / len(diff)

print(average_difference())


In [None]:
# Show all the ones where difference/sum is greater than 1
df_clean[df_clean['difference_proportion'] > 1]

### The error size grows almost predictably, as the errors made by gpt-4 are predictable.

In [None]:
# Show all the ones where difference is the same as the correct sum
df_clean[df_clean['difference'] == df_clean['correct_sum']]

#### Experiment 2: Language

What does an LLM say if it can't? 

Ask GPT to repeat a word, and it will do so without hesitatation. However, if the word to repeat is suppressed,
the model is forced to choose a different answer. 

This is set up in <code>suppression_loop</code>. Choose a new supression word and run the cell to see how GPT views words as similar.


In [None]:
def suppression_loop(API_KEY: str, MODEL: str, suppression_word: str, temperature: float=0, request_timeout: int=10) -> None:
    BIAS = -100

    suppressed_phrases = [suppression_word]
    logit_bias_generator = LogitBias(API_KEY, MODEL, suppressed_phrases, BIAS, request_timeout=request_timeout)

    system_message = "You can only produce real single words. Repeat the word you see in the prompt."
    
    while True:
        PROMPT = f"{suppression_word}"
        response = logit_bias_generator.generate_response(PROMPT, temperature, system_message)
        print(response)
        suppressed_phrases.append(response)
        logit_bias_generator = LogitBias(API_KEY, MODEL, suppressed_phrases, BIAS, request_timeout=request_timeout)

MODEL = "gpt-4"
suppression_word = "world"
suppression_loop(api_key, MODEL, suppression_word, request_timeout=10)


Create your own word network!

Type in a word and watch the network grow as GPT explores similar possibilities.