In [1]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
with open("config_train.yaml", "r") as file:
    config = yaml.safe_load(file)

In [4]:
model_name = config['model']
trained_checkpoint = config['eval']['trained_checkpoint']
model_name, trained_checkpoint

('gemma',
 '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_one_direction')

In [34]:
if model_name == "bart":
    from transformers import BartForConditionalGeneration, BartTokenizer
    model_checkpoint = "facebook/bart-large"
    tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
    model = BartForConditionalGeneration.from_pretrained(trained_checkpoint)
elif "pythia" in model_name:
    from transformers import GPTNeoXForCausalLM, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
    tokenizer.pad_token = tokenizer.eos_token
    trained_checkpoint = "EleutherAI/pythia-1.4b"
    model = GPTNeoXForCausalLM.from_pretrained(trained_checkpoint)
    model.config.pad_token_id = tokenizer.pad_token_id
elif "gemma" in model_name:
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it")
    model = AutoModelForCausalLM.from_pretrained(
        trained_checkpoint,
    )
model_a = model.to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [35]:
both_directions = '/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_both_directions'
model_b = AutoModelForCausalLM.from_pretrained(both_directions).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [66]:
model_c = "google/gemma-1.1-2b-it"
model_c = AutoModelForCausalLM.from_pretrained(model_c).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [36]:
import nnsight
from nnsight import NNsight

model_a_nns = NNsight(model_a)
model_b_nns = NNsight(model_b)

In [37]:
model_a_nns

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [88]:
text = "Peter, Paul and"
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)

with model_a_nns.trace(inputs['input_ids']) as tracer:
    embeddings_a = model_a_nns.model.embed_tokens.output.save()
outputs_a = model_a_nns(**inputs)
first_embedding = embeddings_a.value[0,1]
predicted_ids = torch.argmax(outputs_a.logits, dim=-1)
tokenizer.decode(predicted_ids[0], skip_special_tokens=True)

'The Mac a and Mary'

In [89]:
text = "Mary, Paul and"
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)

with model_b_nns.trace(inputs['input_ids']) as tracer_b:
    embeddings_b = model_b_nns.model.embed_tokens.output.clone().save()
predicted_ids = torch.argmax(model_b_nns(**inputs).logits, dim=-1)
tokenizer.decode(predicted_ids[0], skip_special_tokens=True)

'The Pat a and their'

In [None]:
# model_b_embeddings.value[0, 1] = first_embedding  # Batch index 0, token index 1
with model_b_nns.trace(inputs['input_ids']) as tracer_b:
    model_b_nns.model.embed_tokens.output = embeddings_a
    outputs_b = model_b_nns.lm_head.output.save()
predicted_ids = torch.argmax(outputs_b, dim=-1)
tokenizer.decode(predicted_ids[0], skip_special_tokens=True)

tensor([[[-17.6485, -10.8599, -51.9013,  ..., -15.6539, -16.4370, -17.2726],
         [-27.0210,  -7.6944, -14.5683,  ..., -29.8582, -29.5331, -26.6342],
         [-30.8921,  -7.1539,  -0.1174,  ..., -27.9815, -28.4202, -30.5455],
         [-26.9390,  -4.6514, -23.7419,  ..., -18.7072, -18.6069, -26.5479],
         [-19.4025,   2.0000,  -7.2791,  ...,  -8.7191, -12.3956, -19.0418]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)


'The Mac a and Mary'

In [45]:
model_b_embeddings

tensor([[[ 0.1037,  0.0039, -0.0329,  ..., -0.0184, -0.0093, -0.0113],
         [ 0.1891, -0.0469, -0.0249,  ...,  0.0071,  0.0725,  0.0949],
         [ 0.2044,  0.0173, -0.0730,  ..., -0.0622,  0.0289, -0.0329],
         [ 0.2503, -0.0532, -0.1282,  ..., -0.0097,  0.0025,  0.0159],
         [ 0.2924, -0.0775,  0.0106,  ...,  0.0373,  0.0572, -0.0153]]],
       device='cuda:0', grad_fn=<CloneBackward0>)