In [1]:
import torch
from transformers import AutoTokenizer, BloomForCausalLM
from bloom_for_node_attribution import BloomForCausalLMForNodeAttribution

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model = BloomForCausalLMForNodeAttribution.from_pretrained("bigscience/bloom-560m")

In [3]:
inputs = tokenizer("Hello, I am an AlexPrize chatbot", return_tensors="pt")
seq_length = len(inputs["input_ids"][0])
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=1, do_sample=True, top_k=50, top_p=0.95, return_dict_in_generate=True)

In [4]:
tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

['Hello, I am an AlexPrize chatbot.\n']

In [22]:
# Final hidden states to LM head contributions
lm_head_weights = next(param for param in model.lm_head.parameters())

# Only care about sequence tokens and none of the others
seq_token_weights = torch.index_select(lm_head_weights, 0, outputs.sequences[0][:-1])
hidden_state_activations = outputs.activations[0]["transformer"]["ln_f"].squeeze()

# Contribution of layer normed final hidden states to token logit
contribution1 = torch.mul(seq_token_weights, hidden_state_activations)
print(contribution1)

tensor([[ 4.5872e-02, -4.6421e-02, -4.1495e-02,  ...,  4.6246e+01,
          7.3020e-03,  9.8865e-02],
        [-4.6934e-02, -2.8784e-02, -3.4557e-02,  ...,  3.9431e+01,
         -1.2985e-03, -3.0629e-02],
        [ 8.1281e-03, -9.8136e-03, -1.6750e-02,  ...,  4.3258e+01,
         -1.4852e-02,  4.6938e-02],
        ...,
        [-9.3010e-02, -1.9035e-03,  5.5965e-02,  ...,  3.8051e+01,
         -3.8928e-02,  5.1447e-06],
        [ 1.8649e-02,  2.7172e-02, -3.7876e-03,  ...,  4.0829e+01,
         -5.3049e-03,  2.8129e-01],
        [ 3.1953e-04,  7.5846e-02,  2.1626e-02,  ...,  4.2024e+01,
         -9.2955e-03,  3.6013e-02]], grad_fn=<MulBackward0>)


In [23]:
# Contribution of final feed forward layer to hidden states

# Give activations shape (seq length, 1, input size)
# delta_x
mlp_dense_4h_to_h_activations = outputs.activations[0]["transformer"]["h"]["23"]["mlp"]["dense_h_to_4h"]
mlp_dense_4h_to_h_activations = torch.transpose(mlp_dense_4h_to_h_activations, 0, 1)

# w_xy
mlp_dense_4h_to_h_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_4h_to_h.weight":
        mlp_dense_4h_to_h_weight = param
        
input_size = mlp_dense_4h_to_h_weight.shape[1]
output_size = mlp_dense_4h_to_h_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
mlp_dense_4h_to_h_weight = mlp_dense_4h_to_h_weight.expand(seq_length, output_size, input_size)

# Essentially the previous weight product
# w_yz
seq_token_weights = seq_token_weights.unsqueeze(-1)
seq_token_weights = seq_token_weights.expand(seq_length, output_size, input_size)

# Multiply current layer weights with previous layer weights
# Element-wise multiply each column by the output layer's weights to the final layer to get input layer's contribution to final prediction
# w_xy * w_yz
weight_product = torch.mul(mlp_dense_4h_to_h_weight, seq_token_weights)

# Element-wise multiply each weight row by the input node's activation
# Each column in contribution2 contains one input node's weights to every output node
# w_xy * w_yz * delta_x
contribution2 = torch.mul(weight_product, mlp_dense_4h_to_h_activations)

# Sum over column's elements to get contribution of input node to final token output
# sum(w_xy * w_yz * delta_x) over all y
contribution2 = torch.sum(contribution2, 1)

# Sum over column's elements (aka all weights from one input node to all output nodes) to have weight matrix for next layer
# sum(w_xy * w_yz) over all y, used for next computation
weight_product_sum = torch.sum(weight_product, 1)
print(contribution2)

tensor([[-1.3841e-03,  6.3697e-04, -1.5406e-04,  ...,  4.3465e-04,
          4.0150e-05, -5.9891e-04],
        [ 4.7142e-03, -1.7368e-03,  2.8305e-04,  ..., -2.0748e-04,
         -1.3106e-03,  1.1136e-03],
        [ 1.4183e-03, -2.4820e-04, -1.8910e-05,  ...,  8.2431e-06,
         -3.1495e-04,  7.1449e-04],
        ...,
        [ 1.7915e-03,  2.5799e-04,  2.0766e-04,  ..., -3.4075e-05,
          4.3839e-04,  7.2800e-04],
        [ 7.3433e-04,  2.0247e-04, -1.6178e-03,  ..., -4.9226e-05,
         -1.0721e-03,  2.0690e-04],
        [-5.4961e-04, -1.2362e-04, -6.1327e-04,  ..., -1.6303e-04,
         -3.7989e-04, -7.8312e-04]], grad_fn=<SumBackward1>)


In [24]:
# Contribution of mlp.dense_h_to_4h inputs to the predicted token

#delta_w
post_attention_layernorm_activations = outputs.activations[0]["transformer"]["h"]["23"]["post_attention_layernorm"]
post_attention_layernorm_activations = torch.transpose(post_attention_layernorm_activations, 0, 1)

# w_wx
mlp_dense_h_to_4h_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_h_to_4h.weight":
        mlp_dense_h_to_4h_weight = param
        
input_size = mlp_dense_h_to_4h_weight.shape[1]
output_size = mlp_dense_h_to_4h_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
mlp_dense_h_to_4h_weight = mlp_dense_h_to_4h_weight.expand(seq_length, output_size, input_size)

# Rehape previous weight product sum for multiplciation with current weights
# sum(w_xy * w_yz) over all y, used for next computation
weight_product_sum = weight_product_sum.unsqueeze(-1)
weight_product_sum = weight_product_sum.expand(seq_length, output_size, input_size)

# w_wx * sum(w_xy * w_yz) over all y
weight_product = torch.mul(mlp_dense_h_to_4h_weight, weight_product_sum)

# w_wx * delta_w * sum(w_xy * w_yz) over all y
contribution3 = torch.mul(weight_product, post_attention_layernorm_activations)

# sum(w_wx * delta_w * sum(w_xy * w_yz) over all y) over all x
contribution3 = torch.sum(contribution3, 1)

# sum(w_wx * sum(w_xy * w_yz) over all y) over all x, for next computation
weight_product_sum = torch.sum(weight_product, 1)
print(contribution3)

tensor([[-1.6548e-03, -2.2128e-02, -8.0217e-03,  ..., -2.2217e+00,
          6.2512e-03, -4.4189e-02],
        [-3.0524e-02,  1.9464e-03, -9.5641e-02,  ..., -1.6648e+00,
         -1.9307e-02,  2.3634e-02],
        [ 9.8880e-03,  3.1073e-02, -2.9036e-02,  ..., -1.9420e+00,
          1.3735e-03, -1.0021e-02],
        ...,
        [-8.0712e-03,  9.5216e-03,  1.3890e-02,  ..., -1.5980e+00,
          2.2731e-04, -2.1542e-03],
        [ 1.0310e-03,  2.0804e-02,  1.6554e-02,  ..., -2.0859e+00,
         -2.4728e-03, -1.0520e-01],
        [-4.2761e-03,  5.1319e-02,  2.1998e-02,  ..., -2.0985e+00,
          3.5365e-04, -4.2339e-02]], grad_fn=<SumBackward1>)


In [25]:
# Dense output to merged attention heads

# Delta_v
merged_head_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["merge_heads"]
merged_head_activations = torch.transpose(merged_head_activations, 0, 1)

# w_vw
self_attention_dense_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.self_attention.dense.weight":
        self_attention_dense_weight = param
        
input_size = self_attention_dense_weight.shape[1]
output_size = self_attention_dense_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
self_attention_dense_weight = self_attention_dense_weight.expand(seq_length, output_size, input_size)

# Rehape previous weight product sum for multiplciation with current weights
weight_product_sum = weight_product_sum.unsqueeze(-1)
weight_product_sum = weight_product_sum.expand(seq_length, output_size, input_size)
weight_product = torch.mul(self_attention_dense_weight, weight_product_sum)

contribution4 = torch.mul(weight_product, merged_head_activations)
contribution4 = torch.sum(contribution4, 1)

weight_product_sum = torch.sum(weight_product, 1)
print(weight_product_sum.shape)

torch.Size([10, 1024])


In [26]:
# Value layer contribution to merged head output

# Value activations are the weights for the query_key output
value_layer_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["value_layer"]

# Query key output are the weights for the value activations
attention_probs = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["attention_probs_reshaped"]

num_heads = model.transformer.h[0].self_attention.num_heads
head_dim = model.transformer.h[0].self_attention.head_dim

# Need to reshape merged head contribution to multiply with attention_probs and value (layer num heads x seq_length x head dim)
value_layer_weight_product_sum = weight_product_sum.view(seq_length, num_heads, head_dim)
value_layer_weight_product_sum = value_layer_weight_product_sum.transpose(0, 1)

# Need to add an extra dim to the weight product sum because each column of the attention prob needs to be multiplied by the whole weight product sum
value_layer_weight_product_sum = value_layer_weight_product_sum.unsqueeze(1)
value_layer_weight_product_sum = value_layer_weight_product_sum.expand(num_heads, seq_length, seq_length, head_dim)

# Also need to expand attention probs to have final dim=head dim for elementwise multiplication
expanded_attention_probs = attention_probs.unsqueeze(-1)
expanded_attention_probs = expanded_attention_probs.expand(num_heads, seq_length, seq_length, head_dim)

# softmax(query x key) are the weights for the value layer
# Multiply softmax(query x key) output by next layer's weight product sum to get the weight product some for value layer contribution
# Need to elementwise multiply (16, 10, 10) each column of this matrix by each column of the current weight product sum matrix.
value_layer_weight_product = torch.mul(value_layer_weight_product_sum, expanded_attention_probs)

# Now we can sum over that extra dim we have. Each column in the most inner matrix represents one number in the value layer's weight product sum contribution.
value_layer_weight_product_sum = torch.sum(value_layer_weight_product, 2)

# If I elementwise multiplied the value layer and this current weight product sum, it gives the value layer output's contribution to the final prediction.
value_layer_contributions = torch.mul(value_layer_activations, value_layer_weight_product_sum)

In [None]:
# Query_Key attention weight contribution
query_key_attn_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["query_key_attn_weights"]

# Same reshaping that was needed for the value contribution calculation
key_query_atten_product_sum = weight_product_sum.view(seq_length, num_heads, head_dim)
key_query_atten_product_sum = value_layer_weight_product_sum.transpose(0, 1)







In [21]:
# Contribution of dense layer after attention heads to FFN

# transformer.h.23.self_attention.query_key_value.weight torch.Size([3072, 1024])
# Activation shape: torch.Size([1, 10, 3072])
# transformer.h.23.self_attention.dense.weight torch.Size([1024, 1024])

#delta_v 



In [None]:
for name, param in model.named_parameters():
    if "bias" in name:
        continue
        
    # output = xW(transpose) + b
    # The model weights are already transposed, so input and output dims appear swapped
    print(name, param.shape)

    hierarchy = name.split(".")
    curr_act = outputs.activations[0]

    for level in hierarchy:
        if level not in ["weight", "bias"]:
            curr_act = curr_act[level]
        else:
            break

    print(f"Activation shape: {curr_act.shape}")

In [None]:
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_4h_to_h.weight":
        print(param)