In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.general import compute_last_token_embedding_grad_emb, get_whole, set_seed
from datasets import load_dataset
from matplotlib import pyplot as plt
from utils.metrics import compute_metrics
import gc
from time import time
from tqdm import tqdm
import sys
import warnings
sys.path.append('..')
warnings.filterwarnings("ignore")
from transformers import logging as transformers_logging
transformers_logging.set_verbosity_error()

In [None]:
gc.collect()

model_id = 'roneneldan/TinyStories-1M'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
promts = [
    'my secret key is big and long, it is 1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ',
    '12autoZeinai ena~~ !poli, a1212kiro pr33-=ompt tao op"\oio ;::/>elpizo na d1212isko1212leyt5646ei na ma77ntepsei to montelo',
]

## Dataset

In [None]:
dataset = load_dataset("bookcorpus", split="train")

In [None]:
# sentences = [x['text'] for x in dataset if x['text'].strip()]
# print(sentences[:5])
set_seed(42)
N = 100  # number of sentences to process
#select a permuted subset of the dataset
sentences = dataset.shuffle(seed=42).select(range(N))['text']

In [None]:
sentences[:20]

# Experiment 1

We want to show that this algorithm works perfectly at the first activation space but it goes nuts afterwards
    - sample n sentence from an ID dataset
        - compute BLUE score (or other sentence level distance metrics) vs L2 norm (score at the embedding level)
        - plot them across the model

In [None]:
def test_prompt_reconstruction(llm, tokenizer, dataset, seed=8):

    embedding_matrix = llm.get_input_embeddings().weight
    metrics = []
    # average scores over the dataset
    for layer_idx in range(llm.config.num_hidden_layers):
        output_sentences = []
        output_embeddings = []
        for prompt in tqdm(dataset, desc="Processing prompts"):
            set_seed(seed)
            h_target = get_whole(prompt, model, tokenizer,layer_idx=layer_idx)
            embedding_matrix.shape
            output_tokens = (h_target @ embedding_matrix.T).argmax(dim=-1)
            output_sentence = tokenizer.decode(output_tokens, skip_special_tokens=True)
            output_sentences.append(output_sentence)
            output_embeddings.append(embedding_matrix[output_tokens])
        metrics.append(compute_metrics(sentences, output_sentences, output_embeddings))
    merge_metrics = {}
    for metric in metrics:
        for key, value in metric.items():
            if key not in merge_metrics:
                merge_metrics[key] = []
            merge_metrics[key].append(value)
    print(merge_metrics)
    return merge_metrics
ex1_result = test_prompt_reconstruction(model, tokenizer, sentences, seed=42)

In [None]:
from utils.constants import COLORS

def plot_metrics(metrics):
    fig, ax = plt.subplots(figsize=(10, 6))
    for i, (key, values) in enumerate(metrics.items()):
        ax.plot(range(len(values)), values, label=key, color=COLORS[i % len(COLORS)], linewidth=2, marker='o')
    ax.set_xlabel(r'Layer Index', fontsize=14)
    ax.set_ylabel(r'Metric Value', fontsize=14)
    ax.set_title(r'Metrics per Layer', fontsize=16)
    ax.legend()
    plt.show()
plot_metrics(ex1_result)