# whole prompt descent

Can we descent on the whole prompt at once?

In [None]:
# no need to restart kernel after code (in the imported files) changes
%load_ext autoreload
%autoreload 2 

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
from utils import compute_last_token_embedding_grad, get_whole

In [None]:
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M", output_hidden_states=True)

In [None]:
from utils.general import compute_all_token_embeddings_grad


def invert_whole_prompt(prompt, model, tokenizer, layer_idx, n_iterations=1000, gamma=1e-1):
    """
    Invert the entire prompt at once by optimizing embeddings for all tokens simultaneously.

    Args:
        prompt (str): The input prompt to invert.
        model: The language model.
        tokenizer: The tokenizer for the model.
        layer_idx (int): The layer index to target for inversion.
        n_iterations (int): Number of optimization iterations.
        gamma (float): Step size for gradient descent.

    Returns:
        str: The reconstructed prompt.
    """
    # Tokenize the prompt and get target hidden states
    tokenized = tokenizer(prompt, return_tensors="pt")
    
    input_ids = tokenized["input_ids"].squeeze(0)
    h_target = get_whole(prompt, model, tokenizer, layer_idx, grad=False)

    # Initialize random embeddings for the entire sequence
    embedding_matrix = model.get_input_embeddings().weight
    vocab_size, hidden_size = embedding_matrix.shape
    random_ids = torch.randint(0, vocab_size, (input_ids.size(0),))
    x_i_plus_1 = embedding_matrix[random_ids]

    with tqdm(total=n_iterations, desc="Inverting prompt") as pbar:
        for iteration in range(n_iterations):
            # Compute gradients for the entire sequence
            grad_oracle = torch.zeros_like(x_i_plus_1)
            loss = 0
            for i in range(input_ids.size(0)):
                grad, token_loss = compute_last_token_embedding_grad(
                    y=random_ids[:i + 1],
                    llm=model,
                    layer_idx=layer_idx,
                    h_target=h_target[i],
                    tokenizer=tokenizer,
                )
                grad_oracle[i] = grad
                loss += token_loss
            # print(grad_oracle.shape, loss) # [seq_len, hidden_size]
            # use the whole sequence at once
            grad_oracle, loss = compute_all_token_embeddings_grad(
                y=random_ids,
                llm=model,
                layer_idx=layer_idx,
                h_target=h_target,
                tokenizer=tokenizer,
            ) 
            # print(grad_oracle.shape, loss) # [vocab_size, hidden_size]
            # print(x_i_plus_1.shape, embedding_matrix.shape)
            # Update embeddings using gradient descent
            x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle

            # Find the closest tokens in the embedding space
            distances = torch.cdist(x_i_plus_1, embedding_matrix)
            random_ids = torch.argmin(distances, dim=1)

            # Decode the current reconstruction
            reconstructed_prompt = tokenizer.decode(random_ids.tolist(), skip_special_tokens=True)
            pbar.set_postfix({"Loss": loss, "Prompt": reconstructed_prompt})
            pbar.update(1)

            # Early stopping if the reconstruction matches the original prompt
            if reconstructed_prompt == prompt:
                break

    return reconstructed_prompt

# Example usage
prompt = "my name is george and I live in Greece."
layer_idx = 8
reconstructed_prompt = invert_whole_prompt(prompt, model, tokenizer, layer_idx)
print(f"Original: {prompt}")
print(f"Reconstructed: {reconstructed_prompt}")