In [None]:
# import from huggingface roneneldan/TinyStories-1M
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from nnsight import LanguageModel
import torch as t
# load garbage collection and empty cache 
import gc
from torch.cuda import empty_cache

# UNEMBEDDING BASELINE

In [None]:
def clean():
    gc.collect()
    empty_cache()

model_id = "roneneldan/TinyStories-1M"
# model_id = "EleutherAI/gpt-j-6b"

try:
    del llm
    clean()
    llm = LanguageModel(model_id, device_map="cuda", load_in_8bit=True)
    tokenizer = llm.tokenizer
except:
    llm = LanguageModel(model_id, device_map="cuda", load_in_8bit=True)
    tokenizer = llm.tokenizer

prompt_trial = "1:10,2:20,3:"
prompt_trial = "my name is "

In [None]:

def embeddings_to_texts_baseline(embeddings, model, tokenizer, skip_special_tokens=True):
    """
    Map input embeddings (batch, seq_len, emb_dim) → list of decoded strings.
    """
    # 1) Project embeddings to vocab logits
    logits = model.lm_head(embeddings.to("cuda"))           # (batch, seq_len, vocab_size)
    # 2) Greedy decode: pick highest logit per position
    token_ids = torch.argmax(logits, dim=-1)     # (batch, seq_len)
    # 3) Transform each sequence of IDs into text
    texts = tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
    return logits, texts

def get_next_token_prediction(embeddings, model, tokenizer, skip_special_tokens=True):
    # 1) Project embeddings to vocab logits
    logits = model.lm_head(embeddings)           # (batch, seq_len, vocab_size)
    # 2) Greedy decode: pick highest logit per position
    token_ids = torch.argmax(logits, dim=-1)[:,-1]     # (batch, seq_len)
    # 3) Transform each sequence of IDs into text
    texts = tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
    return texts

@t.inference_mode()
def get_residual_output(prompt, layer_idx, llm, normalize = False):

    assert hasattr(llm, 'transformer'), "The model does not have a transformer attribute."
    assert hasattr(llm.transformer, 'h'), "The transformer does not have a 'h' attribute for layers."


    with llm.trace(prompt):
        residual_output = llm.transformer.h[layer_idx].output[0].save()  # Save the output of the layer for inspection

    if normalize: # FIXME: check if this is the correct way to normalize
        residual_output = llm.transformer.ln_f(residual_output)
    
    if llm.device.type == "cuda":
        residual_output = residual_output.detach().to("cpu")
        clean()
    
    return residual_output

@t.inference_mode()
def get_embeddings(prompt,llm):

    assert hasattr(llm, 'transformer'), "The model does not have a transformer attribute."
    assert hasattr(llm.transformer, 'drop'), "The transformer does not have a 'drop' attribute for input embeddings."
    
    with llm.trace(prompt):
        input_embeddings = llm.transformer.drop.input.save()

    if llm.device.type == "cuda":
        input_embeddings = input_embeddings.detach().to("cpu")
        clean()
    return input_embeddings
        


In [None]:
@t.inference_mode()
def get_one_token_h(prompt, layer_idx, model, tokenizer):
    llm = LanguageModel(model, tokenizer=tokenizer)

    # check that the prompt is a single token
    assert len(tokenizer(prompt)["input_ids"]) == 1, "The prompt must be a single token."
    # get the hidden representation of the prompt at layer_idx
    with llm.trace(prompt):
        hidden = llm.transformer.h[layer_idx].output.save()
    if llm.device.type == "cuda":
        hidden = hidden.detach().to("cpu")
        clean()
    return hidden

def get_hidden_representation(prompt, layer_idx, model, tokenizer):
    llm = LanguageModel(model, tokenizer=tokenizer)
    with llm.trace(prompt):
        hidden = llm.transformer.h[layer_idx].output.save()
    if llm.device.type == "cuda":
        hidden = hidden.detach().to("cpu")
        clean()
    return hidden[0].squeeze()
    

### TRYING THE DECODING

In [None]:
# make a for loop where for each layer you try the embeddings_to_texts_baseline for each hidden state output
full_text = True
print(f"{prompt_trial=} \n\n")
if full_text:   
    reversed_embeddings = get_embeddings(prompt_trial, llm)
    logits, texts = embeddings_to_texts_baseline(reversed_embeddings, llm, tokenizer)
    print(f"Reversed Embeddings for the prompt: {texts= } \n\n")
else:
    next_token_prediction = get_next_token_prediction(get_embeddings(prompt_trial, llm), llm, tokenizer)
    print(f"Next token prediction for the prompt: {next_token_prediction= } \n\n")


print(f"Iterating through each layer's output for the prompt: {prompt_trial}\n")
for layer_idx in range(len(llm.transformer.h)):
    residual_output = get_residual_output(prompt_trial, layer_idx, llm, True)
    if full_text:
        logits, texts = embeddings_to_texts_baseline(residual_output, llm, tokenizer)
        print(f"Layer {layer_idx} output texts: {texts}")
    else:
        next_token_prediction = get_next_token_prediction(residual_output, llm, tokenizer)
        print(f"Layer {layer_idx} next token prediction: {next_token_prediction}")

# GRADIENT BASED ALGORITHM

In [None]:
import torch
import torch.nn.functional as F

def compute_last_token_embedding_grad(
    y: torch.LongTensor,
    llm: torch.nn.Module,
    layer_idx: int,
    h: torch.Tensor
) -> torch.Tensor:
    """
    Given:
      • y: 1D LongTensor of token IDs with shape (t,)
      • llm: a HuggingFace‐style language model (e.g. GPT2, BERT), already loaded
      • layer_idx: integer layer at which to “prune” and extract the hidden state
      • h: target hidden vector of shape (hidden_size,) or (1, hidden_size)

    Returns:
      • A torch.Tensor of shape (hidden_size,) which is
        ∂/∂e_last ( ‖z_last_layer - h‖₂² ), where z_last_layer is the model’s hidden state
        for the last token at the requested layer.  Only the embedding of the last token
        receives a nonzero gradient.
    """

    # 1) Make sure the model is in eval() mode (not strictly needed for gradients,
    #    but sets layers like dropout to eval).  Then freeze all parameters.
    llm.eval()
    for p in llm.parameters():
        p.requires_grad_(False)

    # 2) Prepare input IDs: ensure shape is (1, t)
    if not isinstance(y, torch.Tensor):
        y = torch.tensor(y, dtype=torch.long)
    if y.dim() == 1:
        input_ids = y.unsqueeze(0)  # now shape is (1, t)
    else:
        input_ids = y  # assume the user already shaped it as (1, t)

    device = next(llm.parameters()).device
    input_ids = input_ids.to(device)
    h = h.to(device)
    # If h has shape (hidden_size,), unsqueeze to (1, hidden_size)
    if h.dim() == 1:
        h = h.unsqueeze(0)

    # 3) Run the embedding layer under no_grad to get embeddings, then detach.
    #    We want to treat embeddings as a leaf variable that can require_grad for the last token.
    with torch.no_grad():
        # Assuming llm has a method get_input_embeddings() that returns the embedding layer
        embed_layer = llm.get_input_embeddings()
        full_embeddings = embed_layer(input_ids)  
        # full_embeddings: (1, t, hidden_size)

    # 4) Detach and set requires_grad=True only on full_embeddings ↦ we'll mask out everything
    #    except the last token’s embedding when we take gradients.
    embeddings = full_embeddings.detach().requires_grad_(True)
    # shape: (1, t, hidden_size)

    # 5) Forward‐pass through the model “up to” layer_idx.
    #    We rely on the model’s ability to accept inputs_embeds and to return hidden_states.
    #    Many HF models put hidden_states[0] = embeddings,
    #                                                  hidden_states[1] = after layer 0, etc.
    outputs = llm(
        inputs_embeds=embeddings,
        output_hidden_states=True,
        return_dict=True
    )
    # outputs.hidden_states is a tuple of length (num_layers+1).  Indexing might differ by model:
    #    - hidden_states[0]: the embedding output
    #    - hidden_states[1]: after the first Transformer block
    #    - …
    #    - hidden_states[L]: after the L-th block, etc.

    all_hidden_states = outputs.hidden_states
    # Sanity‐check: layer_idx must be < len(all_hidden_states)
    if layer_idx >= len(all_hidden_states):
        raise ValueError(
            f"layer_idx={layer_idx} is too large; model only returns "
            f"{len(all_hidden_states)} hidden states."
        )

    # 6) Extract the last‐token hidden vector at the requested layer.
    #    hidden_states[layer_idx] has shape (1, t, hidden_size).
    z_at_layer = all_hidden_states[layer_idx]  # shape (1, t, hidden_size)
    z_last_token = z_at_layer[:, -1, :]        # shape (1, hidden_size)

    # 7) Compute the (squared) L2 loss: sum((z_last_token - h)^2) 
    #    (If you want the non‐squared L2 distance, use torch.norm(z - h, p=2), but squared is simpler.)
    #    Here we choose squared L2 because its gradient is 2*(z - h), which is fine.
    diff = z_last_token - h             # shape (1, hidden_size)
    loss = torch.sum(diff * diff)       # scalar

    # 8) Backward pass: only embeddings require grad; all model parameters are frozen.
    llm.zero_grad()       # clear any stored gradients (just in case)
    loss.backward()

    # 9) Now `embeddings.grad` has shape (1, t, hidden_size).  We only care about the last token.
    grad_embeddings = embeddings.grad        # (1, t, hidden_size)
    grad_last = grad_embeddings[:, -1, :]    # (1, hidden_size)
    # Squeeze to make it (hidden_size,)
    grad_last = grad_last.squeeze(0)

    return grad_last


In [None]:
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch

# # 1) Load a pretrained LLM (e.g. GPT-2)
# tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
# tokenizer.pad_token = tokenizer.eos_token
# model     = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M", output_hidden_states=True)


# # 2) Suppose y = ["Hello", "world"]  → token IDs:
# tokens = tokenizer(["Hello", "world"], return_tensors="pt", add_special_tokens=False)
# y_ids  = tokens["input_ids"].squeeze(0)   # e.g. tensor([15496, 995])

# # 3) Suppose we want the target vector h to be some random “desired” hidden representation 
# #    at layer 4 (for illustration).
# hidden_size = model.config.hidden_size          # 768 for GPT-2 small
# h_target    = torch.randn(1, hidden_size)       # pretend this is our “objective”

# # 4) Call the function:
# grad_last_embedding = compute_last_token_embedding_grad(
#     y=y_ids,
#     llm=model,
#     layer_idx=4,
#     h=h_target
# )

# print("Grad shape:", grad_last_embedding.shape)  # → torch.Size([768])


In [None]:
# function that takes as input a tokenizer and an embedding matrix and returns a random token and its embedding
def get_random_token_and_embedding(tokenizer, embedding_matrix):
    # Get the vocabulary size
    vocab_size = embedding_matrix.shape[0]
    
    # Generate a random token ID
    random_token_id = torch.randint(0, vocab_size, (1,)).item()
    
    # Get the corresponding embedding
    embedding = embedding_matrix[random_token_id]
    
    # Decode the token ID to get the token string
    token_string = tokenizer.decode(random_token_id)
    
    return token_string, embedding

In [None]:
from tqdm import tqdm
# take as input llm, tokenizer, embedding matrix, an hidden representation, layer_idx, n_iterations, step_size, 


tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M", output_hidden_states=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
prompt = "George"
h = get_one_token_h(prompt, 4, model,tokenizer)
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)
gamma = 1  # step size for gradient descent
layer_idx = 4  # layer at which we want to compute the gradient
n_iterations = 100  # number of iterations for gradient descent

In [None]:
prompt = "George"
h = get_one_token_h(prompt, 4, model,tokenizer)
embedding_matrix = model.get_input_embeddings().weight  # (vocab_size, hidden_size)
gamma = 1  # step size for gradient descent
layer_idx = 4  # layer at which we want to compute the gradient
n_iterations = 100  # number of iterations for gradient descent

# make a guess for (e_0,y_0)
y_i, x_i_plus_1= get_random_token_and_embedding(tokenizer, embedding_matrix)

# loop through the n_iterations
for iteration in (bar:=tqdm(range(n_iterations))):
    bar.set_postfix_str(f"Current guess: {y_i}")
    # compute the embedding of the current guess
    # P_hat_i = embedding_matrix[tokenizer.encode(y_i)[0]]  # (hidden_size,)
    
    # compute the gradient of the oracle with respect to e_i -> inner for
    grad_oracle = compute_last_token_embedding_grad(
        y=torch.tensor(tokenizer.encode(y_i), dtype=torch.long), # turn into x_i_plus_1
        llm=model,
        layer_idx=layer_idx,
        h=h
    )  # (hidden_size,)
    bar.set_postfix_str(f"Current guess: {y_i}, Gradient norm: {grad_oracle.norm().item():.4f}")

    # update the guess using gradient descent
    x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle  # (hidden_size,)
    
    # find the closest token in the embedding space to e_i_plus_1
    distances = torch.norm(embedding_matrix - x_i_plus_1, dim=1)  # (vocab_size,)
    closest_token_id = torch.argmin(distances).item()
    y_i = tokenizer.decode(closest_token_id)

    # make a summary via prints
    print(f"Iteration {iteration + 1}/{n_iterations}:")
    print(f"  Current guess: {y_i}")
    print(f"  Gradient norm: {grad_oracle.norm().item():.4f}")





In [None]:
prompt = "George"
n = 100
# tokenizer it but return text
h = get_hidden_representation(prompt, layer_idx, model, tokenizer)
size = 1 if len(h.shape)==1 else h.shape[0]

# discovered_tokens = []

In [None]:
# tokenizer.tokenize("my name is George")
# tokenizer.encode("name")
# discovered_tokens = ["my", " name", " is"]
# # merge discovered tokens and y_i in a list and a string
# lista = discovered_tokens + [y_i]
# stringa = "".join(lista)
# y_i, discovered_tokens, lista, stringa, tokenizer.tokenize(stringa)
# y_extend = "".join((discovered_tokens + [y_i]))
# print(f"Extended guess: {y_extend}")
# encoded_y = torch.tensor(tokenizer.encode(y_extend), dtype=torch.long) 
# grad_oracle = compute_last_token_embedding_grad(
#         y=encoded_y, # turn into x_i_plus_1
#         llm=model,
#         layer_idx=layer_idx,
#         h=h
#     )

In [None]:

# make a guess for (e_0,y_0)

discovered_tokens = []
for j in range(size):
    y_i, x_i_plus_1= get_random_token_and_embedding(tokenizer, embedding_matrix)


    # loop through the n_iterations
    for iteration in (bar:=tqdm(range(n_iterations))):
        bar.set_postfix_str(f"Current guess: {y_i}")
        # compute the embedding of the current guess
        # P_hat_i = embedding_matrix[tokenizer.encode(y_i)[0]]  # (hidden_size,)
        
        y_extend = "".join((discovered_tokens + [y_i]))
        print(f"Extended guess: {y_extend}")
        encoded_y = torch.tensor(tokenizer.encode(y_extend), dtype=torch.long) 
        grad_oracle = compute_last_token_embedding_grad(
                y=encoded_y, # turn into x_i_plus_1
                llm=model,
                layer_idx=layer_idx,
                h=h
            )
        bar.set_postfix_str(f"Current guess: {y_i}, Gradient norm: {grad_oracle.norm().item():.4f}")

        # update the guess using gradient descent
        x_i_plus_1 = x_i_plus_1 - gamma * grad_oracle  # (hidden_size,)
        
        # find the closest token in the embedding space to e_i_plus_1
        distances = torch.norm(embedding_matrix - x_i_plus_1, dim=1)  # (vocab_size,)
        closest_token_id = torch.argmin(distances).item()
        y_i = tokenizer.decode(closest_token_id)
        # if l2 norm between j of h and embedding_matrix is small break
        if torch.norm(h[j] - embedding_matrix[closest_token_id]).item() < 1e-3:
            print(f"Found token {y_i} with small L2 norm: {torch.norm(h[j] - embedding_matrix[closest_token_id]).item():.4f}")
            break
        # make a summary via prints
        # print(f"Iteration {iteration + 1}/{n_iterations}:")
        # print(f"  Current guess: {y_i}")
        # print(f"  Gradient norm: {grad_oracle.norm().item():.4f}")


    # add the discovered token to the list
    discovered_tokens.append(y_i)
    # print the discovered tokens
    print(f"Discovered tokens so far: {discovered_tokens}")

# print the final discovered tokens
print(f"Final discovered tokens: {discovered_tokens}")
