In [1]:
import torch
from transformers import AutoModelForCausalLM

target_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    torch_dtype = torch.float16,
    device_map = "auto"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
target_model.model.layers[0]

MistralDecoderLayer(
  (self_attn): MistralSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm()
  (post_attention_layernorm): MistralRMSNorm()
)

In [3]:
from typing import Optional, Tuple
from collections import defaultdict

multiplied = defaultdict([])
cos = torch.nn.CosineSimilarity(dim=2)

def partial_forward(
    model,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

    residual = hidden_states

    hidden_states = model.input_layernorm(hidden_states)

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = model.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )
    
    #hidden_states = residual + hidden_states

    # Fully Connected
    #residual = hidden_states
    #hidden_states = model.post_attention_layernorm(hidden_states)
    #hidden_states = model.mlp(hidden_states)
    
    return residual, hidden_states


def get_vectors(name):
    def hook(model, input):
        residual, hidden_states = partial_forward(model, *input)
        pre = torch.flatten(residual.to("cuda:0"))
        post = torch.flatten(hidden_states.to("cuda:0"))
        similarity = cos(pre, post).detach().to("cpu").numpy()[0]
        print(residual.shape)
        print(hidden_states.shape)
        print(similarity)
        multiplied.append({
            "name": name,
            "similarity": similarity
        })
    return hook

for idx, layer in enumerate(target_model.model.layers):
    layer.register_forward_pre_hook(get_vectors(f"layer_{idx}"))

### Test

In [4]:
from transformers import pipeline, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

pipe = pipeline(
    "text-generation",
    model = target_model,
    tokenizer = tokenizer
)

pipe.tokenizer.pad_token_id = target_model.config.eos_token_id

pipe(
    "What's ML?",
    do_sample = False,
    max_new_tokens = 100,
    return_full_text = False
)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.0697  -0.05185 -0.08405 -0.00817 -0.059  ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.145   -0.2213  -0.135   -0.04764 -0.1221 ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[ 0.0362   -0.04446  -0.1294   -0.011024  0.03418 ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.02858 -0.03052 -0.01547  0.0621   0.1142 ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.2563   0.01846  0.04468  0.12006  0.1287 ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.3762   0.01949  0.05408  0.05707  0.1572 ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.1494   0.01839 -0.01212  0.07513  0.1069 ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.2418  -0.00982 -0.07196  0.02216  0.05692]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.2006  -0.04428 -0.2222   0.07477  0.1273 ]
torch.Size([1, 5, 4096])
torch.Size([1, 5, 4096])
[-0.3142  -0.1664  -0.1606  -0.02585  0.0328 ]
torch.Size([1, 5, 4096])


[{'generated_text': '\n\nComment: @JamesK. I\'m not sure what you mean by "ML" - could you clarify?\n\nComment: @JamesK I\'m not sure what you mean by "ML" - could you clarify?\n\nComment: @JamesK I\'m not sure what you mean by "ML" - could you clarify?\n\nComment: @JamesK I\'m not sure what you mean by "ML" -'}]

In [6]:
multiplied

[{'name': 'layer_0',
  'similarity': array([-0.1403,  0.2295, -0.572 , ..., -0.6567,  0.487 ,  0.454 ],
        dtype=float16)},
 {'name': 'layer_1',
  'similarity': array([-0.1584,  0.1466,  0.2856, ..., -0.2483, -0.4639, -0.2537],
        dtype=float16)},
 {'name': 'layer_2',
  'similarity': array([-0.6396, -0.0687,  0.9043, ...,  0.6064,  0.1501, -0.159 ],
        dtype=float16)},
 {'name': 'layer_3',
  'similarity': array([0.2903, 0.1943, 0.944 , ..., 0.1914, 0.454 , 0.51  ], dtype=float16)},
 {'name': 'layer_4',
  'similarity': array([0.582 , 0.099 , 0.1788, ..., 0.808 , 0.3196, 0.2686], dtype=float16)},
 {'name': 'layer_5',
  'similarity': array([ 0.663 ,  0.631 ,  0.303 , ..., -0.4353,  0.8247,  0.67  ],
        dtype=float16)},
 {'name': 'layer_6',
  'similarity': array([-0.769 , -0.5737,  0.7285, ..., -0.1206, -0.1823, -0.6084],
        dtype=float16)},
 {'name': 'layer_7',
  'similarity': array([-0.10376,  0.1144 ,  0.4019 , ...,  0.2208 ,  0.556  ,  0.05396],
        dtype=f

In [None]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split = "train_sft")
subset = dataset.select(range(1000))

In [None]:
from transformers.pipelines.pt_utils import KeyDataset
from tqdm import tqdm


for gen in tqdm(
    pipe(
        KeyDataset(subset, "prompt"),
        do_sample = True,
        batch_size = 8,
        max_new_tokens = 100,
        temperature = 0.1,
        top_p = 0.95,
        top_k = 50,
        seed = 4321,
        return_full_text = False
    )
):
    _ = None

In [None]:
multiplied

In [None]:
from tqdm import tqdm
from torch.nn import functional as F

sim_arr = []
cos = torch.nn.CosineSimilarity(dim=0)

for item in tqdm(multiplied):
    for idx, pre in enumerate(item["pre"]):
        sim_dict = {
            "id": idx,
            "layer": item["name"],
            "similarity": cos(pre, item["post"][idx])
        }
        sim_arr.append(sim_dict)

In [None]:
sim_arr[0]

In [None]:
import pandas as pd

sim_pd = []
idx = 0

for sim_ in sim_arr:
    sim_pd.append(
        {
            "layer": sim_["layer"],
            "similarity": sim_["similarity"].detach().cpu().to(torch.float16).numpy()[0]
        }
    )

In [None]:
sim_arr[0]