In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import compute_last_token_embedding_grad_emb, get_whole, set_seed, plot_metrics, RESULTS_PATH
from datasets import load_dataset
from matplotlib import pyplot as plt
from utils.metrics import compute_metrics
import gc
from time import time
from tqdm import tqdm
import sys
import warnings
sys.path.append('..')
warnings.filterwarnings("ignore")
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()
plt.rcParams['text.usetex'] = True
import pickle
import os

In [None]:
gc.collect()

model_id = 'roneneldan/TinyStories-1M'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
promts = [
    'my secret key is big and long, it is 1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ',
    '12autoZeinai ena~~ !poli, a1212kiro pr33-=ompt tao op"\oio ;::/>elpizo na d1212isko1212leyt5646ei na ma77ntepsei to montelo',
]

## Dataset

In [None]:
dataset = load_dataset("bookcorpus", split="train")

In [None]:
# sentences = [x['text'] for x in dataset if x['text'].strip()]
# print(sentences[:5])
set_seed(42)
N = 100  # number of sentences to process
#select a permuted subset of the dataset
sentences = dataset.shuffle(seed=42).select(range(N))['text']

In [None]:
sentences[:20]

# Experiment 1

We want to show that this algorithm works perfectly at the first activation space but it goes nuts afterwards
    - sample n sentence from an ID dataset
        - compute BLUE score (or other sentence level distance metrics) vs L2 norm (score at the embedding level)
        - plot them across the model

In [None]:
def test_prompt_reconstruction(llm, tokenizer, dataset, seed=8):

    embedding_matrix = llm.get_input_embeddings().weight
    metrics = []
    # average scores over the dataset
    sentense_embeddings = []
    for sentence in tqdm(dataset, desc="Computing sentence embeddings"):
        set_seed(seed)
        input_ids = tokenizer(sentence, return_tensors='pt').input_ids
        with torch.no_grad():
            embedding = llm.get_input_embeddings()(input_ids)[0]
            embedding = embedding.numpy()
        sentense_embeddings.append(embedding)
    
    for layer_idx in range(llm.config.num_hidden_layers):
        output_sentences = []
        output_embeddings = []
        for prompt in tqdm(dataset, desc="Processing prompts"):
            set_seed(seed)
            h_target = get_whole(prompt, model, tokenizer,layer_idx=layer_idx)
            output_tokens = (h_target @ embedding_matrix.T).argmax(dim=-1)
            output_sentence = tokenizer.decode(output_tokens, skip_special_tokens=True)
            output_sentences.append(output_sentence)
            output_embeddings.append(embedding_matrix[output_tokens].detach().numpy())
        metrics.append(compute_metrics(sentences, sentense_embeddings, output_sentences, output_embeddings))
    merge_metrics = {}
    for metric in metrics:
        for key, value in metric.items():
            if key not in merge_metrics:
                merge_metrics[key] = []
            merge_metrics[key].append(value)
    return merge_metrics

In [None]:
path = os.path.join(RESULTS_PATH, f'ex1_results_{N}.pkl')
K = 100  # number of sentences to process in exhaustive search

if not os.path.exists(RESULTS_PATH):
    os.makedirs(RESULTS_PATH)
if os.path.exists(path):
    with open(path, 'rb') as f:
        ex1_result = pickle.load(f)
else:
    ex1_result = test_prompt_reconstruction(model, tokenizer, sentences, seed=42)
    with open(path, 'wb') as f:
        pickle.dump(ex1_result, f)

In [None]:
plot_metrics({key:ex1_result[key] for key in ex1_result if 'l2' not in key}, fill_between=True)
plot_metrics({key:ex1_result[key] for key in ex1_result if 'l2' in key}, fill_between=True)

# Experiment 2: 
We want to show how expensive is Exhaustive Search from practical stance. Pick a sentence or a few, and plot the performance in decoding it of some random search. On the y-axis blue score or l2  

In [None]:
# exhaustive search is we search for each token one by one, minimizing the loss
# search for each token in the sentence
# mesure computation time and steps used 
def exhaustive_search(llm, tokenizer, prompt, layer_idx=0, seed=42, eps=1e-3):
    print(f"Exhaustive search for prompt: {prompt} on layer {layer_idx}")
    set_seed(seed)
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    embedding_matrix = llm.get_input_embeddings().weight
    h_target = get_whole(prompt, llm, tokenizer, layer_idx=layer_idx)
    
    output_tokens = []
    time_start = time()
    step = 0
    losses = []
    times = []
    steps = []
    for i in range(input_ids.shape[1]):
        min_loss = float('inf')
        best_token = None
        bar = tqdm(range(embedding_matrix.shape[0]), desc=f"Finding token {i+1}/{input_ids.shape[1]}")
        for j in bar:
            step += 1
            # try each token in the vocabulary
            # compute the loss for the current token
            current_tokens = output_tokens + [j]
            current_tokens = torch.tensor(current_tokens).unsqueeze(0)
            h = get_whole("", llm, tokenizer, layer_idx=layer_idx, input_ids=current_tokens)
            loss = (h_target[i] - h[i]).norm()
            if loss < min_loss:
                min_loss = loss
                best_token = j
            bar.set_postfix({'loss': min_loss.item(), 'token': tokenizer.decode(best_token)})
            if loss < eps:
                break

        # add 0 to h in the dimensions of not-yet-decoded tokens so it matches the target
        h_expanded = torch.zeros_like(h_target)
        h_expanded[:i+1] = h[:i+1]
        loss = (h_target - h_expanded).norm().item() / (h_target.shape[0] * h_target.shape[1])

        losses.append(loss)
        times.append(time() - time_start)
        steps.append(step)

        output_tokens.append(best_token)
        sentence = tokenizer.decode(output_tokens, skip_special_tokens=True)
        print(f"Token {i+1}/{input_ids.shape[1]}: {tokenizer.decode(best_token)} | Loss: {loss:.4f} | Time: {times[-1]:.2f}s | Steps: {step} | Sentence: {sentence}")

    output_sentence = tokenizer.decode(output_tokens, skip_special_tokens=True)

    return output_sentence, losses, times, steps

    

In [None]:

K = 10  # number of sentences to process in exhaustive search
path = os.path.join(RESULTS_PATH, 'ex2_results_{N}_{K}.pkl')

if not os.path.exists(RESULTS_PATH):
    os.makedirs(RESULTS_PATH)
if os.path.exists(path):
    with open(path, 'rb') as f:
        ex2_results = pickle.load(f)
else:
    ex2_results = []
    for i, prompt in enumerate(sentences):
        if i > K:
            break
        ex2_result = exhaustive_search(model, tokenizer, prompt, layer_idx=8, seed=8)
        ex2_results.append(ex2_result)
    with open(path, 'wb') as f:
        pickle.dump(ex2_results, f)

In [None]:
from utils.constants import COLORS

fig, ax = plt.subplots(figsize=(10, 6))
for i, res in enumerate(ex2_results):
    losses, times, steps = res[1], res[2], res[3]
    # scatter plot of losses vs steps
    # plt.scatter(steps, losses, label='Losses', color='blue')
    ax.plot(steps, losses, color=COLORS[0])
ax.set_xlabel(r'Steps', fontsize=14)
ax.set_ylabel(r'Loss', fontsize=14)
ax.set_title(r'Exhaustive search on the 8th layer', fontsize=16)
ax.legend()
