In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import (
    extract_hidden_states_prompt, 
    extract_hidden_states_ids,
    set_seed, plot_metrics, 
    RESULTS_PATH
)
from datasets import load_dataset
from matplotlib import pyplot as plt
from utils.metrics import compute_metrics
import gc
from time import time
from tqdm import tqdm
import sys
import warnings
sys.path.append('..')
warnings.filterwarnings("ignore")
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
plt.rcParams['text.usetex'] = True
import pickle
import os
import numpy as np
TOKENIZERS_PARALLELISM=False

In [None]:
gc.collect()

model_id = 'roneneldan/TinyStories-1M'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
promts = [
    'my secret key is big and long, it is 1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ',
    '12autoZeinai ena~~ !poli, a1212kiro pr33-=ompt tao op"\oio ;::/>elpizo na d1212isko1212leyt5646ei na ma77ntepsei to montelo',
]

## Dataset

In [None]:
dataset = load_dataset("bookcorpus", split="train")

In [None]:
# sentences = [x['text'] for x in dataset if x['text'].strip()]
# print(sentences[:5])
set_seed(42)
N = 100  # number of sentences to process
#select a permuted subset of the dataset
sentences = dataset.shuffle(seed=42).select(range(N))['text']

In [None]:
sentences[:20]

# Experiment 1

We want to show that this algorithm works perfectly at the first activation space but it goes nuts afterwards
    - sample n sentence from an ID dataset
        - compute BLUE score (or other sentence level distance metrics) vs L2 norm (score at the embedding level)
        - plot them across the model

In [None]:
def test_prompt_reconstruction(llm, tokenizer, dataset, seed=8):

    embedding_matrix = llm.get_input_embeddings().weight
    metrics = []
    # average scores over the dataset
    sentense_embeddings = []
    for sentence in tqdm(dataset, desc="Computing sentence embeddings"):
        set_seed(seed)
        input_ids = tokenizer(sentence, return_tensors='pt').input_ids
        with torch.no_grad():
            embedding = llm.get_input_embeddings()(input_ids)[0]
            embedding = embedding.numpy()
        sentense_embeddings.append(embedding)
    
    for layer_idx in range(llm.config.num_hidden_layers):
        output_sentences = []
        output_embeddings = []
        for prompt in tqdm(dataset, desc="Processing prompts"):
            set_seed(seed)
            h_target = extract_hidden_states_prompt(prompt, model, tokenizer,layer_idx=layer_idx)
            output_tokens = (h_target @ embedding_matrix.T).argmax(dim=-1)
            output_sentence = tokenizer.decode(output_tokens, skip_special_tokens=True)
            output_sentences.append(output_sentence)
            output_embeddings.append(embedding_matrix[output_tokens].detach().numpy())
        metrics.append(compute_metrics(sentences, sentense_embeddings, output_sentences, output_embeddings))
    merge_metrics = {}
    for metric in metrics:
        for key, value in metric.items():
            if key not in merge_metrics:
                merge_metrics[key] = []
            merge_metrics[key].append(value)
    return merge_metrics

In [None]:
path = os.path.join(RESULTS_PATH, f'ex1_results_{N}.pkl')

if not os.path.exists(RESULTS_PATH):
    os.makedirs(RESULTS_PATH)
if os.path.exists(path):
    with open(path, 'rb') as f:
        ex1_result = pickle.load(f)
else:
    ex1_result = test_prompt_reconstruction(model, tokenizer, sentences, seed=42)
    with open(path, 'wb') as f:
        pickle.dump(ex1_result, f)

In [None]:
plot_metrics({key:ex1_result[key] for key in ex1_result if 'l2' not in key}, fill_between=True)
plot_metrics({key:ex1_result[key] for key in ex1_result if 'l2' in key}, fill_between=True)

# Experiment 2: 
We want to show how expensive is Exhaustive Search from practical stance. Pick a sentence or a few, and plot the performance in decoding it of some random search. On the y-axis blue score or l2  

In [None]:
# exhaustive search is we search for each token one by one, minimizin the loss
# search for each token in the sentence
# mesure computation time and steps used 


def exhaustive_search(llm, tokenizer, prompt, layer_idx=0, seed=42, eps=1e-3):
    print(f"Exhaustive search for prompt: {prompt} on layer {layer_idx}")
    set_seed(seed)
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    embedding_matrix = llm.get_input_embeddings().weight
    h_target = extract_hidden_states_prompt(prompt, llm, tokenizer, layer_idx=layer_idx)
    
    output_tokens = []
    time_start = time()
    step = 0
    l2s = []
    times = []
    steps = []
    for i in range(input_ids.shape[1]):
        min_loss = float('inf')
        best_token = None
        bar = tqdm(range(embedding_matrix.shape[0]), desc=f"Finding token {i+1}/{input_ids.shape[1]}")
        for j in bar:
            step += 1
            # try each token in the vocabulary
            # compute the loss for the current token
            current_tokens = output_tokens + [j]
            current_tokens = torch.tensor(current_tokens).unsqueeze(0)
            h = extract_hidden_states_ids(current_tokens, llm, layer_idx)
            loss = (h_target[i] - h[i]).norm()
            if loss < min_loss:
                min_loss = loss
                best_token = j
            bar.set_postfix({'loss': min_loss.item(), 'token': tokenizer.decode(best_token)})
            if loss < eps:
                break

        # add 0 to h in the dimensions of not-yet-decoded tokens so it matches the target
        h_expanded = torch.zeros_like(h_target)
        h_expanded[:i+1] = h[:i+1]
        l2 = (h_target - h_expanded).norm(p=2, dim=1).mean().item()

        l2s.append(l2)
        times.append(time() - time_start)
        steps.append(step)

        output_tokens.append(best_token)
        sentence = tokenizer.decode(output_tokens, skip_special_tokens=True)
        print(f"Token {i+1}/{input_ids.shape[1]}: {tokenizer.decode(best_token)} | L2: {l2:.4f} | Time: {times[-1]:.2f}s | Steps: {step} | Sentence: {sentence}")

    output_sentence = tokenizer.decode(output_tokens, skip_special_tokens=True)

    return output_sentence, l2s, times, steps


In [None]:
K = 10  # number of sentences to process in exhaustive search
path = os.path.join(RESULTS_PATH, f'ex2_results_{N}_{K}.pkl')

if not os.path.exists(RESULTS_PATH):
    os.makedirs(RESULTS_PATH)
if os.path.exists(path):
    with open(path, 'rb') as f:
        ex2_results = pickle.load(f)
else:
    ex2_results = []
    for i, prompt in enumerate(sentences):
        if i > K:
            break
        ex2_result = exhaustive_search(model, tokenizer, prompt, layer_idx=8, seed=8)
        ex2_results.append(ex2_result)
    with open(path, 'wb') as f:
        pickle.dump(ex2_results, f)

In [None]:
from utils.constants import COLORS

fig, ax = plt.subplots(figsize=(10, 6))
for i, res in enumerate(ex2_results):
    losses, times, steps = res[1], res[2], res[3]
    # scatter plot of losses vs steps
    # plt.scatter(steps, losses, label='Losses', color='blue')
    ax.plot(steps, losses, color=COLORS[0])
ax.set_xlabel(r'Steps', fontsize=14)
ax.set_ylabel(r'Loss', fontsize=14)
ax.set_title(r'Exhaustive search on the 8th layer', fontsize=16)
ax.legend()


aggregated results of the second experiment

In [None]:
results_collected = {
    'losses': [x[1] for x in ex2_results],
    'times': [x[2] for x in ex2_results],
    'steps': [x[3] for x in ex2_results],
}
# set losses to be arrays of the same length (max length of losses)
max_length = max(len(x) for x in results_collected['losses'])
for key in results_collected:
    results_collected[key] = [x + [np.nan] * (max_length - len(x)) for x in results_collected[key]]
    results_collected[key] = np.array(results_collected[key])
    results_collected[key] = results_collected[key].T


plot_metrics({
    # 'losses': results_collected['losses'],
    'times': results_collected['times'],
    # 'steps': results_collected['steps']
}, xlabel='Tokens found', ylabel='Time (s)', title='Exhaustive Search Time per Step (layer 8)', rename={'times': 'Exhaustive search'}, fill_between=True)

In [None]:
# plot time on the x-axis and number of tokens found on the y-axis
# results have to be pivoted, currently one entrly per token found

import pandas as pd

results_pivoted = {
    'times': [],
    'steps': [],
    'tokens_found': [],
    'losses': [],
}
for i, res in enumerate(ex2_results):
    times, steps, losses = res[2], res[3], res[1]
    results_pivoted['times'].append(times)
    results_pivoted['steps'].append(steps)
    results_pivoted['tokens_found'].append(np.arange(1, len(losses) + 1)/len(losses))
    results_pivoted['losses'].append(losses)

results_pivoted = pd.DataFrame(results_pivoted)
results_pivoted = results_pivoted.explode(['times', 'steps', 'tokens_found', 'losses'])

# scatter plot of time (x-axis) vs loss (y-axis)
plt.figure(figsize=(10, 6))
# color by the fraction of tokens found
# plt.scatter(results_pivoted['times'], results_pivoted['losses'], 
#             c=results_pivoted['tokens_found'], cmap='viridis', alpha=0.7, edgecolors='w', s=50)
# plt.colorbar(label='Fraction of Tokens Found')
plt.plot(results_pivoted['times'], results_pivoted['losses'], 
            color=COLORS[0], marker="o"
)
plt.xlabel('Time (s)', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.title('Exhaustive Search: Time vs Loss', fontsize=16)

## Experiment 3:

Now we want to show how bad is Hardprompt. Ideally, we would like to take a sentence and to show how much time it takes in decoding it. Same plot as E2
    - Suggest that there is no possibility of escaping local minima
    - Suggest that the tightest converge we can prove is O(V^|n|) (`put it nicely and not formally`)
        - ***NB*** Do this experiment at each layer of the 8
        - ***NB*** Do this experiment with each algorithm
        -     - **E3.1 where you simply use our algorithm on the dataset of the other three and we want to show that it converges to perfect BLUE and L2**

In [None]:
# use invert_whole_prompt on layer 8 
from utils import invert_whole_prompt

K = 10  # number of sentences to process in exhaustive search
max_iter = 3000  # maximum number of iterations for the optimization
lr = 1e-3  # learning rate for the optimization
path = os.path.join(RESULTS_PATH, f'ex3_results_{N}_{K}_{max_iter}_{lr}_tmp.pkl')

if not os.path.exists(RESULTS_PATH):
    os.makedirs(RESULTS_PATH)
if os.path.exists(path):
    with open(path, 'rb') as f:
        ex3_results = pickle.load(f)
else:
    ex3_results = []
    for i, prompt in enumerate(sentences):
        if i >= K:
            break
        ex3_result = invert_whole_prompt(prompt, model, tokenizer, layer_idx=8, n_iterations=max_iter, lr=lr, log_freq=10)
        ex3_results.append(ex3_result)
    with open(path, 'wb') as f:
        pickle.dump(ex3_results, f)

In [None]:
ex3_results[0][1].keys()

In [None]:
for metric_name in ['bertscore_f1', 'l2_distance', 'rouge-l_f']:
    # plot each sentence result
    fig, ax = plt.subplots(figsize=(10, 6))
    for i, (sent, res) in enumerate(ex3_results):
        losses, times, steps = res[metric_name], res['time'], res['step']
        # scatter plot of losses vs steps
        # plt.scatter(steps, losses, label='Losses', color='blue')
        ax.plot(times, losses, color=COLORS[3])
    ax.set_xlabel(r'Time (s)', fontsize=14)
    ax.set_ylabel(f'{metric_name} score', fontsize=14)
    ax.set_title(r'Invert Whole Prompt on the 8th Layer', fontsize=16)
    ax.legend()

In [None]:
results_collected = {
    key: [x[1][key] for x in ex3_results] for key in ex3_results[0][1].keys()
    # 'l2': [x[1]['l2_distance'] for x in ex3_results],
    # 'times': [x[1]['time'] for x in ex3_results],
    # 'steps': [x[1]['step'] for x in ex3_results],
}
# set losses to be arrays of the same length (max length of losses)
max_length = max(len(x) for x in results_collected['l2_distance'])
for key in results_collected:
    results_collected[key] = [x + [np.nan] * (max_length - len(x)) for x in results_collected[key]]
    results_collected[key] = np.array(results_collected[key])
    results_collected[key] = results_collected[key].T

for metric_name in ['bertscore_f1', 'l2_distance', 'rouge-l_f']:
    plot_metrics({
        'times': results_collected[metric_name],
    }, xlabel='Steps x100', ylabel='Time (s)', title=f'Wholoe prompt descent (layer 8) - {metric_name}', rename={'times': 'Exhaustive search'}, fill_between=True)