In [1]:
import torch
import copy, math
from transformers import LlamaForCausalLM, LlamaTokenizer
LAYER_NUM = 32
HEAD_NUM = 32
HEAD_DIM = 128
HIDDEN_DIM = HEAD_NUM * HEAD_DIM
torch.set_default_device("cuda")
zero_tensor = torch.tensor([0.0]*4096)

In [2]:
def transfer_output(model_output):
    all_pos_layer_input = []
    all_pos_attn_output = []
    all_pos_residual_output = []
    all_pos_ffn_output = []
    all_pos_layer_output = []
    all_last_attn_subvalues = []
    all_pos_coefficient_scores = []
    all_attn_scores = []
    for layer_i in range(LAYER_NUM):
        cur_layer_input = model_output[layer_i][0]
        cur_attn_output = model_output[layer_i][1]
        cur_residual_output = model_output[layer_i][2]
        cur_ffn_output = model_output[layer_i][3]
        cur_layer_output = model_output[layer_i][4]
        cur_last_attn_subvalues = model_output[layer_i][5]
        cur_coefficient_scores = model_output[layer_i][6]
        cur_attn_weights = model_output[layer_i][7]
        all_pos_layer_input.append(cur_layer_input[0].tolist())
        all_pos_attn_output.append(cur_attn_output[0].tolist())
        all_pos_residual_output.append(cur_residual_output[0].tolist())
        all_pos_ffn_output.append(cur_ffn_output[0].tolist())
        all_pos_layer_output.append(cur_layer_output[0].tolist())
        all_last_attn_subvalues.append(cur_last_attn_subvalues[0].tolist())
        all_pos_coefficient_scores.append(cur_coefficient_scores[0].tolist())
        all_attn_scores.append(cur_attn_weights)
    return all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, \
           all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores
def get_fc2_params(model, layer_num):
    return model.model.layers[layer_num].mlp.down_proj.weight.data
def get_bsvalues(vector, model, final_var):
    vector = vector * torch.rsqrt(final_var + 1e-6)
    vector_rmsn = vector * model.model.norm.weight.data
    vector_bsvalues = model.lm_head(vector_rmsn).data
    return vector_bsvalues
def get_layernorm_weight(model, layer_num):
    return model.model.layers[layer_num].post_attention_layernorm.weight.data
def get_prob(vector):
    prob = torch.nn.Softmax(-1)(vector)
    return prob
def get_log_increase(model, all_ffn_subvalues, all_pos_residual_output, final_var, predict_index):
    all_ffn_log_increase = []
    for layer_i in range(LAYER_NUM):
        cur_ffn_subvalues = all_ffn_subvalues[layer_i]
        cur_residual = torch.tensor(all_pos_residual_output[layer_i][-1])
        origin_prob_log = torch.log(get_prob(get_bsvalues(cur_residual, model, final_var))[predict_index])
        cur_ffn_subvalues_plus = cur_ffn_subvalues + cur_residual
        cur_ffn_subvalues_bsvalues = get_bsvalues(cur_ffn_subvalues_plus, model, final_var)
        cur_ffn_subvalues_probs = get_prob(cur_ffn_subvalues_bsvalues)
        cur_ffn_subvalues_probs = cur_ffn_subvalues_probs[:, predict_index]
        cur_ffn_subvalues_probs_log = torch.log(cur_ffn_subvalues_probs)
        cur_ffn_subvalues_probs_log_increase = cur_ffn_subvalues_probs_log - origin_prob_log
        all_ffn_log_increase.append(cur_ffn_subvalues_probs_log_increase.tolist())
    return all_ffn_log_increase
def get_pos_vector(vector, pos_embed_var, model, layer_num):
    vector = vector * torch.rsqrt(pos_embed_var + 1e-6)
    vector_rmsn = vector * model.model.layers[layer_num].input_layernorm.weight.data
    return vector_rmsn

In [3]:
#please replace your own dir saving llama-7b model.
#if you haven't downloaded it, you can try "huggyllama/llama-7b" to automatically download it from huggingface.
modelname = "../../scratch/save_models/llama-7b" 
tokenizer = LlamaTokenizer.from_pretrained(modelname)
model = LlamaForCausalLM.from_pretrained(modelname)
model.eval()
model.cuda()

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
 

In [4]:
#compute the final prediction of the input
test_sentence = "3+5="
indexed_tokens = tokenizer.encode(test_sentence)
tokens = [tokenizer.decode(x) for x in indexed_tokens]
tokens_tensor = torch.tensor([indexed_tokens])
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0]
predicted_top10 = torch.argsort(predictions[0][-1], descending=True)[:10]
predicted_text = [tokenizer.decode(x) for x in predicted_top10]
print(test_sentence, "=>", predicted_text)
all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, \
all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(outputs[1])
final_var = torch.tensor(all_pos_layer_output[-1][-1]).pow(2).mean(-1, keepdim=True)
pos_len = len(tokens)
print(tokens)

3+5= => ['8', '1', '6', '7', '2', '?', '', '3', '4', '5']
['<s>', '', '3', '+', '5', '=']


In [5]:
predict_index = predicted_top10[0].item()
print(predict_index, tokenizer.decode(predict_index))
cur_prob = get_prob(predictions[0][-1])[predict_index].item()
print("prob: ", cur_prob)
all_head_dict_prob, all_head_dict_logprob = {}, {}
all_head_dict_prob["old"] = cur_prob
all_head_dict_logprob["old"] = math.log(cur_prob)

29947 8
prob:  0.3775949776172638


In [6]:
#the memory is broken when the most important head is zero-intervened
all_attn_o_weights = []
for layer_index in range(LAYER_NUM):
    all_attn_o_weights.append(model.model.layers[layer_index].self_attn.o_proj.weight.data)
model1 = copy.deepcopy(model)
modify_heads = ['17_22']
modify_heads_dict = {}
for l_h in modify_heads:
    l, h = l_h.split("_")
    if l not in modify_heads_dict:
        modify_heads_dict[l] = []
    modify_heads_dict[l].append(h)
for l, hs in modify_heads_dict.items():
    layer_index = int(l)
    all_attn_o_weights_new = copy.deepcopy(all_attn_o_weights[layer_index])
    for h in hs:
        head_index = int(h)
        start, end = HEAD_DIM*head_index, HEAD_DIM*head_index+HEAD_DIM
        for neuron_index in range(start, end):
            all_attn_o_weights_new[:, neuron_index] = zero_tensor
    new_parameters = torch.nn.Parameter(all_attn_o_weights_new)
    model1.model.layers[layer_index].self_attn.o_proj.weight = new_parameters
    del all_attn_o_weights_new
with torch.no_grad():
    outputs1 = model1(tokens_tensor)
    predictions1 = outputs1[0]
cur_prob1 = get_prob(predictions1[0][-1])[predict_index].item()
predicted_top10_1 = torch.argsort(predictions1[0][-1], descending=True)[:10]
predicted_text1 = [tokenizer.decode(x) for x in predicted_top10_1]
print("prob: ", cur_prob1-cur_prob, test_sentence, "=>", predicted_text1)
all_pos_layer_input1, all_pos_attn_output1, all_pos_residual_output1, all_pos_ffn_output1, all_pos_layer_output1, \
all_last_attn_subvalues1, all_pos_coefficient_scores1, all_attn_scores1 = transfer_output(outputs1[1])

prob:  -0.315999049693346 3+5= => ['6', '1', '7', '8', '?', '2', '5', '3', '4', '']


In [7]:
#identifying the deep FFN neurons using comparable neuron analysis (CNA) method
all_ffn_subvalues = []
for layer_i in range(LAYER_NUM):
    coefficient_scores = torch.tensor(all_pos_coefficient_scores[layer_i][-1])
    fc2_vectors = get_fc2_params(model, layer_i)
    ffn_subvalues = (coefficient_scores * fc2_vectors).T
    all_ffn_subvalues.append(ffn_subvalues)
all_ffn_subvalues1 = []
for layer_i in range(LAYER_NUM):
    coefficient_scores1 = torch.tensor(all_pos_coefficient_scores1[layer_i][-1])
    fc2_vectors1 = get_fc2_params(model1, layer_i)
    ffn_subvalues1 = (coefficient_scores1 * fc2_vectors1).T
    all_ffn_subvalues1.append(ffn_subvalues1)
all_ffn_log_increase = get_log_increase(model, all_ffn_subvalues, all_pos_residual_output, final_var, predict_index)
all_ffn_log_increase1 = get_log_increase(model1, all_ffn_subvalues1, all_pos_residual_output1, final_var, predict_index)
all_ffn_scores = []
for layer_i in range(LAYER_NUM):
    for neuron_index in range(len(all_ffn_subvalues[0])):
        all_ffn_scores.append((str(layer_i)+"_"+str(neuron_index), 
                               round(all_ffn_log_increase[layer_i][neuron_index] - all_ffn_log_increase1[layer_i][neuron_index], 4), 
                               round(all_ffn_log_increase[layer_i][neuron_index], 4), 
                               round(all_ffn_log_increase1[layer_i][neuron_index], 4), 
                               round(all_pos_coefficient_scores[layer_i][-1][neuron_index]-all_pos_coefficient_scores1[layer_i][-1][neuron_index], 4), 
                               round(all_pos_coefficient_scores[layer_i][-1][neuron_index], 4), 
                               round(all_pos_coefficient_scores1[layer_i][-1][neuron_index], 4)))
all_ffn_scores_sort = sorted(all_ffn_scores, key=lambda x: x[1])[::-1]
for x in all_ffn_scores_sort[:10]:
    layer = int(x[0].split("_")[0])
    neuron = int(x[0].split("_")[1])
    cur_embed = all_ffn_subvalues[layer][neuron]
    cur_embed_bsvalue = get_bsvalues(cur_embed, model, final_var)
    cur_embed_bsvalue_sort = torch.argsort(cur_embed_bsvalue, descending=True)
    print(x, [tokenizer.decode(a) for a in cur_embed_bsvalue_sort[:10]])

('28_3696', 0.6964, 0.8229, 0.1265, -5.2503, -6.2005, -0.9502) ['8', 'eight', '⁸', 'VIII', 'huit', 'acht', '₈', '八', 'otto', 'e']
('25_7164', 0.2327, 0.3097, 0.077, 6.3519, 8.4364, 2.0845) ['six', 'eight', 'acht', 'Four', 'Six', 'twelve', 'six', 'four', 'vier', 'four']
('29_10957', 0.1384, 0.1464, 0.008, -6.2541, -6.6118, -0.3577) ['ві', 'desc', 'ги', 'rien', 'rach', 'Њ', 'usa', 'eight', 'Ott', 'Source']
('19_5769', 0.1348, 0.2041, 0.0692, 2.5066, 3.7902, 1.2837) ['eight', 'VIII', '⁸', '8', 'III', 'huit', '₈', 'acht', 'XVIII', '<0x88>']
('29_6563', 0.1322, 0.1812, 0.049, -1.5336, -2.0934, -0.5598) ['8', 'eight', '₈', '⁸', 'VIII', 'huit', 'otto', '7', 'acht', '₇']
('23_2618', 0.1293, 0.2058, 0.0766, 2.426, 3.8565, 1.4305) ['eight', 'sevent', '8', 'huit', 'acht', '7', 'sept', '⁷', 'seven', 'VIII']
('28_5475', 0.0965, -0.0921, -0.1886, -1.4833, 1.4104, 2.8937) ['Ά', 'ма', '↳', 'blatt', '⇔', '∘', '{`', 'temp', 'ℓ', 'чи']
('30_5933', 0.0923, 0.0727, -0.0197, 1.1264, 0.8947, -0.2317) ['eight

In [8]:
#compute the coefficient score of the selected FFN neuron (layer 19, neuron 5769)
test_layer_ffn, test_index_ffn = 19, 5769
print("increase: ", all_ffn_log_increase[test_layer_ffn][test_index_ffn], 
      "coefficient score: ", torch.tensor(all_pos_coefficient_scores[test_layer_ffn][-1])[test_index_ffn].item())
fc2_vector = get_fc2_params(model, test_layer_ffn).T[test_index_ffn]
fc2_vector_bsvalue = get_bsvalues(fc2_vector, model, final_var)
fc2_vector_bsvalue_sort = torch.argsort(fc2_vector_bsvalue, descending=True)
print("top value: ", [tokenizer.decode(x) for x in fc2_vector_bsvalue_sort[:10]])
print("last value: ", [tokenizer.decode(x) for x in fc2_vector_bsvalue_sort.tolist()[::-1][:10]])
fc1_vector_up = model.model.layers[test_layer_ffn].mlp.up_proj.weight.data[test_index_ffn].data
fc1_vector_gate = model.model.layers[test_layer_ffn].mlp.gate_proj.weight.data[test_index_ffn].data
cur_layernorm_weight = get_layernorm_weight(model, test_layer_ffn)
cur_ffn_key_up = cur_layernorm_weight * fc1_vector_up
cur_ffn_key_gate = cur_layernorm_weight * fc1_vector_gate
up_score_all = torch.sum(torch.tensor(all_pos_residual_output[test_layer_ffn][-1]) * cur_ffn_key_up).item()
gate_score_all = torch.sum(torch.tensor(all_pos_residual_output[test_layer_ffn][-1]) * cur_ffn_key_gate).item()
up_sym = 1.0 if up_score_all >= 0.0 else -1.0
gate_sym = 1.0 if gate_score_all >= 0.0 else -1.0
print("up score: ", up_score_all, "gate score: ", gate_score_all)
print("up symbol: ", up_sym, "gate symbol: ", gate_sym)

increase:  0.2040548324584961 coefficient score:  3.790210723876953
top value:  ['eight', 'VIII', '⁸', '8', 'III', 'huit', '₈', 'acht', 'XVIII', '<0x88>']
last value:  ['io', 'Vier', 'ied', 'ilo', 'fourth', 'Four', '四', 'ier', 'ismo', 'nin']
up score:  2.997763156890869 gate score:  1.7075436115264893
up symbol:  1.0 gate symbol:  1.0


In [9]:
#analyze layer 17
#the last position has the largest inner product with the subkey of the selected deep FFN neuron
attn_test_layer = 17
test_layer_input = torch.tensor(all_pos_layer_input[attn_test_layer])
test_layer_input_bsvalues = get_bsvalues(test_layer_input, model, final_var)
test_layer_input_sort = torch.argsort(test_layer_input_bsvalues, descending=True)
cur_v_heads = torch.tensor(all_last_attn_subvalues[attn_test_layer])
cur_attn_o_split = model.model.layers[attn_test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)
cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)
test_layer_subvalues = torch.sum(cur_attn_subvalues_headrecompute, 1)
test_layer_subvalues_bsvalues = get_bsvalues(test_layer_subvalues, model, final_var)
test_layer_subvalues_sort = torch.argsort(test_layer_subvalues_bsvalues, descending=True)
attn_inner_products_up = torch.sum(test_layer_subvalues*cur_ffn_key_up, -1)
attn_inner_products_gate = torch.sum(test_layer_subvalues*cur_ffn_key_gate, -1)
for pos in range(len(test_layer_input_sort)):
    print(pos, "up: ", attn_inner_products_up[pos].item(), "gate: ", attn_inner_products_gate[pos].item())
    print("input", [tokenizer.decode(x) for x in test_layer_input_sort[pos][:10]])
    print("subvalue", [tokenizer.decode(x) for x in test_layer_subvalues_sort[pos][:10]])

0 up:  0.019892985001206398 gate:  0.128106951713562
input ['<s>', 'sime', 'SERT', 'bolds', 'multicol', 'engelsk', 'UIView', 'ALSE', 'Ľ', 'kaf']
subvalue ['', '<0x0A>', ',', '(', '-', '.', 'of', '\xa0', ':', 's']
1 up:  -0.006079650484025478 gate:  -0.005465351976454258
input ['\ufeff', '←', '<0xE2>', 'ℚ', '\u200b', '1', 'Â', '��', 'ï', '\xad']
subvalue ['nom', 'nom', 'cat', 'ULT', 'cat', 'Nom', 'mans', 'ten', 'ход', 'Cat']
2 up:  -0.03674646466970444 gate:  -0.028093719854950905
input ['rd', 'thoughts', 'D', '0', 'new', 'things', 'bed', '6', 'M', 'Sister']
subvalue ['fourth', '⁴', '4', 'YY', 'uchar', 'papers', 'aterra', 'apers', 'ASE', 'UNION']
3 up:  -0.01766836829483509 gate:  0.011767487972974777
input ['/-', 'Sample', 'hours', 'hour', 'years', '-+', 'illa', '⁄', 'hour', 'Cleveland']
subvalue ['minus', '(-', '(-', 'között', 'minus', 'esser', 'onto', '++', 'üll', 'uno']
4 up:  0.05119187384843826 gate:  0.06610628962516785
input ['+', 'rule', 'dual', '=', '″', '=', 'aci', 'yr', 'Cha

In [10]:
#compute head score on the last position: head 22 is the most important head in layer 17
test_pos = 5
cur_layer_input = torch.tensor(all_pos_layer_input[attn_test_layer])
cur_v_heads = torch.tensor(all_last_attn_subvalues[attn_test_layer])
cur_attn_o_split = model.model.layers[attn_test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)
cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)
cur_attn_subvalues_head_curpos = cur_attn_subvalues_headrecompute[test_pos]
cur_layer_input_last = cur_layer_input[-1]
origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])
cur_attn_subvalues_head_plus = cur_attn_subvalues_head_curpos + cur_layer_input_last
cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(
    cur_attn_subvalues_head_plus, model, final_var))[:, predict_index])
cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob
cur_pos_heads_inner_products_up = torch.sum(cur_attn_subvalues_head_curpos * cur_ffn_key_up, -1)
cur_pos_heads_inner_products_gate = torch.sum(cur_attn_subvalues_head_curpos * cur_ffn_key_gate, -1)
cur_attn_plus_probs_increase_zip = []
for i in range(len(cur_attn_plus_probs_increase)):
    cur_attn_plus_probs_increase_zip.append((i, round(cur_attn_plus_probs_increase[i].item(), 4), 
                                             round(cur_pos_heads_inner_products_up[i].item(), 4), 
                                             round(cur_pos_heads_inner_products_gate[i].item(), 4)))
cur_attn_plus_probs_increase_sort = sorted(cur_attn_plus_probs_increase_zip, key=lambda x: x[3]+x[2])[::-1]
for head_index0, _, up_score, gate_score in cur_attn_plus_probs_increase_sort:
    print("head: ", head_index0, "up_score: ", up_score, "gate_score: ", gate_score)

head:  22 up_score:  0.4484 gate_score:  0.2736
head:  10 up_score:  0.0052 gate_score:  0.0079
head:  27 up_score:  0.004 gate_score:  0.0052
head:  7 up_score:  -0.0011 gate_score:  0.0103
head:  25 up_score:  0.0071 gate_score:  0.0014
head:  23 up_score:  0.0065 gate_score:  0.0004
head:  1 up_score:  0.0031 gate_score:  0.0025
head:  9 up_score:  0.0055 gate_score:  -0.0022
head:  15 up_score:  0.0036 gate_score:  -0.0018
head:  30 up_score:  -0.0014 gate_score:  0.003
head:  26 up_score:  -0.0003 gate_score:  0.0016
head:  29 up_score:  0.0009 gate_score:  0.0004
head:  14 up_score:  0.0007 gate_score:  0.0004
head:  28 up_score:  0.0005 gate_score:  0.0001
head:  11 up_score:  0.0005 gate_score:  -0.0002
head:  3 up_score:  -0.0031 gate_score:  0.0032
head:  24 up_score:  0.0001 gate_score:  -0.0001
head:  2 up_score:  -0.0001 gate_score:  -0.0
head:  16 up_score:  0.0006 gate_score:  -0.0008
head:  5 up_score:  -0.0018 gate_score:  0.0015
head:  13 up_score:  -0.0001 gate_score

In [11]:
#calculate which layer-level vector is the most important in head 17_22's last position's vector
pos_embed_var = torch.tensor(all_pos_layer_input)[attn_test_layer][test_pos].pow(2).mean(-1, keepdim=True)
curhead_v = model.model.layers[attn_test_layer].self_attn.v_proj.weight.split(HEAD_DIM)[head_index]
curhead_o = model.model.layers[attn_test_layer].self_attn.o_proj.weight.T.split(HEAD_DIM)[head_index]
previous_vectors = [torch.tensor(all_pos_layer_input)[0][test_pos]]
for layer_i in range(attn_test_layer):
    previous_vectors.append(torch.tensor(all_pos_attn_output)[layer_i][test_pos])
    previous_vectors.append(torch.tensor(all_pos_ffn_output)[layer_i][test_pos])
all_scores, all_o = [], []
for i, embed in enumerate(previous_vectors):
    embed = get_pos_vector(embed, pos_embed_var, model, attn_test_layer)
    cur_embed_v = torch.sum(curhead_v * embed, 1)
    cur_embed_o = torch.sum(curhead_o.T * cur_embed_v, 1)
    all_o.append(cur_embed_o)
    cur_embed_inner_products_up = torch.sum(cur_embed_o * cur_ffn_key_up, -1)
    cur_embed_inner_products_gate = torch.sum(cur_embed_o * cur_ffn_key_gate, -1)
    all_scores.append((cur_embed_inner_products_up.item(), cur_embed_inner_products_gate.item()))
all_scores_sort = sorted(zip([i/2-0.5 for i in range(len(all_scores))], all_scores), key=lambda x: x[1][0])[::-1]
print(all_scores_sort[:10])

[(15.0, (0.12116342782974243, 0.0976787805557251)), (13.0, (0.09483273327350616, 0.09762558341026306)), (14.0, (0.052968062460422516, 0.013312157243490219)), (8.0, (0.023074448108673096, -0.005840588361024857)), (7.0, (0.019325416535139084, 0.015491408295929432)), (6.0, (0.019316932186484337, -0.009192920289933681)), (11.0, (0.019167589023709297, -0.008673514239490032)), (16.5, (0.019051995128393173, -0.006628106348216534)), (7.5, (0.018388744443655014, -0.011169372126460075)), (15.5, (0.0173991359770298, 0.0006903298199176788))]


In [12]:
#see which previous vector at which position is useful (find previous subvalue's pos)
previous_layer = 15
cur_v_heads = torch.tensor(all_last_attn_subvalues[previous_layer])
cur_attn_o_split = model.model.layers[previous_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)
cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)
previous_layer_attnvalues = torch.sum(cur_attn_subvalues_headrecompute, 1)
for pos in range(len(previous_layer_attnvalues)):
    cur_embed = previous_layer_attnvalues[pos]
    cur_embed_bsvalue = get_bsvalues(cur_embed, model, final_var)
    cur_embed_bsvalue_sort = torch.argsort(cur_embed_bsvalue, descending=True)
    cur_embed_ln = get_pos_vector(cur_embed, pos_embed_var, model, attn_test_layer)
    cur_embed_v = torch.sum(curhead_v * cur_embed_ln, 1)
    cur_embed_o = torch.sum(curhead_o.T * cur_embed_v, 1)
    cur_embed_inner_products_up = torch.sum(cur_embed_o * cur_ffn_key_up, -1)
    cur_embed_inner_products_gate = torch.sum(cur_embed_o * cur_ffn_key_gate, -1)
    print("pos: ", pos, "up: ", cur_embed_inner_products_up.item(), "gate: ", cur_embed_inner_products_gate.item())

pos:  0 up:  -0.0012112583499401808 gate:  0.0016715569654479623
pos:  1 up:  -0.0005289626424200833 gate:  -0.000134551664814353
pos:  2 up:  0.06307642161846161 gate:  0.09762759506702423
pos:  3 up:  0.014674997888505459 gate:  0.0052227722480893135
pos:  4 up:  0.0697443038225174 gate:  0.006158954929560423
pos:  5 up:  -0.024592075496912003 gate:  -0.012867532670497894


In [13]:
#see which previous layer is useful for computing attn subvalue 15_2
mask_layer, mask_pos = 15, 2
test_embed1 = torch.tensor(all_pos_layer_input[0][mask_pos])
pos_embed_var1_mask_layer = torch.tensor(all_pos_layer_input[mask_layer][mask_pos]).pow(2).mean(-1, keepdim=True)
previous_masklayer = [torch.tensor(all_pos_layer_input[0][mask_pos])]
for i in range(mask_layer):
    previous_masklayer.append(torch.tensor(all_pos_attn_output[i][mask_pos]))
    previous_masklayer.append(torch.tensor(all_pos_ffn_output[i][mask_pos]))

all_previous_vo = []
for i, embed in enumerate(previous_masklayer):
    embed_ln = get_pos_vector(embed, pos_embed_var1_mask_layer, model, mask_layer)
    embed_ov = model.model.layers[mask_layer].self_attn.o_proj(model.model.layers[mask_layer].self_attn.v_proj(embed_ln))
    embed_ln2 = get_pos_vector(embed_ov, pos_embed_var, model, attn_test_layer)
    embed_v2 = torch.sum(curhead_v * embed_ln2, 1)
    embed_o2 = torch.sum(curhead_o.T * embed_v2, 1)
    all_previous_vo.append(embed_o2)

all_previous_vo_scores = []
for i, x in enumerate(all_previous_vo):
    all_previous_vo_scores.append((i/2-0.5, torch.sum(x * cur_ffn_key_up, -1).item(), torch.sum(x * cur_ffn_key_gate, -1).item()))

all_previous_vo_scores_sort = sorted(all_previous_vo_scores, key=lambda x: x[1])[::-1]
print(all_previous_vo_scores_sort[:10])

[(12.5, 0.03383322432637215, 0.07460989058017731), (10.5, 0.019165528938174248, 0.015703681856393814), (14.5, 0.01742423139512539, 0.00963421631604433), (6.5, 0.016220010817050934, 0.02717570960521698), (0.5, 0.01072070561349392, 0.01012455578893423), (1.5, 0.010294288396835327, 0.01426645927131176), (8.0, 0.009111599996685982, 0.00042461464181542397), (11.0, 0.007322901394218206, 0.009814698249101639), (-0.5, 0.006614874117076397, 0.010071543976664543), (12.0, 0.006437157280743122, -0.008164125494658947)]


In [14]:
#see ffn subvalues on pos 4
pos4_useful_layer_ffn = 12
pos4_useful_layer_ffn_coeffs = torch.tensor(all_pos_coefficient_scores)[pos4_useful_layer_ffn][mask_pos]
pos4_useful_layer_fc2 = get_fc2_params(model, pos4_useful_layer_ffn)
pos4_useful_layer_ffn_subvalues = (pos4_useful_layer_ffn_coeffs * pos4_useful_layer_fc2).T
pos4_useful_layer_ffn_subvalues_ln = get_pos_vector(pos4_useful_layer_ffn_subvalues, pos_embed_var1_mask_layer, model, mask_layer)
pos4_useful_layer_ffn_subvalues_ln_ov = model.model.layers[mask_layer].self_attn.o_proj(model.model.layers[mask_layer].self_attn.v_proj(pos4_useful_layer_ffn_subvalues_ln))
pos4_useful_layer_ffn_subvalues_ln_ov_ln2 = get_pos_vector(pos4_useful_layer_ffn_subvalues_ln_ov, pos_embed_var, model, attn_test_layer)
pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2 = torch.bmm(pos4_useful_layer_ffn_subvalues_ln_ov_ln2.unsqueeze(0), curhead_v.T.unsqueeze(0))
pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2_o2 = torch.bmm(pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2, curhead_o.unsqueeze(0)).squeeze(0)

all_ffn_pos_scores_up = torch.sum(pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2_o2 * cur_ffn_key_up, 1)
all_ffn_pos_scores_gate = torch.sum(pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2_o2 * cur_ffn_key_gate, 1)
all_ffn_pos_scores_upgate = []
for i in range(len(all_ffn_pos_scores_gate)):
    all_ffn_pos_scores_upgate.append((i, all_ffn_pos_scores_up[i].item(), all_ffn_pos_scores_gate[i].item()))

all_ffn_pos_scores_upgate_sort = sorted(all_ffn_pos_scores_upgate, key=lambda x: x[1]+x[2])[::-1]
for index, up_score, gate_score in all_ffn_pos_scores_upgate_sort[:1]:
    print(index, up_score, gate_score)
    embed = pos4_useful_layer_ffn_subvalues[index]
    embed_bsvalue = get_bsvalues(embed, model, final_var)
    embed_bsvalue_sort = torch.argsort(embed_bsvalue, descending=True)
    print("value original: ", [tokenizer.decode(x) for x in embed_bsvalue_sort[:10]])
    embed_ln = get_pos_vector(embed, pos_embed_var1_mask_layer, model, mask_layer)
    embed_ov = model.model.layers[mask_layer].self_attn.o_proj(model.model.layers[mask_layer].self_attn.v_proj(embed_ln))
    embed_bsvalue = get_bsvalues(embed_ov, model, final_var)
    embed_bsvalue_sort = torch.argsort(embed_bsvalue, descending=True)
    print("value transform: ", [tokenizer.decode(x) for x in embed_bsvalue_sort[:10]])

4072 0.02357804775238037 0.055049389600753784
value original:  ['rd', 'quarters', 'PO', 'peat', '⅓', 'Constraint', 'ran', 'avas', 'ր', 'angol']
value transform:  ['III', 'three', 'Three', '三', '<0x84>', '₃', '3', 'three', 'Three', 'triple']


In [15]:
#see which previous layer is useful for computing attn subvalue 13_4
mask_layer, mask_pos = 15, 4
test_embed1 = torch.tensor(all_pos_layer_input[0][mask_pos])
pos_embed_var1_mask_layer = torch.tensor(all_pos_layer_input[mask_layer][mask_pos]).pow(2).mean(-1, keepdim=True)
previous_masklayer = [torch.tensor(all_pos_layer_input[0][mask_pos])]
for i in range(mask_layer):
    previous_masklayer.append(torch.tensor(all_pos_attn_output[i][mask_pos]))
    previous_masklayer.append(torch.tensor(all_pos_ffn_output[i][mask_pos]))

all_previous_vo = []
for i, embed in enumerate(previous_masklayer):
    embed_ln = get_pos_vector(embed, pos_embed_var1_mask_layer, model, mask_layer)
    embed_ov = model.model.layers[mask_layer].self_attn.o_proj(model.model.layers[mask_layer].self_attn.v_proj(embed_ln))
    embed_ln2 = get_pos_vector(embed_ov, pos_embed_var, model, attn_test_layer)
    embed_v2 = torch.sum(curhead_v * embed_ln2, 1)
    embed_o2 = torch.sum(curhead_o.T * embed_v2, 1)
    all_previous_vo.append(embed_o2)

all_previous_vo_scores = []
for i, x in enumerate(all_previous_vo):
    all_previous_vo_scores.append((i/2-0.5, torch.sum(x * cur_ffn_key_up, -1).item(), torch.sum(x * cur_ffn_key_gate, -1).item()))

all_previous_vo_scores_sort = sorted(all_previous_vo_scores, key=lambda x: x[1])[::-1]
print(all_previous_vo_scores_sort[:10])

[(14.5, 0.06068086624145508, 0.00012453319504857063), (11.5, 0.05477098375558853, 0.020668024197220802), (5.5, 0.0283675380051136, 0.006406363565474749), (7.5, 0.02060963399708271, 0.02500823885202408), (12.5, 0.017112277448177338, 0.04150316119194031), (0.5, 0.012670803815126419, 0.008029316551983356), (14.0, 0.01023088488727808, -0.012319373898208141), (9.0, 0.009961117058992386, -0.0029868069104850292), (2.5, 0.008939795196056366, -0.0016081184148788452), (11.0, 0.008081584237515926, 0.009449582546949387)]


In [16]:
#see ffn subvalues on pos 4
pos4_useful_layer_ffn = 11
pos4_useful_layer_ffn_coeffs = torch.tensor(all_pos_coefficient_scores)[pos4_useful_layer_ffn][mask_pos]
pos4_useful_layer_fc2 = get_fc2_params(model, pos4_useful_layer_ffn)
pos4_useful_layer_ffn_subvalues = (pos4_useful_layer_ffn_coeffs * pos4_useful_layer_fc2).T
pos4_useful_layer_ffn_subvalues_ln = get_pos_vector(pos4_useful_layer_ffn_subvalues, pos_embed_var1_mask_layer, model, mask_layer)
pos4_useful_layer_ffn_subvalues_ln_ov = model.model.layers[mask_layer].self_attn.o_proj(model.model.layers[mask_layer].self_attn.v_proj(pos4_useful_layer_ffn_subvalues_ln))
pos4_useful_layer_ffn_subvalues_ln_ov_ln2 = get_pos_vector(pos4_useful_layer_ffn_subvalues_ln_ov, pos_embed_var, model, attn_test_layer)
pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2 = torch.bmm(pos4_useful_layer_ffn_subvalues_ln_ov_ln2.unsqueeze(0), curhead_v.T.unsqueeze(0))
pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2_o2 = torch.bmm(pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2, curhead_o.unsqueeze(0)).squeeze(0)

all_ffn_pos_scores_up = torch.sum(pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2_o2 * cur_ffn_key_up, 1)
all_ffn_pos_scores_gate = torch.sum(pos4_useful_layer_ffn_subvalues_ln_ov_ln2_v2_o2 * cur_ffn_key_gate, 1)
all_ffn_pos_scores_upgate = []
for i in range(len(all_ffn_pos_scores_gate)):
    all_ffn_pos_scores_upgate.append((i, all_ffn_pos_scores_up[i].item(), all_ffn_pos_scores_gate[i].item()))

all_ffn_pos_scores_upgate_sort = sorted(all_ffn_pos_scores_upgate, key=lambda x: x[1]+x[2])[::-1]
for index, up_score, gate_score in all_ffn_pos_scores_upgate_sort[:1]:
    print(index, up_score, gate_score)
    embed = pos4_useful_layer_ffn_subvalues[index]
    embed_bsvalue = get_bsvalues(embed, model, final_var)
    embed_bsvalue_sort = torch.argsort(embed_bsvalue, descending=True)
    print("value original: ", [tokenizer.decode(x) for x in embed_bsvalue_sort[:10]])
    embed_ln = get_pos_vector(embed, pos_embed_var1_mask_layer, model, mask_layer)
    embed_ov = model.model.layers[mask_layer].self_attn.o_proj(model.model.layers[mask_layer].self_attn.v_proj(embed_ln))
    embed_bsvalue = get_bsvalues(embed_ov, model, final_var)
    embed_bsvalue_sort = torch.argsort(embed_bsvalue, descending=True)
    print("value transform: ", [tokenizer.decode(x) for x in embed_bsvalue_sort[:10]])

2258 0.02762100100517273 0.012637176550924778
value original:  ['ös', 'enz', 'Trace', 'lis', 'vid', 'suite', 'HT', 'ung', 'ane', 'icano']
value transform:  ['XV', 'fifth', 'Fif', 'avas', 'Five', 'five', 'abase', '五', '₅', 'fif']
