# whole prompt descent

Can we descent on the whole prompt at once?

In [None]:
# no need to restart kernel after code (in the imported files) changes
%load_ext autoreload
%autoreload 2 

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
from utils import compute_last_token_embedding_grad, get_whole

In [None]:
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M", output_hidden_states=True)

In [None]:
from utils.general import compute_all_token_embeddings_grad


def invert_whole_prompt(prompt, model, tokenizer, layer_idx, n_iterations=1000, gamma=1e-1):
    """
    Invert the entire prompt at once by optimizing embeddings for all tokens simultaneously.

    Args:
        prompt (str): The input prompt to invert.
        model: The language model.
        tokenizer: The tokenizer for the model.
        layer_idx (int): The layer index to target for inversion.
        n_iterations (int): Number of optimization iterations.
        gamma (float): Step size for gradient descent.

    Returns:
        str: The reconstructed prompt.
    """
    # Tokenize the prompt and get target hidden states
    tokenized = tokenizer(prompt, return_tensors="pt")
    
    input_ids = tokenized["input_ids"].squeeze(0)
    h_target = get_whole(prompt, model, tokenizer, layer_idx, grad=False)

    # Initialize random embeddings for the entire sequence
    embedding_matrix = model.get_input_embeddings().weight
    vocab_size, hidden_size = embedding_matrix.shape
    random_ids = torch.randint(0, vocab_size, (input_ids.size(0),))
    x_i_plus_1 = embedding_matrix[random_ids]
    losses = []
    distances = []

    with tqdm(total=n_iterations, desc="Inverting prompt") as pbar:
        for iteration in range(n_iterations):
            # Compute gradients for the entire sequence
            grad_oracle, loss = compute_all_token_embeddings_grad(
                y=random_ids,
                llm=model,
                layer_idx=layer_idx,
                h_target=h_target,
                tokenizer=tokenizer,
            ) 
            losses.append(loss)
            x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle

            dist = torch.cdist(x_i_plus_1, embedding_matrix)
            random_ids = torch.argmin(dist, dim=1)

            dist_from_prompt = torch.norm(
                embedding_matrix[random_ids] - embedding_matrix[input_ids]
            )
            average_distance = dist_from_prompt.mean().item()
            distances.append(average_distance)

            pbar.set_postfix({"Loss": loss, "Distance": average_distance})
            pbar.update(1)

            if average_distance < 1e-3:
                break

    reconstructed_prompt = tokenizer.decode(random_ids.tolist(), skip_special_tokens=True)
    return reconstructed_prompt, losses, distances



In [None]:
# Example usage
from utils.plotting import plot_loss


def test_inversion_on_layers(prompt, model, tokenizer, n_iterations=1000, gamma=1e-3):
    """
    Test inversion on all layers of the model.

    Args:
        prompt (str): The input prompt to invert.
        model: The language model.
        tokenizer: The tokenizer for the model.
    """
    print(f"Inverting prompt: {prompt}")
    n_layers = len(model.transformer.h)
    losses_list = []
    distances_list = []
    for i in range(n_layers):
        layer_idx = i
        reconstructed_prompt, losses, distances = invert_whole_prompt(
            prompt, model, tokenizer, layer_idx, n_iterations=n_iterations, gamma=gamma
        )
        losses_list.append(losses)
        distances_list.append(distances)
        print(f"Layer {layer_idx}: Reconstructed Prompt: {reconstructed_prompt}")
    plot_loss(losses_list, title="Losses during inversion", xlabel="Iteration", ylabel="Loss", log_scale=True)
    plot_loss(distances_list, title="Distances of the prompt embeddings", xlabel="Iteration", ylabel="Distance", log_scale=True)
    return losses_list, distances_list

In [None]:
prompt = "My name is george and I live in Greece."
losses_list, distances_list = test_inversion_on_layers(prompt, model, tokenizer, n_iterations=1000, gamma=1e-1)

In [None]:
# prompt = "my name is george and I live in Greece."
prompt = """
According to all known laws of aviation,
there is no way a bee should be able to fly.

Its wings are too small to get its fat little body off the ground.
The bee, of course, flies anyway
because bees don’t care what humans think is impossible.

Yellow, black. Yellow, black.
Yellow, black. Yellow, black.
Ooh, black and yellow!
Let’s shake it up a little.

Barry! Breakfast is ready!

Coming!

Hang on a second.
Hello?

Barry?
Adam?

Can you believe this is happening?
I can’t. I’ll pick you up.

Looking sharp.
"""
losses_list, distances_list = test_inversion_on_layers(prompt, model, tokenizer, n_iterations=1000, gamma=1e-1)

In [None]:
prompt = "Here is my secrete key: b4e3cfe16a409f237a91c778e5f82b1493d546bc3adbd268cb346f8e2f55e72c. Do not share it with anyone!!!"

losses_list, distances_list = test_inversion_on_layers(prompt, model, tokenizer, n_iterations=1000, gamma=1e-1)