In [1]:
# import from huggingface roneneldan/TinyStories-1M
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from nnsight import LanguageModel
import torch as t
# load garbage collection and empty cache 
import gc
from torch.cuda import empty_cache

  from .autonotebook import tqdm as notebook_tqdm


# Trying with TinyStories 

In [None]:
def clean():
    gc.collect()
    empty_cache()

model_id = "roneneldan/TinyStories-1M"
model_id = "EleutherAI/gpt-j-6b"

try:
    del llm
    clean()
    llm = LanguageModel(model_id, device_map="cuda", load_in_8bit=True)
    tokenizer = llm.tokenizer
except:
    llm = LanguageModel(model_id, device_map="cuda", load_in_8bit=True)
    tokenizer = llm.tokenizer

prompt_trial = "1:10,2:20,3:"

In [None]:

def embeddings_to_texts_baseline(embeddings, model, tokenizer, skip_special_tokens=True):
    """
    Map input embeddings (batch, seq_len, emb_dim) → list of decoded strings.
    """
    # 1) Project embeddings to vocab logits
    logits = model.lm_head(embeddings.to("cuda"))           # (batch, seq_len, vocab_size)
    # 2) Greedy decode: pick highest logit per position
    token_ids = torch.argmax(logits, dim=-1)     # (batch, seq_len)
    # 3) Transform each sequence of IDs into text
    texts = tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
    return logits, texts

def get_next_token_prediction(embeddings, model, tokenizer, skip_special_tokens=True):
    # 1) Project embeddings to vocab logits
    logits = model.lm_head(embeddings)           # (batch, seq_len, vocab_size)
    # 2) Greedy decode: pick highest logit per position
    token_ids = torch.argmax(logits, dim=-1)[:,-1]     # (batch, seq_len)
    # 3) Transform each sequence of IDs into text
    texts = tokenizer.batch_decode(token_ids, skip_special_tokens=skip_special_tokens)
    return texts

@t.inference_mode()
def get_residual_output(prompt, layer_idx, llm, normalize = False):

    assert hasattr(llm, 'transformer'), "The model does not have a transformer attribute."
    assert hasattr(llm.transformer, 'h'), "The transformer does not have a 'h' attribute for layers."


    with llm.trace(prompt):
        residual_output = llm.transformer.h[layer_idx].output[0].save()  # Save the output of the layer for inspection

    if normalize: # FIXME: check if this is the correct way to normalize
        residual_output = llm.transformer.ln_f(residual_output)
    
    if llm.device.type == "cuda":
        residual_output = residual_output.detach().to("cpu")
        del llm 
        clean()
    
    return residual_output

@t.inference_mode()
def get_embeddings(prompt,llm):

    assert hasattr(llm, 'transformer'), "The model does not have a transformer attribute."
    assert hasattr(llm.transformer, 'drop'), "The transformer does not have a 'drop' attribute for input embeddings."
    
    with llm.trace(prompt):
        input_embeddings = llm.transformer.drop.input.save()

    if llm.device.type == "cuda":
        input_embeddings = input_embeddings.detach().to("cpu")
        del llm 
        clean()
    return input_embeddings
        


### TRYING THE DECODING

In [19]:
# make a for loop where for each layer you try the embeddings_to_texts_baseline for each hidden state output
full_text = True
print(f"{prompt_trial=} \n\n")
if full_text:   
    reversed_embeddings = get_embeddings(prompt_trial, llm)
    logits, texts = embeddings_to_texts_baseline(reversed_embeddings, llm, tokenizer)
    print(f"Reversed Embeddings for the prompt: {texts= } \n\n")
else:
    next_token_prediction = get_next_token_prediction(get_embeddings(prompt_trial, llm), llm, tokenizer)
    print(f"Next token prediction for the prompt: {next_token_prediction= } \n\n")


print(f"Iterating through each layer's output for the prompt: {prompt_trial}\n")
for layer_idx in range(len(llm.transformer.h)):
    residual_output = get_residual_output(prompt_trial, layer_idx, llm, True)
    if full_text:
        logits, texts = embeddings_to_texts_baseline(residual_output, llm, tokenizer)
        print(f"Layer {layer_idx} output texts: {texts}")
    else:
        next_token_prediction = get_next_token_prediction(residual_output, llm, tokenizer)
        print(f"Layer {layer_idx} next token prediction: {next_token_prediction}")

prompt_trial='1:10,2:20,3:' 




Reversed Embeddings for the prompt: texts= ['----------'] 


Iterating through each layer's output for the prompt: 1:10,2:20,3:

Layer 0 output texts: ['st sometimes controlled sometimesnd also Winc sometimesrd also']
Layer 1 output texts: ['.. Introductionpm sometimesnd South Winc sometimesrd South']
Layer 2 output texts: ['. Introductionpm000nd350}}}000rd Extract']
Layer 3 output texts: ['\nlayoutpm000nd200HI000rd200']
Layer 4 output texts: ['\n500pm000:#500}200:200']
Layer 5 output texts: ["\n500pm000:200});3:'200"]
Layer 6 output texts: ["\n00pm000:500)?3:'200"]
Layer 7 output texts: ["\n00pm000:'500]=3rd20"]
Layer 8 output texts: ["\n000pm000:'00],[3:'20"]
Layer 9 output texts: ["\n000pm000:'00}3rd20"]
Layer 10 output texts: ["\n000pm000:{00}etc:'20"]
Layer 11 output texts: ["\n000pm000:'20}etcrd20"]
Layer 12 output texts: ['\n000pm000)</10,...etc:20']


KeyboardInterrupt: 

# Trying With Mistral

In [None]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()
login(token=os.getenv("HUGGINGFACE_TOKEN"))

In [5]:
model_id = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",       # auto-slice layers across GPUs/CPU
    load_in_8bit=True,       # or load_in_4bit=True
    torch_dtype="auto"       # keep LayerNorm etc. in fp16/32
)


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.06s/it]


In [11]:
def get_embeddings_mistral(prompt):
    inputs = tokenizer(prompt, return_tensors="pt") # dictionary with input_ids and attention_mask
    input_ids = inputs.input_ids.to(model.device) # tensor of shape (1, 10)
    attention_mask = inputs.attention_mask.to(model.device) # tensor of shape (1, 10)
    # Get the embeddings
    with torch.no_grad():
        embeddings = model.model.embed_tokens(input_ids)

    # return the embeddings batch x sequence x embedding_dim
    return embeddings


In [52]:
prompt_trial = "Once upon a time, in a land far away, there lived a"
embeddings = get_embeddings_mistral(prompt_trial)
# print(embeddings.shape) # (1, 10, 768)
embeddings.shape

torch.Size([1, 15, 4096])

In [54]:
logits, text = embeddings_to_texts_baseline(embeddings, model, tokenizer)
text 

['ocker Groupisenerial候puislcerialuvudstock awaypuisafteronnaerial']

In [8]:
model.model.embed_tokens.weight.shape == model.lm_head.weight.shape

True

# Trying with GPT-2 Small

In [5]:
# load with gpt2 
SMALL = False
model_id = ["openai-community/gpt2-xl", "openai-community/gpt2"][SMALL]
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [9]:
prompt_trial = "Once upon a time, in a land far away, there lived a"
embeddings = get_embeddings_gpt(prompt_trial)
embeddings.shape

torch.Size([1, 14, 1600])

In [10]:
logits, text = embeddings_to_texts_baseline(embeddings, model, tokenizer)
text 


['Once upon a time, in a land far away, there lived a']

In [11]:
model.lm_head.weight.shape, model.transformer.wte.weight.shape

(torch.Size([50257, 1600]), torch.Size([50257, 1600]))

In [12]:
model.lm_head.weight.shape == model.transformer.wte.weight.shape

True

# Trying With Gemma2

In [9]:
model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",       # auto-slice layers across GPUs/CPU
    load_in_8bit=True,       # or load_in_4bit=True
    torch_dtype="auto"       # keep LayerNorm etc. in fp16/32
)
model

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Downloading shards: 100%|██████████| 2/2 [01:47<00:00, 53.61s/it] 
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.54s/it]


GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear8bitLt(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear8bitLt(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear8bitLt(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear8bitLt(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear8bitLt(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear8bitLt(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear8bitLt(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention

In [19]:
prompt_trial = "Once upon a time, in a land far away, there lived a"
embeddings = get_embeddings_mistral(prompt_trial)
normalized_embeddings = model.model.norm(embeddings)
# print(embeddings.shape) # (1, 10, 768)
embeddings.shape, normalized_embeddings.shape

(torch.Size([1, 15, 2048]), torch.Size([1, 15, 2048]))

In [15]:
logits, text = embeddings_to_texts_baseline(embeddings, model, tokenizer)
text 

[' increa increa increa increa increa increa increa increa increa increa increa increa increa increa']

In [20]:
logits_normalized, text_normalized = embeddings_to_texts_baseline(normalized_embeddings, model, tokenizer)
text_normalized 


[' increa increa increa increa increa increa increa increa increa increa increa increa increa increa']