In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import (
    load_sentences,
    load_or_run,
    plot_metrics, 
    plot_metrics_over_time,
    RESULTS_PATH
)
from matplotlib import pyplot as plt
import gc
import sys
import warnings
sys.path.append('..')
warnings.filterwarnings("ignore")
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
plt.rcParams['text.usetex'] = True
import pickle
import os
import numpy as np
TOKENIZERS_PARALLELISM=False

In [None]:
#constants
N = 100
SEED = 42  
FORCE_RECOMPUTE=False

In [None]:
gc.collect()

model_id = 'roneneldan/TinyStories-1M'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
sentences = load_sentences()

# Experiment 1

We want to show that this algorithm works perfectly at the first activation space but it goes nuts afterwards
    - sample n sentence from an ID dataset
        - compute BLUE score (or other sentence level distance metrics) vs L2 norm (score at the embedding level)
        - plot them across the model

In [None]:
from utils import test_prompt_reconstruction
path = os.path.join(RESULTS_PATH, f'ex1_results_{N}.pkl')

ex1_result = load_or_run(
    path, 
    test_prompt_reconstruction,
    model,
    tokenizer,
    sentences,
    seed=SEED,
    force=FORCE_RECOMPUTE
)

In [None]:
rename_dict = {
    'bleu' : 'BLEU',
    'rouge-l_f' : 'ROUGE-L-F',
    'bertscore_f1' : 'BERTScore-F1',
}

plot_metrics({key:ex1_result[key] for key in ex1_result if 'l2' not in key}, 
             fill_between=True, rename=rename_dict, xlabel='Layer', ylabel='Score', title='Prompt Reconstruction From Hidden States')
plot_metrics({key:ex1_result[key] for key in ex1_result if 'l2' in key}, 
             fill_between=True, rename=rename_dict, xlabel='Layer', ylabel='L2 Distance', title='L2 Distance From Hidden States')

# Experiment 2: 
We want to show how expensive is Exhaustive Search from practical stance. Pick a sentence or a few, and plot the performance in decoding it of some random search. On the y-axis blue score or l2  

In [None]:
from utils import run_experiment
from utils import exhaustive_search
from utils.general import consolidate_logs


K = 10  # number of sentences to process in exhaustive search
path = os.path.join(RESULTS_PATH, f'ex2_results_{N}_{K}.pkl')

ex2_results = load_or_run(
    path, 
    run_experiment,
    exhaustive_search,
    sentences[1:K+1],
    K,
    model,
    tokenizer,
    layer_idx=8,
    seed=8,
    force=FORCE_RECOMPUTE,
) 

In [None]:
from utils.constants import COLORS

fig, ax = plt.subplots(figsize=(10, 6))
for i, (_, res) in enumerate(ex2_results):
    l2, times, steps = res['l2_distance'], res['time'], res['step']
    if len(l2) != len(times) or len(l2) != len(steps):
        continue
    # scatter plot of losses vs steps
    # plt.scatter(steps, losses, label='Losses', color='blue')
    ax.plot(times, l2, color=COLORS[0])
ax.set_xlabel(r'Steps', fontsize=14)
ax.set_ylabel(r'Loss', fontsize=14)
ax.set_title(r'Exhaustive search on the 8th layer', fontsize=16)
ax.legend()


aggregated results of the second experiment

In [None]:
results_collected = {
    key: [x[1][key] for x in ex2_results] for key in ex2_results[0][1].keys()
}
# set losses to be arrays of the same length (max length of losses)
max_length = max(len(x) for x in results_collected['l2_distance'])
for key in results_collected:
    results_collected[key] = [x + [np.nan] * (max_length - len(x)) for x in results_collected[key]]
    results_collected[key] = np.array(results_collected[key])

for metric_name in ['bertscore_f1', 'l2_distance', 'rouge-l_f']:
    plot_metrics_over_time({
        metric_name: results_collected[metric_name],
        f'{metric_name}_time': results_collected['time'],
    }, xlabel='Time (s)', ylabel=f'{metric_name} score',
    title=f'Invert Whole Prompt: {metric_name} vs Time', 
    rename={metric_name: 'Whole Prompt GD'}, 
    fill_between=True,
    window_size=None,
)

In [None]:
def tokens_over_time(results):
    """Calculate the number of tokens found over time."""
    results_collected = {
        'losses': [x[1]['l2_distance'] for x in ex2_results],
        'times': [x[1]['time'] for x in ex2_results],
        'steps': [x[1]['step'] for x in ex2_results],
    }
    # set losses to be arrays of the same length (max length of losses)
    max_length = max(len(x) for x in results_collected['losses'])
    for key in results_collected:
        results_collected[key] = [x + [np.nan] * (max_length - len(x)) for x in results_collected[key]]
        results_collected[key] = np.array(results_collected[key])
        results_collected[key] = results_collected[key].T

    plot_metrics({
        # 'losses': results_collected['losses'],
        'times': results_collected['times'],
        # 'steps': results_collected['steps']
    }, xlabel='Tokens found', ylabel='Time (s)', title='Exhaustive Search Time per Step (layer 8)', rename={'times': 'Exhaustive search'}, fill_between=True)

tokens_over_time(ex2_results)


## Experiment 3:

Now we want to show how bad is Hardprompt. Ideally, we would like to take a sentence and to show how much time it takes in decoding it. Same plot as E2
    - Suggest that there is no possibility of escaping local minima
    - Suggest that the tightest converge we can prove is O(V^|n|) (`put it nicely and not formally`)
        - ***NB*** Do this experiment at each layer of the 8
        - ***NB*** Do this experiment with each algorithm
        -     - **E3.1 where you simply use our algorithm on the dataset of the other three and we want to show that it converges to perfect BLUE and L2**

In [None]:
# use invert_whole_prompt on layer 8 
from utils import gd_all_tokens

K = 10  # number of sentences to process in exhaustive search
max_iter = 10000  # maximum number of iterations for the optimization
lr = 1e-3  # learning rate for the optimization
path = os.path.join(RESULTS_PATH, f'ex3_results_{N}_{K}_{max_iter}_{lr}.pkl')

ex3_results = load_or_run(
    path, 
    run_experiment,
    gd_all_tokens,
    sentences[1:K+1], 
    K,
    model,
    tokenizer,
    layer_idx=8,
    lr=lr,
    log_freq=100,
    n_iterations=max_iter,
    force=FORCE_RECOMPUTE,
)

In [None]:
for metric_name in ['bertscore_f1', 'l2_distance', 'rouge-l_f']:
    # plot each sentence result
    fig, ax = plt.subplots(figsize=(10, 6))
    for i, (sent, res) in enumerate(ex3_results):
        losses, times, steps = res[metric_name], res['time'], res['step']
        # scatter plot of losses vs steps
        # plt.scatter(steps, losses, label='Losses', color='blue')
        ax.plot(times, losses, color=COLORS[3])
    ax.set_xlabel(r'Time (s)', fontsize=14)
    ax.set_ylabel(f'{metric_name} score', fontsize=14)
    ax.set_title(r'Invert Whole Prompt on the 8th Layer', fontsize=16)
    ax.legend()

In [None]:
results_collected = {
    key: [x[1][key] for x in ex3_results] for key in ex3_results[0][1].keys()
}
# set losses to be arrays of the same length (max length of losses)
max_length = max(len(x) for x in results_collected['l2_distance'])
for key in results_collected:
    results_collected[key] = [x + [np.nan] * (max_length - len(x)) for x in results_collected[key]]
    results_collected[key] = np.array(results_collected[key])
    results_collected[key] = results_collected[key].T

for metric_name in ['bertscore_f1', 'l2_distance', 'rouge-l_f']:
    plot_metrics_over_time({
        metric_name: results_collected[metric_name],
        f'{metric_name}_time': results_collected['time'],
    }, xlabel='Time (s)', ylabel=f'{metric_name} score',
    title=f'Invert Whole Prompt: {metric_name} vs Time', 
    rename={metric_name: 'Whole Prompt GD'}, 
    fill_between=True,
    window_size=None,
)

## Experiment 3.1 - comparison with exhaustive search


In [None]:
ex3_results_collected = {
    key: [x[1][key] for x in ex3_results] for key in ex3_results[0][1].keys()
}
# set losses to be arrays of the same length (max length of losses)
max_length = max(len(x) for x in ex3_results_collected['l2_distance'])
for key in ex3_results_collected:
    ex3_results_collected[key] = [x + [np.nan] * (max_length - len(x)) for x in ex3_results_collected[key]]
    ex3_results_collected[key] = np.array(ex3_results_collected[key])
    ex3_results_collected[key] = ex3_results_collected[key].T

ex2_results_collected = {
    key: [x[1][key] for x in ex2_results] for key in ex2_results[0][1].keys()
}
# set losses to be arrays of the same length (max length of losses)
max_length = max(len(x) for x in ex2_results_collected['l2_distance'])
for key in ex2_results_collected:
    ex2_results_collected[key] = [x + [np.nan] * (max_length - len(x)) for x in ex2_results_collected[key]]
    ex2_results_collected[key] = np.array(ex2_results_collected[key])
    ex2_results_collected[key] = ex2_results_collected[key].T

for metric_name in ['bertscore_f1', 'l2_distance', 'rouge-l_f']:
    plot_metrics_over_time({
        f'ex2_{metric_name}': ex2_results_collected[metric_name],
        f'ex3_{metric_name}': ex3_results_collected[metric_name],
        f'ex2_{metric_name}_time': ex2_results_collected['time'],
        f'ex3_{metric_name}_time': ex3_results_collected['time'],
    }, xlabel='Time (s)', ylabel=f'{metric_name} score',
    title=f'Exhaustive Search vs GD: {metric_name} vs Time',
    rename={
        f'ex2_{metric_name}': 'Exhaustive Search',
        f'ex3_{metric_name}': 'GD - Full Prompt',
    },
    fill_between=True,
    window_size=None,
)