In [20]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.mistral.modeling_mistral import *
from transformers.cache_utils import DynamicCache
import torch
from torch import nn
import copy
from types import MethodType
import pandas as pd

In [2]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1").to('cuda')
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
Token = {v: k for k, v in tokenizer.get_vocab().items()}

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

In [15]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7ff6902f6b90>

In [18]:
def topk(v, k=40, aux=None):
    # Takes in logits
    #v = softmax(v.flatten())
    if type(v) == torch.Tensor:
        v = v.detach().cpu().numpy()
    v = v.flatten()
    idxs = v.argsort()[-k:][::-1]
    if aux:
        ret = [(Token[i], v[i]) + tuple(aux[i]) for i in idxs]
        return pd.DataFrame(ret, columns=['token', 'logit'] + list(range(len(aux[0]))))
    else:
        ret = [(Token[i], v[i]) for i in idxs]
        return pd.DataFrame(ret, columns=['token', 'logit'])

In [40]:
list(input['input_ids'][0].cpu().numpy())

[1, 1984, 6656, 2442, 302, 272, 23404, 2401, 349, 549, 28250]

In [47]:
input.keys()

dict_keys(['input_ids', 'attention_mask'])

In [53]:
len(out.past_key_values)

32

In [54]:
model.model.layers[0]

MistralDecoderLayer(
  (self_attn): MistralAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm()
  (post_attention_layernorm): MistralRMSNorm()
)

In [57]:
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

In [77]:
def lens(model, prompt):
    ret = []
    input = tokenizer(prompt, return_tensors='pt').to('cuda')
    out = model(**input, output_hidden_states=True, output_attentions=True)
    ret.append(
        ['E'] + [Token[x] for x in list(input['input_ids'][0].cpu().numpy())]
    )
    seq_len = input['input_ids'].shape[-1]
    for h, h_prev, attn, (keys, values), decoder in zip(
        out.hidden_states[1:], out.hidden_states[:-1],
        out.attentions, out.past_key_values, model.model.layers
    ):
        keys = repeat_kv(keys, 4)
        values = repeat_kv(values, 4)
        h_row = ['H'] + [
            #topk(model.lm_head(h)[0,j,:], k=1).iloc[0,0] 
            torch.norm((h)[0,j,:]).item()
            for j in range(seq_len)
        ]
        hd_row = ['Hd'] + [
            #topk(model.lm_head(h)[0,j,:], k=1).iloc[0,0] 
            torch.norm((h-h_prev)[0,j,:]).item()
            for j in range(seq_len)
        ]
        
        a = torch.matmul(attn, values).transpose(1, 2).contiguous()
        a = a.reshape(1, seq_len, -1)
        a = decoder.self_attn.o_proj(a)
        a_row = ['A'] + [
            torch.norm(a[0,j,:]).item()
            #topk(model.lm_head(a)[0,j,:], k=1).iloc[0,0] 
            for j in range(seq_len)
        ]
        f = h - h_prev - a
        f_row = ['F'] + [torch.norm(f[0,j,:]).item() for j in range(seq_len)]
        ret.append(a_row)
        ret.append(f_row)
        ret.append(hd_row)
        ret.append(h_row)
    return pd.DataFrame(data=ret)

In [78]:
out = lens(model, prompt)

In [81]:
out.iloc[-30:,[0, 7, 8, 9, 10, 11]]

Unnamed: 0,0,7,8,9,10,11
99,Hd,5.58987,3.881607,4.543558,5.443723,4.543383
100,H,19.137266,17.948874,18.324638,17.267794,16.813477
101,A,0.783993,1.215868,0.592541,0.581104,0.838376
102,F,4.603269,3.90245,6.33707,6.682795,4.82113
103,Hd,4.71511,4.126276,6.411267,6.7142,4.950005
104,H,20.087797,19.162827,20.183628,19.750025,18.357061
105,A,0.730662,1.450283,1.222855,0.704192,1.69537
106,F,5.34078,4.049546,6.573922,5.84883,4.495585
107,Hd,5.422429,4.279365,6.709745,5.923417,4.827442
108,H,21.611397,20.505722,22.066854,20.677406,19.252426


In [4]:
prompt = 'My favorite element of the periodic table is platinum'
input = tokenizer(prompt, return_tensors='pt').to('cuda')

In [5]:
out = model(**input, output_hidden_states=True)

In [9]:
len(out.hidden_states)

33

In [10]:
model.model.layers[0]

MistralDecoderLayer(
  (self_attn): MistralAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm()
  (post_attention_layernorm): MistralRMSNorm()
)

In [25]:
topk(model.lm_head(out.hidden_states[0])[0,5,:]).iloc[0,0]

'▁same'

In [17]:
type(out.hidden_states[0])

torch.Tensor