In [1]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
with open("config_train.yaml", "r") as file:
    config = yaml.safe_load(file)

In [3]:
model_name = config['model']
trained_checkpoint = config['eval']['trained_checkpoint']
model_name, trained_checkpoint

('gemma',
 '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_one_direction')

In [4]:
if model_name == "bart":
    from transformers import BartForConditionalGeneration, BartTokenizer
    model_checkpoint = "facebook/bart-large"
    tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
    model = BartForConditionalGeneration.from_pretrained(trained_checkpoint)
elif "pythia" in model_name:
    from transformers import GPTNeoXForCausalLM, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
    tokenizer.pad_token = tokenizer.eos_token
    trained_checkpoint = "EleutherAI/pythia-1.4b"
    model = GPTNeoXForCausalLM.from_pretrained(trained_checkpoint)
    model.config.pad_token_id = tokenizer.pad_token_id
elif "gemma" in model_name:
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it")
    model = AutoModelForCausalLM.from_pretrained(
        trained_checkpoint,
    )
model = model.to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
import nnsight
from nnsight import NNsight

model_nns = NNsight(model)

In [10]:
model_nns

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [None]:
text = "Once upon a time"
inputs = tokenizer(text, return_tensors="pt")

with model_nns.trace(inputs['input_ids']) as tracer:
    hidden_states = model_nns.model.embed_tokens.output.save()

In [24]:
hidden_states

tensor([[[ 0.1036,  0.0039, -0.0329,  ..., -0.0184, -0.0093, -0.0112],
         [ 0.1889, -0.0468, -0.0250,  ...,  0.0070,  0.0725,  0.0947],
         [ 0.2044,  0.0173, -0.0731,  ..., -0.0622,  0.0289, -0.0328],
         [ 0.2503, -0.0531, -0.1281,  ..., -0.0097,  0.0025,  0.0158],
         [ 0.2926, -0.0773,  0.0107,  ...,  0.0374,  0.0570, -0.0153]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [25]:
hidden_states.shape

torch.Size([1, 5, 2048])

In [22]:
inputs

{'input_ids': tensor([[    2, 14326,  3054,   476,  1069]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}