In [1]:
# import from huggingface roneneldan/TinyStories-1M
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from nnsight import LanguageModel
import torch as t
# load garbage collection and empty cache 
import gc

# UNEMBEDDING BASELINE

In [2]:
def clean():
    gc.collect()
    # empty_cache()

model_id = "roneneldan/TinyStories-1M"
# model_id = "EleutherAI/gpt-j-6b"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

load_in_8bit = False

try:
    del llm
    clean()
    llm = LanguageModel(model_id, device_map=device, load_in_8bit=load_in_8bit)
    tokenizer = llm.tokenizer
except:
    llm = LanguageModel(model_id, device_map=device, load_in_8bit=load_in_8bit)
    tokenizer = llm.tokenizer

prompt_trial = "1:10,2:20,3:"
prompt_trial = "my name is "

In [3]:

def embeddings_to_texts_baseline(embeddings, model, tokenizer, skip_special_tokens=True):
    """
    Map input embeddings (batch, seq_len, emb_dim) → list of decoded strings.
    """
    # 1) Project embeddings to vocab logits
    logits = model.lm_head(embeddings.to(device))           # (batch, seq_len, vocab_size)
    # 2) Greedy decode: pick highest logit per position
    token_ids = torch.argmax(logits, dim=-1)     # (batch, seq_len)
    # 3) Transform each sequence of IDs into text
    texts = tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
    return logits, texts

def get_next_token_prediction(embeddings, model, tokenizer, skip_special_tokens=True):
    # 1) Project embeddings to vocab logits
    logits = model.lm_head(embeddings)           # (batch, seq_len, vocab_size)
    # 2) Greedy decode: pick highest logit per position
    token_ids = torch.argmax(logits, dim=-1)[:,-1]     # (batch, seq_len)
    # 3) Transform each sequence of IDs into text
    texts = tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
    return texts

@t.inference_mode()
def get_residual_output(prompt, layer_idx, llm, normalize = False):

    assert hasattr(llm, 'transformer'), "The model does not have a transformer attribute."
    assert hasattr(llm.transformer, 'h'), "The transformer does not have a 'h' attribute for layers."


    with llm.trace(prompt):
        residual_output = llm.transformer.h[layer_idx].output[0].save()  # Save the output of the layer for inspection

    if normalize: # FIXME: check if this is the correct way to normalize
        residual_output = llm.transformer.ln_f(residual_output)
    
    if llm.device.type == "cuda":
        residual_output = residual_output.detach().cpu()
        clean()
    
    return residual_output

@t.inference_mode()
def get_embeddings(prompt,llm):

    assert hasattr(llm, 'transformer'), "The model does not have a transformer attribute."
    assert hasattr(llm.transformer, 'drop'), "The transformer does not have a 'drop' attribute for input embeddings."
    
    with llm.trace(prompt):
        input_embeddings = llm.transformer.drop.input.save()

    if llm.device.type == "cuda":
        input_embeddings = input_embeddings.detach().to("cpu")
        clean()
    return input_embeddings
        


In [4]:
@t.inference_mode()
def get_one_token_h(prompt, layer_idx, model, tokenizer):
    llm = LanguageModel(model, tokenizer=tokenizer) if not isinstance(model, LanguageModel) else model

    # check that the prompt is a single token
    assert len(tokenizer(prompt)["input_ids"]) == 1, "The prompt must be a single token."
    # get the hidden representation of the prompt at layer_idx
    with llm.trace(prompt):
        hidden = llm.transformer.h[layer_idx].output.save()
    if llm.device.type == "cuda":
        hidden = hidden.detach().cpu()
        clean()
    return hidden

@t.inference_mode()
def get_hidden_representation(prompt, layer_idx, model, tokenizer):
    llm = LanguageModel(model, tokenizer=tokenizer) if not isinstance(model, LanguageModel) else model
    with llm.trace(prompt):
        hidden = llm.transformer.h[layer_idx].output.save()
    if llm.device.type == "cuda":
        hidden = hidden.detach().cpu()
        clean()
    return hidden[0].squeeze()
    

### TRYING THE DECODING

In [5]:
# make a for loop where for each layer you try the embeddings_to_texts_baseline for each hidden state output
full_text = False
print(f"{prompt_trial=} \n\n")
if full_text:   
    reversed_embeddings = get_embeddings(prompt_trial, llm)
    logits, texts = embeddings_to_texts_baseline(reversed_embeddings, llm, tokenizer)
    print(f"Reversed Embeddings for the prompt: {texts= } \n\n")
else:
    next_token_prediction = get_next_token_prediction(get_embeddings(prompt_trial, llm), llm, tokenizer)
    print(f"Next token prediction for the prompt: {next_token_prediction= } \n\n")


print(f"Iterating through each layer's output for the prompt: {prompt_trial}\n")
for layer_idx in range(len(llm.transformer.h)):
    residual_output = get_residual_output(prompt_trial, layer_idx, llm, True)
    if full_text:
        logits, texts = embeddings_to_texts_baseline(residual_output, llm, tokenizer)
        print(f"Layer {layer_idx} output texts: {texts}")
    else:
        next_token_prediction = get_next_token_prediction(residual_output, llm, tokenizer)
        print(f"Layer {layer_idx} next token prediction: {next_token_prediction}")

prompt_trial='my name is ' 


Next token prediction for the prompt: next_token_prediction= [' '] 


Iterating through each layer's output for the prompt: my name is 

Layer 0 next token prediction: [' ']
Layer 1 next token prediction: [' ']
Layer 2 next token prediction: [' upon']
Layer 3 next token prediction: ['.']
Layer 4 next token prediction: [' own']
Layer 5 next token prediction: [' I']
Layer 6 next token prediction: ['.']
Layer 7 next token prediction: ['\n']


In [6]:
import torch

def get_whole1(
    prompt: str,
    llm: torch.nn.Module,
    tokenizer,
    layer_idx: int
) -> torch.Tensor:
    # 1) Tokenize (and move to same device as the model).
    device = next(llm.parameters()).device
    encoded = tokenizer(prompt, return_tensors="pt")
    input_ids = encoded["input_ids"].to(device)      # shape (1, seq_len)
    attention_mask = encoded.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)

    # 2) Forward pass under no_grad, requesting hidden_states
    with torch.no_grad():
        outputs = llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        hidden_states = outputs.hidden_states
        # hidden_states is a tuple: (layer0, layer1, ..., layerN),
        # where layer0 = embeddings output, layer1 = first transformer block, etc.

        # 3) Index into [0, -1, :] to get the last token for this single‐batch example
        #    hidden_states[layer_idx] has shape (batch_size=1, seq_len, hidden_size).
        return hidden_states[layer_idx][0]


In [7]:
# make a for loop where for each layer you try the embeddings_to_texts_baseline for each hidden state output
prompt = 'my name is george'
prefix = 'my name'

for layer_idx in range(len(llm.transformer.h)):
    # residual_output1 = get_residual_output(prompt, layer_idx, llm, True)
    # residual_output2 = get_residual_output(prefix, layer_idx, llm, True)

    # residual_output1 = get_hidden_representation(prompt, layer_idx, llm, tokenizer)
    # residual_output2 = get_hidden_representation(prefix, layer_idx, llm, tokenizer)

    residual_output1 = get_whole1(prompt, llm, tokenizer, layer_idx)
    residual_output2 = get_whole1(prefix, llm, tokenizer, layer_idx)

    print(residual_output1.shape)
    print(residual_output2.shape)
    print(torch.mean((residual_output1[:2,:] - residual_output2) ** 2))
    print()

    

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)

torch.Size([5, 64])
torch.Size([2, 64])
tensor(0.)



# GRADIENT BASED ALGORITHM

In [70]:
import torch
import torch.nn.functional as F

def compute_last_token_embedding_grad(
    y: torch.LongTensor,
    llm: torch.nn.Module,
    layer_idx: int,
    h: torch.Tensor
) -> torch.Tensor:
    """
    Given:
      • y: 1D LongTensor of token IDs with shape (t,)
      • llm: a HuggingFace‐style language model (e.g. GPT2, BERT), already loaded
      • layer_idx: integer layer at which to “prune” and extract the hidden state
      • h: target hidden vector of shape (hidden_size,) or (1, hidden_size)

    Returns:
      • A torch.Tensor of shape (hidden_size,) which is
        ∂/∂e_last ( ‖z_last_layer - h‖₂² ), where z_last_layer is the model’s hidden state
        for the last token at the requested layer.  Only the embedding of the last token
        receives a nonzero gradient.
    """

    # 1) Make sure the model is in eval() mode (not strictly needed for gradients,
    #    but sets layers like dropout to eval).  Then freeze all parameters.
    llm.eval()
    for p in llm.parameters():
        p.requires_grad_(False)

    # 2) Prepare input IDs: ensure shape is (1, t)
    if not isinstance(y, torch.Tensor):
        y = torch.tensor(y, dtype=torch.long)
    if y.dim() == 1:
        input_ids = y.unsqueeze(0)  # now shape is (1, t)
    else:
        input_ids = y  # assume the user already shaped it as (1, t)

    device = next(llm.parameters()).device
    input_ids = input_ids.to(device)
    h = h.to(device)
    # If h has shape (hidden_size,), unsqueeze to (1, hidden_size)
    if h.dim() == 1:
        h = h.unsqueeze(0)

    # 3) Run the embedding layer under no_grad to get embeddings, then detach.
    #    We want to treat embeddings as a leaf variable that can require_grad for the last token.
    with torch.no_grad():
        # Assuming llm has a method get_input_embeddings() that returns the embedding layer
        embed_layer = llm.get_input_embeddings()
        full_embeddings = embed_layer(input_ids)  
        # full_embeddings: (1, t, hidden_size)

    # 4) Detach and set requires_grad=True only on full_embeddings ↦ we'll mask out everything
    #    except the last token’s embedding when we take gradients.
    embeddings = full_embeddings.detach().requires_grad_(True)
    # shape: (1, t, hidden_size)

    # 5) Forward‐pass through the model “up to” layer_idx.
    #    We rely on the model’s ability to accept inputs_embeds and to return hidden_states.
    #    Many HF models put hidden_states[0] = embeddings,
    #                                                  hidden_states[1] = after layer 0, etc.
    outputs = llm(
        inputs_embeds=embeddings,
        output_hidden_states=True,
        return_dict=True
    )
    # outputs.hidden_states is a tuple of length (num_layers+1).  Indexing might differ by model:
    #    - hidden_states[0]: the embedding output
    #    - hidden_states[1]: after the first Transformer block
    #    - …
    #    - hidden_states[L]: after the L-th block, etc.

    all_hidden_states = outputs.hidden_states
    # Sanity‐check: layer_idx must be < len(all_hidden_states)
    if layer_idx >= len(all_hidden_states):
        raise ValueError(
            f"layer_idx={layer_idx} is too large; model only returns "
            f"{len(all_hidden_states)} hidden states."
        )

    # 6) Extract the last‐token hidden vector at the requested layer.
    #    hidden_states[layer_idx] has shape (1, t, hidden_size).
    z_at_layer = all_hidden_states[layer_idx]  # shape (1, t, hidden_size)
    z_last_token = z_at_layer[:, -1, :]        # shape (1, hidden_size)

    # 7) Compute the (squared) L2 loss: sum((z_last_token - h)^2) 
    #    (If you want the non‐squared L2 distance, use torch.norm(z - h, p=2), but squared is simpler.)
    #    Here we choose squared L2 because its gradient is 2*(z - h), which is fine.
    diff = z_last_token - h             # shape (1, hidden_size)
    loss = torch.sum(diff * diff)       # scalar

    # 8) Backward pass: only embeddings require grad; all model parameters are frozen.
    llm.zero_grad()       # clear any stored gradients (just in case)
    loss.backward()

    # 9) Now `embeddings.grad` has shape (1, t, hidden_size).  We only care about the last token.
    grad_embeddings = embeddings.grad        # (1, t, hidden_size)
    grad_last = grad_embeddings[:, -1, :]    # (1, hidden_size)
    # Squeeze to make it (hidden_size,)
    grad_last = grad_last.squeeze(0)

    return grad_last, loss.item()


In [8]:
# function that takes as input a tokenizer and an embedding matrix and returns a random token and its embedding
def get_random_token_and_embedding(tokenizer, embedding_matrix):
    # Get the vocabulary size
    vocab_size = embedding_matrix.shape[0]
    
    # Generate a random token ID
    random_token_id = torch.randint(0, vocab_size, (1,)).item()
    
    # Get the corresponding embedding
    embedding = embedding_matrix[random_token_id]
    
    # Decode the token ID to get the token string
    token_string = tokenizer.decode(random_token_id)
    
    return token_string, embedding

# function that takes as input a tokenizer and an embedding matrix and returns a random token and its embedding
def get_random_token_id_and_embedding(embedding_matrix):
    # Get the vocabulary size
    vocab_size = embedding_matrix.shape[0]
    
    # Generate a random token ID
    random_token_id = torch.randint(0, vocab_size, (1,)).item()
    
    # Get the corresponding embedding
    embedding = embedding_matrix[random_token_id]

    
    return random_token_id, embedding

In [9]:
from tqdm import tqdm
# take as input llm, tokenizer, embedding matrix, an hidden representation, layer_idx, n_iterations, step_size, 


tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M", output_hidden_states=True)
tokenizer.pad_token = tokenizer.eos_token



In [73]:
prompt = "George"
h = get_one_token_h(prompt, 4, model,tokenizer)
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)
gamma = 1  # step size for gradient descent
layer_idx = 4  # layer at which we want to compute the gradient
n_iterations = 100  # number of iterations for gradient descent

In [None]:
prompt = "George"
h = get_one_token_h(prompt, 4, model,tokenizer)
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)
gamma = 1  # step size for gradient descent
layer_idx = 4  # layer at which we want to compute the gradient
n_iterations = 100  # number of iterations for gradient descent

# make a guess for (e_0,y_0)
y_i, x_i_plus_1= get_random_token_and_embedding(tokenizer, embedding_matrix)

# loop through the n_iterations
for iteration in (bar:=tqdm(range(n_iterations))):
    bar.set_postfix_str(f"Current guess: {y_i}")
    # compute the embedding of the current guess
    # P_hat_i = embedding_matrix[tokenizer.encode(y_i)[0]]  # (hidden_size,)
    
    # compute the gradient of the oracle with respect to e_i -> inner for
    grad_oracle = compute_last_token_embedding_grad(
        y=torch.tensor(tokenizer.encode(y_i), dtype=torch.long), # turn into x_i_plus_1
        llm=model,
        layer_idx=layer_idx,
        h=h
    )  # (hidden_size,)
    bar.set_postfix_str(f"Current guess: {y_i}, Gradient norm: {grad_oracle.norm().item():.4f}")

    # update the guess using gradient descent
    x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle  # (hidden_size,)
    
    # find the closest token in the embedding space to e_i_plus_1
    distances = torch.norm(embedding_matrix - x_i_plus_1, dim=1)  # (vocab_size,)
    closest_token_id = torch.argmin(distances).item()
    y_i = tokenizer.decode(closest_token_id)

    # make a summary via prints
    print(f"Iteration {iteration + 1}/{n_iterations}:")
    print(f"  Current guess: {y_i}")
    print(f"  Gradient norm: {grad_oracle.norm().item():.4f}")





In [88]:
prompt = "George"
n = 100
# tokenizer it but return text
h = get_hidden_representation(prompt, layer_idx, model, tokenizer)
size = 1 if len(h.shape)==1 else h.shape[0]

# discovered_tokens = []

In [89]:
# tokenizer.tokenize("my name is George")
# tokenizer.encode("name")
# discovered_tokens = ["my", " name", " is"]
# # merge discovered tokens and y_i in a list and a string
# lista = discovered_tokens + [y_i]
# stringa = "".join(lista)
# y_i, discovered_tokens, lista, stringa, tokenizer.tokenize(stringa)
# y_extend = "".join((discovered_tokens + [y_i]))
# print(f"Extended guess: {y_extend}")
# encoded_y = torch.tensor(tokenizer.encode(y_extend), dtype=torch.long) 
# grad_oracle = compute_last_token_embedding_grad(
#         y=encoded_y, # turn into x_i_plus_1
#         llm=model,
#         layer_idx=layer_idx,
#         h=h
#     )

In [11]:
import torch
import torch.nn.functional as F

import torch
from typing import Optional

def get_whole(
    prompt: str,
    llm: torch.nn.Module,
    tokenizer,
    layer_idx: int,
    input_ids: Optional[torch.Tensor] = None,
    grad: bool = False
) -> torch.Tensor:
    """
    Tokenize `prompt`, do a forward pass with output_hidden_states=True,
    and return the hidden vector of the *last* token at layer `layer_idx`.

    Args:
        prompt (str): the input string, e.g. "Harry".
        llm (nn.Module): a HuggingFace‐style model (with embeddings + hidden_states).
        tokenizer: the corresponding tokenizer for `llm`.
        layer_idx (int): which hidden‐layer index to extract (0=embeddings, 1=first block, etc.)

    Returns:
        Tensor of shape (hidden_size,) = the last‐token hidden state at `layer_idx`,
        computed under torch.no_grad().
    """
    # 1) Tokenize (and move to same device as the model).
    device = next(llm.parameters()).device
    if input_ids is None:
        encoded = tokenizer(prompt, return_tensors="pt")
        input_ids = encoded["input_ids"].to(device)      # shape (1, seq_len)
    
    if not grad:
        # Forward under no_grad and detach before returning
        with torch.no_grad():
            outputs = llm(
                input_ids=input_ids,
                output_hidden_states=True
            )
            hidden_states = outputs.hidden_states
            h = hidden_states[layer_idx][0]  # shape = (seq_len, hidden_size)

        return h.detach()
    else:
        # Forward normally, so gradients can flow
        outputs = llm(
            input_ids=input_ids,
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states
        h = hidden_states[layer_idx][0]  # shape = (seq_len, hidden_size)
        return h


def compute_last_token_embedding_grad(
    y: torch.LongTensor,
    llm: torch.nn.Module,
    layer_idx: int,
    h_target: torch.Tensor
):
    device = next(llm.parameters()).device
    y = y.to(device)
    h_target = h_target.to(device)

    emb_layer = llm.get_input_embeddings()
    if not emb_layer.weight.requires_grad:
        emb_layer.weight.requires_grad_(True)

    llm.zero_grad()
    emb_layer.zero_grad()

    with torch.set_grad_enabled(True):
        h_last = get_whole('', llm, tokenizer, layer_idx, y.unsqueeze(0), grad=True)[-1]
        diff = h_last - h_target
        loss = torch.dot(diff, diff)
        loss.backward()

    last_token_id = y[-1].item()
    grad_last_embedding = emb_layer.weight.grad[last_token_id].detach().clone() # TODO: Is this gradient with respect to the input or not?

    llm.zero_grad()
    emb_layer.zero_grad()

    return grad_last_embedding, loss.item()


## Vanilla Algo

In [12]:
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)

prompt = "my name is george and i am living here in greece, i am 20 years old and my secret is that i am in love"
layer_idx = 8  # layer at which we want to compute the gradient

# h = get_hidden_representation(prompt, layer_idx, model, tokenizer)
h = get_whole(prompt, model, tokenizer, layer_idx)
if h.dim() == 1:
    h = h.unsqueeze(0)

gamma = 1e-0  # step size for gradient descent
n_iterations = 5000  # number of iterations for gradient descent

discovered_ids = []
for j in range(h.size(0)):
    # an idea here is to initialize with feeding an LLM with the prompt so far
    # and getting the next token, more expensive, likely to work much better
    y_i_id, x_i_plus_1 = get_random_token_id_and_embedding(embedding_matrix)

    for iteration in (bar := tqdm(range(n_iterations), desc=f'Token [{j:2d}/{h.size(0):2d}]')):
        input_ids = torch.tensor(discovered_ids + [y_i_id], dtype=torch.long)

        h_pred = get_whole('', model, tokenizer, layer_idx, input_ids.unsqueeze(0), grad=False)[-1]

        if torch.sum((h_pred - h[:h_pred.size(0),:]) ** 2) <= 1e-10:
            print('Early stopping')
            break

        grad_oracle, loss = compute_last_token_embedding_grad(
            y=input_ids, # turn into x_i_plus_1
            llm=model,
            layer_idx=layer_idx,
            h_target=h[j]
        ) # TODO: Dont use tokens, rather use previous embeddings + x_i_plus_1
        # TODO: non-zero loss even when having the correct y_i_id
        
        string_so_far = tokenizer.decode(input_ids.cpu().tolist(), skip_special_tokens=True)

        bar.set_postfix_str(f"Loss: {loss:.2e} - Gradient norm: {grad_oracle.norm().item():.2e} - String: {string_so_far}")
        # print('', flush=True)
        
        if string_so_far == prompt:
            break

        if loss < 1e-6:
            # print(f"Found token {y_i} with small L2 norm: {torch.norm(h[j] - embedding_matrix[closest_token_id]).item():.4f}")
            break

        # update the guess using gradient descent
        x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle  # (hidden_size,)
        
        # find the closest token in the embedding space to e_i_plus_1
        distances = torch.norm(embedding_matrix - x_i_plus_1, dim=1)  # (vocab_size,)
        y_i_id = int(torch.argmin(distances))


    
    discovered_ids.append(y_i_id)

final_string = tokenizer.decode(input_ids.cpu().tolist(), skip_special_tokens=True)
print(f"Final discovered tokens: {final_string}")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Token [ 0/29]:   0%|          | 17/5000 [00:00<00:56, 88.34it/s, Loss: 1.94e-09 - Gradient norm: 5.92e-03 - String: my]     
Token [ 1/29]:   0%|          | 6/5000 [00:00<00:59, 84.02it/s, Loss: 2.97e-09 - Gradient norm: 1.94e-02 - String: my name]
Token [ 2/29]:  27%|██▋       | 1353/5000 [00:13<00:36, 101.17it/s, Loss: 8.78e+02 - Gradient norm: 4.16e+03 - String: my name and]      

Token [ 2/29]:  27%|██▋       | 1353/5000 [00:13<00:36, 101.17it/s, Loss: 1.03e+03 - Gradient norm: 3.54e+03 - String: my name he]
Token [ 2/29]:  27%|██▋       | 1358/5000 [00:13<00:35, 101.49it/s, Loss: 6.09e-09 - Gradient norm: 1.91e-02 - String: my name is]
Token [ 3/29]:   0%|          | 8/5000 [00:00<00:54, 92.37it/s, Los

KeyboardInterrupt: 

In [82]:
import torch
from tqdm import tqdm

embedding_matrix = model.get_input_embeddings().weight  # shape (vocab_size, hidden_size)

prompt = "Helenaki is hereeee and is a cutie"
layer_idx = 4  # layer at which we want to compute the gradient

# 1) Get the “target” hidden states for each token in `prompt`
h = get_whole(prompt, model, tokenizer, layer_idx, grad=False)
if h.dim() == 1:
    h = h.unsqueeze(0)  # now shape (seq_len, hidden_size)

# 2) Adam hyperparameters
gamma = 1e-1        # base learning rate (you can tune this)
beta1 = 0.9
beta2 = 0.999
eps = 1e-8
n_iterations = 5000

discovered_ids = []

for j in range(h.size(0)):
    # ——— Initialize a random token and its “fake” embedding guess ———
    y_i_id, x_i_plus_1 = get_random_token_id_and_embedding(embedding_matrix)
    # x_i_plus_1 has shape (hidden_size,), on same device as embedding_matrix

    # 3) Initialize Adam buffers m, v (shape = hidden_size) to zero
    m = torch.zeros_like(x_i_plus_1)
    v = torch.zeros_like(x_i_plus_1)

    for iteration in (bar := tqdm(range(n_iterations), desc=f"Token [{j:2d}/{h.size(0):2d}]")):
        # Build the current hypothesis of token‐IDs:
        input_ids = torch.tensor(discovered_ids + [y_i_id], dtype=torch.long, device=embedding_matrix.device)

        # 4) Compute gradient ∇_x L₂( h_last, h_target ) at layer_idx
        grad_oracle, loss = compute_last_token_embedding_grad(
            y=input_ids,          # tensor of shape (seq_len,)
            llm=model,
            layer_idx=layer_idx,
            h_target=h[j]
        )

        # 5) Logging / progress
        string_so_far = tokenizer.decode(input_ids.cpu().tolist(), skip_special_tokens=True)
        bar.set_postfix_str(
            f"Loss: {loss:.4f} | ‖grad‖: {grad_oracle.norm().item():.4f} | '{string_so_far}'"
        )

        if string_so_far == prompt:
            break
        if grad_oracle.norm(p=2) < 1e-4:
            # gradient is essentially zero → we’re “close enough”
            break

        # ——— Adam update on x_i_plus_1 ———
        t = iteration + 1
        # 6a) Update biased first moment estimate
        m = beta1 * m + (1 - beta1) * grad_oracle
        # 6b) Update biased second moment estimate (elementwise square)
        v = beta2 * v + (1 - beta2) * (grad_oracle * grad_oracle)
        # 6c) Compute bias‐corrected m_hat, v_hat
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)
        # 6d) Take Adam step
        x_i_plus_1 = x_i_plus_1 - gamma * (m_hat / (v_hat.sqrt() + eps))

        # 7) Project x_i_plus_1 back to the nearest token in embedding space
        #    (Choose the token whose embedding row is closest in L₂ distance)
        distances = torch.norm(embedding_matrix - x_i_plus_1.unsqueeze(0), dim=1)  # (vocab_size,)
        y_i_id = int(torch.argmin(distances))

    # End of iteration loop for this token j
    discovered_ids.append(y_i_id)

# After all tokens are “discovered”:
final_string = tokenizer.decode(torch.tensor(discovered_ids, dtype=torch.long), skip_special_tokens=True)
print(f"Final discovered tokens: {final_string}")

Token [ 0/12]:   0%|          | 11/5000 [00:00<00:35, 139.53it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Hel']
Token [ 1/12]:   1%|▏         | 72/5000 [00:00<00:32, 151.39it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Helen']         
Token [ 2/12]:   0%|          | 10/5000 [00:00<00:38, 129.42it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Helenaki']
Token [ 3/12]:   1%|          | 39/5000 [00:00<00:33, 147.71it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Helenaki is']       
Token [ 4/12]:   0%|          | 14/5000 [00:00<00:36, 136.61it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Helenaki is he']
Token [ 5/12]:   0%|          | 5/5000 [00:00<00:40, 124.05it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Helenaki is heree']
Token [ 6/12]:   0%|          | 12/5000 [00:00<00:36, 135.82it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Helenaki is hereeee']
Token [ 7/12]:   0%|          | 20/5000 [00:00<00:42, 116.88it/s, Loss: 0.0000 | ‖grad‖: 0.0000 | 'Helenaki is hereeee and'] 
Token [ 8/12]:   1%|▏         | 74/5000 [00:00<00:35, 137.09it/s, Lo

Final discovered tokens: Helenaki is hereeee and is a cutie





In [83]:

h = get_whole(prompt, model, tokenizer, layer_idx, grad=False)

h.shape

torch.Size([12, 64])

In [None]:
prompt = "George"
h = get_one_token_h(prompt, 4, model,tokenizer)[0]
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)
gamma = 1  # step size for gradient descent
layer_idx = 4  # layer at which we want to compute the gradient
n_iterations = 100  # number of iterations for gradient descent

# make a guess for (e_0,y_0)
y_i, x_i_plus_1= get_random_token_and_embedding(tokenizer, embedding_matrix)

# loop through the n_iterations
for iteration in (bar:=tqdm(range(n_iterations))):
    bar.set_postfix_str(f"Current guess: {y_i}")
    # compute the embedding of the current guess
    # P_hat_i = embedding_matrix[tokenizer.encode(y_i)[0]]  # (hidden_size,)
    
    # compute the gradient of the oracle with respect to e_i -> inner for
    grad_oracle, _ = compute_last_token_embedding_grad(
        y=torch.tensor(tokenizer.encode(y_i), dtype=torch.long), # turn into x_i_plus_1
        llm=model,
        layer_idx=layer_idx,
        h=h
    )  # (hidden_size,)
    # bar.set_postfix_str(f"Current guess: {y_i}, Gradient norm: {grad_oracle.norm().item():.4f}")

    # update the guess using gradient descent
    x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle  # (hidden_size,)
    
    # find the closest token in the embedding space to e_i_plus_1
    distances = torch.norm(embedding_matrix - x_i_plus_1, dim=1)  # (vocab_size,)
    closest_token_id = torch.argmin(distances).item()
    y_i = tokenizer.decode(closest_token_id)

    # make a summary via prints
    print(f"Current guess: {y_i} - Target: {prompt} - Gradient norm: {grad_oracle.norm().item():.4f}")





## Boosted Algo

In [None]:
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)

prompt = "my name is george and i am living here in greece, i am 20 years old and my secret is that i am in love"
layer_idx = 8  # layer at which we want to compute the gradient

# h = get_hidden_representation(prompt, layer_idx, model, tokenizer)
h = get_whole(prompt, model, tokenizer, layer_idx)
if h.dim() == 1:
    h = h.unsqueeze(0)

gamma = 1e-0  # step size for gradient descent
n_iterations = 5000  # number of iterations for gradient descent

discovered_ids = []
for j in range(h.size(0)):
    # an idea here is to initialize with feeding an LLM with the prompt so far
    # and getting the next token, more expensive, likely to work much better
    y_i_id, x_i_plus_1 = get_random_token_id_and_embedding(embedding_matrix)

    for iteration in (bar := tqdm(range(n_iterations), desc=f'Token [{j:2d}/{h.size(0):2d}]')):
        input_ids = torch.tensor(discovered_ids + [y_i_id], dtype=torch.long)

        h_pred = get_whole('', model, tokenizer, layer_idx, input_ids.unsqueeze(0), grad=False)[-1]

        if torch.sum((h_pred - h[:h_pred.size(0),:]) ** 2) <= 1e-10:
            print('Early stopping')
            break

        grad_oracle, loss = compute_last_token_embedding_grad(
            y=input_ids, # turn into x_i_plus_1
            llm=model,
            layer_idx=layer_idx,
            h_target=h[j]
        ) # TODO: Dont use tokens, rather use previous embeddings + x_i_plus_1
        # TODO: non-zero loss even when having the correct y_i_id
        
        string_so_far = tokenizer.decode(input_ids.cpu().tolist(), skip_special_tokens=True)

        bar.set_postfix_str(f"Loss: {loss:.2e} - Gradient norm: {grad_oracle.norm().item():.2e} - String: {string_so_far}")
        
        if string_so_far == prompt:
            break

        if loss < 1e-8:
            # print(f"Found token {y_i} with small L2 norm: {torch.norm(h[j] - embedding_matrix[closest_token_id]).item():.4f}")
            break

        # update the guess using gradient descent
        x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle  # (hidden_size,)
        
        # find the closest token in the embedding space to e_i_plus_1
        distances = torch.norm(embedding_matrix - x_i_plus_1, dim=1)  # (vocab_size,)
        y_i_id = int(torch.argmin(distances))


    
    discovered_ids.append(y_i_id)

final_string = tokenizer.decode(input_ids.cpu().tolist(), skip_special_tokens=True)
print(f"Final discovered tokens: {final_string}")


## Debugging

In [114]:
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)

prompt = "my name is george and i am living here in greece, i am 20 years old and my secret is that i am in love"
layer_idx = 8  # layer at which we want to compute the gradient


tokenized = tokenizer(prompt)

# h = get_hidden_representation(prompt, layer_idx, model, tokenizer)
h = get_whole(prompt, model, tokenizer, layer_idx)
if h.dim() == 1:
    h = h.unsqueeze(0)

gamma = 1e-0  # step size for gradient descent
n_iterations = 5000  # number of iterations for gradient descent

discovered_ids = []
for j in range(h.size(0)):
    # an idea here is to initialize with feeding an LLM with the prompt so far
    # and getting the next token, more expensive, likely to work much better
    y_i_id, x_i_plus_1 = get_random_token_id_and_embedding(embedding_matrix)

    for iteration in (bar := tqdm(range(n_iterations), desc=f'Token [{j:2d}/{h.size(0):2d}]')):
        input_ids = torch.tensor(discovered_ids + [y_i_id], dtype=torch.long)

        h_pred = get_whole('', model, tokenizer, layer_idx, input_ids.unsqueeze(0), grad=False)[-1]

        # if torch.sum((h_pred - h[:h_pred.size(0),:]) ** 2) <= 1e-10:
        #     print('Early stopping')
        #     break

        grad_oracle, loss = compute_last_token_embedding_grad(
            y=input_ids, # turn into x_i_plus_1
            llm=model,
            layer_idx=layer_idx,
            h_target=h[j]
        ) # TODO: Dont use tokens, rather use previous embeddings + x_i_plus_1
        # TODO: non-zero loss even when having the correct y_i_id
        
        string_so_far = tokenizer.decode(input_ids.cpu().tolist(), skip_special_tokens=True)

        bar.set_postfix_str(f"Loss: {loss:.2e} - Gradient norm: {grad_oracle.norm().item():.2e} - String: {string_so_far}")
        
        # if string_so_far == prompt:
        #     break

        if loss < 1e-8:
            # print(f"Found token {y_i} with small L2 norm: {torch.norm(h[j] - embedding_matrix[closest_token_id]).item():.4f}")
            break

        # update the guess using gradient descent
        x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle  # (hidden_size,)
        
        # find the closest token in the embedding space to e_i_plus_1
        distances = torch.norm(embedding_matrix - x_i_plus_1, dim=1)  # (vocab_size,)
        y_i_id = int(torch.argmin(distances))


    
    discovered_ids.append(y_i_id)

final_string = tokenizer.decode(input_ids.cpu().tolist(), skip_special_tokens=True)
print(f"Final discovered tokens: {final_string}")


Token [ 0/ 1]:   0%|          | 4/5000 [00:00<01:04, 78.00it/s, Loss: 0.00e+00 - Gradient norm: 0.00e+00 - String: George]

Final discovered tokens: George



