In [None]:
import json, os
import torch
import bertviz, uuid
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
from collections import Counter
from transformers import AutoModelForCausalLM, AutoTokenizer
from IPython.core.display import display, HTML, Javascript
from bertviz.util import format_special_chars, format_attention, num_layers, num_heads
LAYER_NUM = 36  # Qwen3 has 36 layers
HEAD_NUM = 32
HEAD_DIM = 128
HIDDEN_DIM = HEAD_NUM * HEAD_DIM
torch.set_default_device("cuda:1")  # Use GPU 1

In [None]:
def transfer_output(model_output):
    all_pos_layer_input = []
    all_pos_attn_output = []
    all_pos_residual_output = []
    all_pos_ffn_output = []
    all_pos_layer_output = []
    all_last_attn_subvalues = []
    all_pos_coefficient_scores = []
    all_attn_scores = []
    
    for layer_i in range(len(model_output)):
        layer_data = model_output[layer_i]
        
        if len(layer_data) >= 8:
            cur_layer_input = layer_data[0]
            cur_attn_output = layer_data[1]
            cur_residual_output = layer_data[2]
            cur_ffn_output = layer_data[3]
            cur_layer_output = layer_data[4]
            cur_last_attn_subvalues = layer_data[5]
            cur_coefficient_scores = layer_data[6]
            cur_attn_weights = layer_data[7]
            
            all_pos_layer_input.append(cur_layer_input[0].tolist())
            all_pos_attn_output.append(cur_attn_output[0].tolist())
            all_pos_residual_output.append(cur_residual_output[0].tolist())
            all_pos_ffn_output.append(cur_ffn_output[0].tolist())
            all_pos_layer_output.append(cur_layer_output[0].tolist())
            all_last_attn_subvalues.append(cur_last_attn_subvalues[0].tolist())
            all_pos_coefficient_scores.append(cur_coefficient_scores[0].tolist())
            all_attn_scores.append(cur_attn_weights)
        else:
            break
            
    return all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, \
           all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores

def get_fc2_params(model, layer_num):
    return model.model.layers[layer_num].mlp.down_proj.weight.data
def get_bsvalues(vector, model, final_var):
    vector = vector * torch.rsqrt(final_var + 1e-6)
    vector_rmsn = vector * model.model.norm.weight.data
    vector_bsvalues = model.lm_head(vector_rmsn).data
    return vector_bsvalues
    return vector_bsvalues
def get_prob(vector):
    prob = torch.nn.Softmax(-1)(vector)
    return prob
def transfer_l(l):
    new_x, new_y = [], []
    for x in l:
        new_x.append(x[0])
        new_y.append(x[1])
    return new_x, new_y
def plt_bar(x, y, yname="log increase"):
    x_major_locator=MultipleLocator(1)
    plt.figure(figsize=(8, 3))
    ax=plt.gca()
    ax.xaxis.set_major_locator(x_major_locator)
    plt_x = [a/2 for a in x]
    plt.xlim(-0.5, plt_x[-1]+0.49)
    x_attn, y_attn, x_ffn, y_ffn = [], [], [], []
    for i in range(len(x)):
        if i%2 == 0:
            x_attn.append(x[i]/2)
            y_attn.append(y[i])
        else:
            x_ffn.append(x[i]/2)
            y_ffn.append(y[i])
    plt.bar(x_attn, y_attn, color="darksalmon", label="attention layers")
    plt.bar(x_ffn, y_ffn, color="lightseagreen", label="FFN layers")
    plt.xlabel("layer")
    plt.ylabel(yname)
    plt.legend()
    plt.show()
def plt_heatmap(data):
    xLabel = range(len(data[0]))
    yLabel = range(len(data))
    fig = plt.figure(figsize=(10,8))
    ax = fig.add_subplot(111)
    ax.set_xticks(range(len(xLabel)))
    ax.set_yticklabels(yLabel)
    im = ax.imshow(data, cmap=plt.cm.hot_r)
    #plt.colorbar(im)
    plt.title("attn head log increase heatmap")
    plt.show()

In [None]:
modelname = "Qwen" 
tokenizer = AutoTokenizer.from_pretrained(modelname)
model = AutoModelForCausalLM.from_pretrained(modelname, attn_implementation="eager")
model.eval()
model.to("cuda:1")  # Use GPU 1

In [None]:
test_sentence = "Tim Duncan plays the sport of"
indexed_tokens = tokenizer.encode(test_sentence)
tokens = [tokenizer.decode(x) for x in indexed_tokens]
tokens_tensor = torch.tensor([indexed_tokens])
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0]
predicted_top10 = torch.argsort(predictions[0][-1], descending=True)[:10]
predicted_text = [tokenizer.decode(x) for x in predicted_top10]
print(test_sentence, "=>", predicted_text)
# print(transfer_output(outputs[1]))
# print(len(outputs[1]))
# print(outputs[2])


In [None]:
test_sentence = "Tim Duncan plays the sport of"
indexed_tokens = tokenizer.encode(test_sentence)
tokens = [tokenizer.decode(x) for x in indexed_tokens]
tokens_tensor = torch.tensor([indexed_tokens]).to("cuda:1")

with torch.no_grad():
    base_outputs = model.model(tokens_tensor, use_cache=True, output_attentions=True)
    
if base_outputs.past_key_values is not None:
    if len(base_outputs.past_key_values) > 0:
        print(type(base_outputs.past_key_values[0]))
        print(len(base_outputs.past_key_values[0]))
        
        first_state = base_outputs.past_key_values[0]
        try:
            all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(base_outputs.past_key_values)
        except Exception as e:
            import traceback
            traceback.print_exc()
else:
    print("Base past_key_values empty")

In [None]:
test_sentence = "Tim Duncan plays the sport of"
indexed_tokens = tokenizer.encode(test_sentence)
tokens = [tokenizer.decode(x) for x in indexed_tokens]
tokens_tensor = torch.tensor([indexed_tokens]).to("cuda:1")

with torch.no_grad():
    outputs = model(tokens_tensor, use_cache=True, output_attentions=True)
    predictions = outputs[0]
predicted_top10 = torch.argsort(predictions[0][-1], descending=True)[:10]
predicted_text = [tokenizer.decode(x) for x in predicted_top10]
print(test_sentence, "=>", predicted_text)

with torch.no_grad():
    base_outputs = model.model(tokens_tensor, use_cache=True, output_attentions=True)

all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(base_outputs.past_key_values)
final_var = torch.tensor(all_pos_layer_output[-1][-1]).pow(2).mean(-1, keepdim=True)
pos_len = len(tokens)
print(tokens)

In [None]:
predict_index = predicted_top10[0].item()
print(predict_index, tokenizer.decode(predict_index))

In [None]:
#layer-level increase (value layers)
all_attn_log_increase = []
for layer_i in range(LAYER_NUM):
    cur_attn_vector = torch.tensor(all_pos_attn_output[layer_i][-1])
    cur_layer_input = torch.tensor(all_pos_layer_input[layer_i][-1])
    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_layer_input, model, final_var))[predict_index])
    cur_attn_vector_plus = cur_attn_vector + cur_layer_input
    cur_attn_vector_bsvalues = get_bsvalues(cur_attn_vector_plus, model, final_var)
    cur_attn_vector_probs = get_prob(cur_attn_vector_bsvalues)
    cur_attn_vector_probs = cur_attn_vector_probs[predict_index]
    cur_attn_vector_probs_log = torch.log(cur_attn_vector_probs)
    cur_attn_vector_probs_log_increase = cur_attn_vector_probs_log - origin_prob_log
    all_attn_log_increase.append(cur_attn_vector_probs_log_increase.item())
all_ffn_log_increase = []
for layer_i in range(LAYER_NUM):
    cur_ffn_vector = torch.tensor(all_pos_ffn_output[layer_i][-1])
    cur_residual = torch.tensor(all_pos_residual_output[layer_i][-1])
    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_residual, model, final_var))[predict_index])
    cur_ffn_vector_plus = cur_ffn_vector + cur_residual
    cur_ffn_vector_bsvalues = get_bsvalues(cur_ffn_vector_plus, model, final_var)
    cur_ffn_vector_probs = get_prob(cur_ffn_vector_bsvalues)
    cur_ffn_vector_probs = cur_ffn_vector_probs[predict_index]
    cur_ffn_vector_probs_log = torch.log(cur_ffn_vector_probs)
    cur_ffn_vector_probs_log_increase = cur_ffn_vector_probs_log - origin_prob_log
    all_ffn_log_increase.append(cur_ffn_vector_probs_log_increase.tolist())

# 打印FFN layers的分数
for layer_i in range(LAYER_NUM):
    print(f"Layer {layer_i}: {all_ffn_log_increase[layer_i]:.6f}")

attn_list, ffn_list = [], []
for layer_i in range(LAYER_NUM):
    attn_list.append([str(layer_i), all_attn_log_increase[layer_i]])
    ffn_list.append([str(layer_i), all_ffn_log_increase[layer_i]])
attn_list_sort = sorted(attn_list, key=lambda x: x[-1])[::-1]#[:10]
ffn_list_sort = sorted(ffn_list, key=lambda x: x[-1])[::-1]#[:10]
attn_increase_compute, ffn_increase_compute = [], []
for indx, increase in attn_list_sort:
    attn_increase_compute.append((indx, round(increase, 3)))
for indx, increase in ffn_list_sort:
    ffn_increase_compute.append((indx, round(increase, 3)))
print("attn sum: ", sum([x[1] for x in attn_increase_compute]), 
      "ffn sum: ", sum([x[1] for x in ffn_increase_compute]))
print("attn: ", attn_increase_compute)
print("ffn: ", ffn_increase_compute)
all_increases_draw = []
for i in range(len(attn_list)):
    all_increases_draw.append(attn_list[i][1])
    all_increases_draw.append(ffn_list[i][1])    
plt_bar(range(len(all_increases_draw)), all_increases_draw)

In [None]:
#head-level increase (value heads)
all_head_increase = []
for test_layer in range(LAYER_NUM):
    cur_layer_input = torch.tensor(all_pos_layer_input[test_layer])
    cur_v_heads = torch.tensor(all_last_attn_subvalues[test_layer])
    cur_attn_o_split = model.model.layers[test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)
    cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)
    cur_attn_subvalues_head_sum = torch.sum(cur_attn_subvalues_headrecompute, 0)
    cur_layer_input_last = cur_layer_input[-1]
    origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])
    cur_attn_subvalues_head_plus = cur_attn_subvalues_head_sum + cur_layer_input_last
    cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(
            cur_attn_subvalues_head_plus, model, final_var))[:, predict_index])
    cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob
    for i in range(len(cur_attn_plus_probs_increase)):
        all_head_increase.append([str(test_layer)+"_"+str(i), round(cur_attn_plus_probs_increase[i].item(), 4)])

all_head_increase_sort = sorted(all_head_increase, key=lambda x:x[-1])[::-1]
print(all_head_increase_sort[:30])
all_head_increase_list = [x[1] for x in all_head_increase]
all_head_increase_list_split = torch.tensor(all_head_increase_list).view((LAYER_NUM, HEAD_NUM)).permute((1,0)).tolist()
plt_heatmap(all_head_increase_list_split)

In [None]:
#pos-level increase in a specified head
test_layer, test_head = 15, 15
cur_layer_input = torch.tensor(all_pos_layer_input[test_layer])
cur_v_heads = torch.tensor(all_last_attn_subvalues[test_layer])
cur_attn_o_split = model.model.layers[test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)
cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)
cur_attn_subvalues_headrecompute_curhead = cur_attn_subvalues_headrecompute[:, test_head, :]
cur_layer_input_last = cur_layer_input[-1]
origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])
cur_attn_subvalues_headrecompute_curhead_plus = cur_attn_subvalues_headrecompute_curhead + cur_layer_input_last
cur_attn_subvalues_headrecompute_curhead_plus_probs = torch.log(get_prob(get_bsvalues(
    cur_attn_subvalues_headrecompute_curhead_plus, model, final_var))[:, predict_index])
cur_attn_subvalues_headrecompute_increase = cur_attn_subvalues_headrecompute_curhead_plus_probs - origin_prob
cur_attn_subvalues_headrecompute_increase_zip = list(zip(range(len(cur_attn_subvalues_headrecompute_increase)), 
    tokens, cur_attn_subvalues_headrecompute_increase.tolist()))
cur_attn_subvalues_headrecompute_increase_zip_sort = sorted(cur_attn_subvalues_headrecompute_increase_zip,
    key=lambda x:x[-1])[::-1]
cur_layer_input_bsvalues = get_bsvalues(cur_layer_input, model, final_var)
cur_layer_input_bsvalues_sort = torch.argsort(cur_layer_input_bsvalues, descending=True)
cur_attn_subvalues_headrecompute_curhead_bsvalues = get_bsvalues(
    cur_attn_subvalues_headrecompute_curhead, model, final_var)
cur_attn_subvalues_headrecompute_curhead_bsvalues_sort = torch.argsort(
    cur_attn_subvalues_headrecompute_curhead_bsvalues, descending=True)
key_input = cur_layer_input.clone()
key_input -= torch.tensor(all_pos_layer_input[0])
for layer_i in range(test_layer):
    key_input -= torch.tensor(all_pos_ffn_output[layer_i])
key_input_bsvalues = get_bsvalues(key_input, model, final_var)
key_input_bsvalues_sort = torch.argsort(key_input_bsvalues, descending=True)
print(list(zip(range(len(tokens)), tokens)))
for pos, word, increase in cur_attn_subvalues_headrecompute_increase_zip_sort:
    print("\n", pos, word, "increase: ", round(increase, 4), "attn: ", round(
        all_attn_scores[test_layer][0][test_head][-1][pos].item(), 4))
    print("layer input: ", [tokenizer.decode(x) for x in cur_layer_input_bsvalues_sort[pos][:20]])
    print("key: ", [tokenizer.decode(x) for x in key_input_bsvalues_sort[pos][:20]])
    print("value: ", [tokenizer.decode(x) for x in cur_attn_subvalues_headrecompute_curhead_bsvalues_sort[pos][:10]])

In [None]:
#FFN neuron increase (value FFN neuron)
all_ffn_subvalues = []
for layer_i in range(LAYER_NUM):
    coefficient_scores = torch.tensor(all_pos_coefficient_scores[layer_i][-1])
    fc2_vectors = get_fc2_params(model, layer_i)
    ffn_subvalues = (coefficient_scores * fc2_vectors).T
    all_ffn_subvalues.append(ffn_subvalues)
ffn_subvalue_list = []
for layer_i in range(LAYER_NUM):
    cur_ffn_subvalues = all_ffn_subvalues[layer_i]
    cur_residual = torch.tensor(all_pos_residual_output[layer_i][-1])
    origin_prob_log = torch.log(get_prob(get_bsvalues(cur_residual, model, final_var))[predict_index])
    cur_ffn_subvalues_plus = cur_ffn_subvalues + cur_residual
    cur_ffn_subvalues_bsvalues = get_bsvalues(cur_ffn_subvalues_plus, model, final_var)
    cur_ffn_subvalues_probs = get_prob(cur_ffn_subvalues_bsvalues)
    cur_ffn_subvalues_probs = cur_ffn_subvalues_probs[:, predict_index]
    cur_ffn_subvalues_probs_log = torch.log(cur_ffn_subvalues_probs)
    cur_ffn_subvalues_probs_log_increase = cur_ffn_subvalues_probs_log - origin_prob_log
    for index, ffn_increase in enumerate(cur_ffn_subvalues_probs_log_increase):
        ffn_subvalue_list.append([str(layer_i)+"_"+str(index), ffn_increase.item()])
ffn_subvalue_list_sort = sorted(ffn_subvalue_list, key=lambda x: x[-1])[::-1]
for x in ffn_subvalue_list_sort[:10]:
    print(x[0], round(x[1], 4))
    layer = int(x[0].split("_")[0])
    neuron = int(x[0].split("_")[1])
    cur_vector = get_fc2_params(model, layer).T[neuron]
    cur_vector_bsvalue = get_bsvalues(cur_vector, model, final_var)
    cur_vector_bsvalue_sort = torch.argsort(cur_vector_bsvalue, descending=True)
    print("top10: ", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[:10]])
    print("last10: ", [tokenizer.decode(a) for a in cur_vector_bsvalue_sort[-10:].tolist()[::-1]])

In [None]:
#visualize the number of value FFN neurons in different layers
FFN_value_neurons = [x[0] for x in ffn_subvalue_list_sort[:300]]
FFN_layer_count_value = [int(x.split("_")[0]) for x in list(FFN_value_neurons)]
FFN_layer_count_value = Counter(FFN_layer_count_value)
FFN_layer_count_value = sorted(zip(FFN_layer_count_value.keys(), FFN_layer_count_value.values()))


gpt_FFN_value_x, gpt_FFN_value_y = transfer_l(FFN_layer_count_value)

plt.figure(figsize=(6,3))
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.plot(gpt_FFN_value_x, gpt_FFN_value_y, "bo-", label="qwen3 FFN value neurons")
plt.xlabel("layer", fontsize=10)
plt.ylabel("count", fontsize=10)
plt.legend(fontsize=10, loc="upper right")
plt.show()

In [None]:
#find query layers activating FFN neurons
all_residual_scores = [0.0]*(1+2*LAYER_NUM)
for l_n, increase_score in ffn_subvalue_list_sort[:30]:
    ffn_layer, ffn_neuron = l_n.split("_")
    ffn_layer, ffn_neuron = int(ffn_layer), int(ffn_neuron)
    ffn_neuron_key = model.model.layers[ffn_layer].mlp.down_proj.weight.data[:, ffn_neuron]
    ffn_neuron_key_new = ffn_neuron_key * model.model.layers[ffn_layer].post_attention_layernorm.weight.data
    last_layer_residualstream = [torch.tensor(all_pos_layer_input[0][-1]).unsqueeze(0)]
    for layer_i in range(ffn_layer):
        last_layer_residualstream.append(torch.tensor(all_pos_attn_output[layer_i][-1]).unsqueeze(0))
        last_layer_residualstream.append(torch.tensor(all_pos_ffn_output[layer_i][-1]).unsqueeze(0))
    last_layer_residualstream.append(torch.tensor(all_pos_attn_output[ffn_layer][-1]).unsqueeze(0))
    last_layer_residualstream_cat = torch.cat(last_layer_residualstream, 0)
    last_layer_residualstream_innerproduct = torch.sum(last_layer_residualstream_cat*ffn_neuron_key_new, -1)
    last_layer_residualstream_innerproduct_zip = list(zip(range(len(last_layer_residualstream_innerproduct)), last_layer_residualstream_innerproduct.tolist()))
    sum_inner_product = sum([x[1] for x in last_layer_residualstream_innerproduct_zip])
    for l, inner in last_layer_residualstream_innerproduct_zip:
        all_residual_scores[l] += inner/sum_inner_product * increase_score
all_residual_scores_zip = list(zip(range(len(all_residual_scores)), all_residual_scores))
all_residual_scores_zip_sort = sorted(all_residual_scores_zip, key=lambda x: x[-1])[::-1]
print([(a[0]/2-0.5, round(a[1],4)) for a in all_residual_scores_zip_sort])
plt_bar(range(len(all_residual_scores[1:])), all_residual_scores[1:])

In [None]:
#find query layers activating attn neurons
all_residual_scores = [0.0]*(1+2*LAYER_NUM)
avg_attn_layer_curdir = []
for l_h_n_p, increase_score in cur_file_attn_neuron_list_sort[:30]:
    attn_layer, attn_head, attn_neuron, attn_pos = l_h_n_p.split("_")
    attn_layer, attn_head, attn_neuron, attn_pos = int(attn_layer), int(attn_head), int(attn_neuron), int(attn_pos)
    avg_attn_layer_curdir.append(attn_layer)
    cur_attn_neuron = attn_head*HEAD_DIM+attn_neuron
    attn_neuron_key = model.model.layers[attn_layer].self_attn.v_proj.weight.data[cur_attn_neuron]
    attn_neuron_key_new = attn_neuron_key * model.model.layers[attn_layer].input_layernorm.weight.data
    pos_layer_residualstream = [torch.tensor(all_pos_layer_input[0][attn_pos]).unsqueeze(0)]
    for layer_i in range(attn_layer):
        pos_layer_residualstream.append(torch.tensor(all_pos_attn_output[layer_i][attn_pos]).unsqueeze(0))
        pos_layer_residualstream.append(torch.tensor(all_pos_ffn_output[layer_i][attn_pos]).unsqueeze(0))
    pos_layer_residualstream_cat = torch.cat(pos_layer_residualstream, 0)
    pos_layer_residualstream_innerproduct = torch.sum(pos_layer_residualstream_cat*attn_neuron_key_new, -1)
    pos_layer_residualstream_innerproduct_zip = list(zip(range(len(pos_layer_residualstream_innerproduct)), pos_layer_residualstream_innerproduct.tolist()))
    sum_inner_product = sum([x[1] for x in pos_layer_residualstream_innerproduct_zip])
    for l, inner in pos_layer_residualstream_innerproduct_zip:
        all_residual_scores[l] += inner/sum_inner_product * increase_score
all_residual_scores_zip = list(zip(range(len(all_residual_scores)), all_residual_scores))
all_residual_scores_zip_sort = sorted(all_residual_scores_zip, key=lambda x: x[-1])[::-1]
print("avg attn layer: ", sum(avg_attn_layer_curdir)/len(avg_attn_layer_curdir))
print([(a[0]/2-0.5, a[1]) for a in all_residual_scores_zip_sort[:10]])
plt_bar(range(len(all_residual_scores[1:])), all_residual_scores[1:])