In [1]:
import torch
from transformers import AutoTokenizer, BloomForCausalLM
from bloom_for_node_attribution import BloomForCausalLMForNodeAttribution

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model = BloomForCausalLMForNodeAttribution.from_pretrained("bigscience/bloom-560m")

In [3]:
inputs = tokenizer("Hello, I am an AlexPrize chatbot", return_tensors="pt")
batch_size = len(inputs["input_ids"])
seq_length = len(inputs["input_ids"][0])
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=1, do_sample=True, top_k=50, top_p=0.95, return_dict_in_generate=True)

In [4]:
tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

['Hello, I am an AlexPrize chatbot.']

In [21]:
# Final hidden states to LM head contributions
# (250880 x 1024)
lm_head_weights = next(param for param in model.lm_head.parameters())

# Only care about sequence tokens and none of the others
seq_token_weights = torch.index_select(lm_head_weights, 0, outputs.sequences[0][:-1])
hidden_state_activations = outputs.activations[0]["transformer"]["ln_f"].squeeze()

# Contribution of layer normed final hidden states to token logit
contribution1 = torch.mul(seq_token_weights, hidden_state_activations)

In [22]:
# Contribution of final feed forward layer to hidden states

# Give activations shape (seq length, 1, input size)
# delta_x
mlp_dense_4h_to_h_activations = outputs.activations[0]["transformer"]["h"]["23"]["mlp"]["dense_h_to_4h"]

# w_xy
mlp_dense_4h_to_h_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_4h_to_h.weight":
        mlp_dense_4h_to_h_weight = param
        
input_size = mlp_dense_4h_to_h_weight.shape[1]
output_size = mlp_dense_4h_to_h_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
mlp_dense_4h_to_h_weight = mlp_dense_4h_to_h_weight.expand(seq_length, output_size, input_size)

# Essentially the previous weight product
# w_yz
seq_token_weights = seq_token_weights.unsqueeze(-1)
seq_token_weights = seq_token_weights.expand(seq_length, output_size, input_size)

# Multiply current layer weights with previous layer weights
# Element-wise multiply each column by the output layer's weights to the final layer to get input layer's contribution to final prediction
# w_xy * w_yz
weight_product = torch.mul(mlp_dense_4h_to_h_weight, seq_token_weights)

# Sum over column's elements (aka all weights from one input node to all output nodes) to have weight matrix for next layer
# sum(w_xy * w_yz) over all y, used for next computation
weight_product_sum = torch.sum(weight_product, 1)

# Element-wise multiply each weight row by the input node's activation
# Each column in contribution2 contains one input node's weights to every output node
# w_xy * w_yz * delta_x and sum(w_xy * w_yz * delta_x) over all y
contribution2 = torch.mul(weight_product_sum, mlp_dense_4h_to_h_activations)

In [7]:
# Contribution of mlp.dense_h_to_4h inputs to the predicted token

#delta_w
post_attention_layernorm_activations = outputs.activations[0]["transformer"]["h"]["23"]["post_attention_layernorm"]

# w_wx
mlp_dense_h_to_4h_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_h_to_4h.weight":
        mlp_dense_h_to_4h_weight = param
        
input_size = mlp_dense_h_to_4h_weight.shape[1]
output_size = mlp_dense_h_to_4h_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
mlp_dense_h_to_4h_weight = mlp_dense_h_to_4h_weight.expand(seq_length, output_size, input_size)

# Rehape previous weight product sum for multiplciation with current weights
# sum(w_xy * w_yz) over all y, used for next computation
weight_product_sum = weight_product_sum.unsqueeze(-1)
weight_product_sum = weight_product_sum.expand(seq_length, output_size, input_size)

# w_wx * sum(w_xy * w_yz) over all y
weight_product = torch.mul(mlp_dense_h_to_4h_weight, weight_product_sum)

# sum(w_wx * sum(w_xy * w_yz) over all y) over all x, for next computation
weight_product_sum = torch.sum(weight_product, 1)

# w_wx * delta_w * sum(w_xy * w_yz) over all y
# sum(w_wx * delta_w * sum(w_xy * w_yz) over all y) over all x
contribution3 = torch.mul(weight_product, post_attention_layernorm_activations)

In [8]:
# Dense output to merged attention heads

# Delta_v
merged_head_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["merge_heads"]

# w_vw
self_attention_dense_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.self_attention.dense.weight":
        self_attention_dense_weight = param
        
input_size = self_attention_dense_weight.shape[1]
output_size = self_attention_dense_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
self_attention_dense_weight = self_attention_dense_weight.expand(seq_length, output_size, input_size)

# Rehape previous weight product sum for multiplciation with current weights
weight_product_sum = weight_product_sum.unsqueeze(-1)
weight_product_sum = weight_product_sum.expand(seq_length, output_size, input_size)

weight_product = torch.mul(self_attention_dense_weight, weight_product_sum)
weight_product_sum = torch.sum(weight_product, 1)

contribution4 = torch.mul(weight_product, merged_head_activations)

In [9]:
# Value layer contribution to merged head output

# Value activations are the weights for the query_key output
value_layer_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["value_layer"]

# Query key output are the weights for the value activations
attention_probs = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["attention_probs_reshaped"]

num_heads = model.transformer.h[0].self_attention.num_heads
head_dim = model.transformer.h[0].self_attention.head_dim

# Need to reshape merged head contribution to multiply with attention_probs and value (layer num heads x seq_length x head dim)
value_weight_product_sum = weight_product_sum.view(seq_length, num_heads, head_dim)
value_weight_product_sum = value_weight_product_sum.transpose(0, 1)

# Need to add an extra dim to the weight product sum because each column of the attention prob needs to be multiplied by the whole weight product sum
value_weight_product_sum = value_weight_product_sum.unsqueeze(1)
value_weight_product_sum = value_weight_product_sum.expand(num_heads, seq_length, seq_length, head_dim)

# Also need to expand attention probs to have final dim=head dim for elementwise multiplication
expanded_attention_probs = attention_probs.transpose(1, 2)
expanded_attention_probs = expanded_attention_probs.unsqueeze(-1)
expanded_attention_probs = expanded_attention_probs.expand(num_heads, seq_length, seq_length, head_dim)

# softmax(query x key) are the weights for the value layer
# Multiply softmax(query x key) output by next layer's weight product sum to get the weight product some for value layer contribution
# Need to elementwise multiply (16, 10, 10) each column of this matrix by each column of the current weight product sum matrix.
value_weight_product = torch.mul(value_weight_product_sum, expanded_attention_probs)

# Now we can sum over that extra dim we have. Each column in the most inner matrix represents one number in the value layer's weight product sum contribution.
value_weight_product_sum = torch.sum(value_weight_product, 2)

# If I element wise multiplied the value layer and this current weight product sum, it gives the value layer output's contribution to the final prediction.
value_contribution = torch.mul(value_layer_activations, value_weight_product_sum)

In [10]:
# Query_Key attention weight contribution
query_key_attn_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["query_key_attn_weights"]

# Same reshaping that was needed for the value contribution calculation
query_key_attn_weight_product_sum = weight_product_sum.view(seq_length, num_heads, head_dim)
query_key_attn_weight_product_sum = query_key_attn_weight_product_sum.transpose(0, 1)

# Need to multiply each row in the product sum by every row in the value weight matrix, so replicating each product sum row ((value weight matrix num rows) = seq_length) amount of times
# Note I am expanding a different dimension here compared to what I did for the value weight product sum 
query_key_attn_weight_product_sum = query_key_attn_weight_product_sum.unsqueeze(2)
query_key_attn_weight_product_sum = query_key_attn_weight_product_sum.expand(num_heads, seq_length, seq_length, head_dim)

# Also need to expand the value weight matrix for this elementwise multiplication
expanded_value_layer_activations = value_layer_activations.unsqueeze(1)
expanded_value_layer_activations = expanded_value_layer_activations.expand(num_heads, seq_length, seq_length, head_dim)

# For each row in the weight product sum, multiply it by every row in the value weight matrix
query_key_attn_weight_product = torch.mul(query_key_attn_weight_product_sum, expanded_value_layer_activations)

# Sum over the head_dim dimension, which in this case, can be thought of as the the output dimension
query_key_attn_weight_product_sum = torch.sum(query_key_attn_weight_product, -1)

# Multiplying by the query key activations gives the query key contribution. Don't actually need these contriution values though.
query_key_attn_contribution = torch.mul(query_key_attn_activations.squeeze(), query_key_attn_weight_product_sum)

In [11]:
# Query contributions to the attention weights

# The query activations will be the key's weights when calcualting the key contibution in the next cell
query_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["query_layer"]
print(query_activations.shape)

# Key activations are the query's weights here
key_activations = outputs.activations[0]["transformer"]["h"]["23"]["self_attention"]["key_layer"]
print(key_activations.shape)

# Treat this weight product sum like I treated the weight product sum in query_key_attn_weight_product_sum
query_weight_product_sum = query_key_attn_weight_product_sum.unsqueeze(2)
query_weight_product_sum = query_weight_product_sum.expand(num_heads, seq_length, head_dim, seq_length)

# Treat the key activations like I treated the value activations when calculating the query_key_attn_weight_product_sum
expanded_key_layer_activations = key_activations.unsqueeze(1)
expanded_key_layer_activations = expanded_key_layer_activations.expand(num_heads, seq_length, head_dim, seq_length)

# For each row in the weight product sum, multiply it by every row in the key weight matrix
query_weight_product = torch.mul(query_weight_product_sum, expanded_key_layer_activations)
query_weight_product_sum = torch.sum(query_weight_product, -1)

# Treat the query activations like I treated the query_key_attn activations
query_contribution = torch.mul(query_activations, query_weight_product_sum)
print(query_contribution.shape)

torch.Size([16, 10, 64])
torch.Size([16, 64, 10])
torch.Size([16, 10, 64])


In [12]:
# Key contributions to the attention weights

# Treat key matrix like value matrix when calculating value contribution
key_weight_product_sum = query_key_attn_weight_product_sum.unsqueeze(1)
key_weight_product_sum = key_weight_product_sum.expand(num_heads, head_dim, seq_length, seq_length)

expanded_query_activations = query_activations.transpose(1, 2)
expanded_query_activations = expanded_query_activations.unsqueeze(-1)
expanded_query_activations = expanded_query_activations.expand(num_heads, head_dim, seq_length, seq_length)

key_weight_product = torch.mul(key_weight_product_sum, expanded_query_activations)
key_weight_product_sum = torch.sum(key_weight_product, 2)

key_contribution = torch.mul(key_activations, key_weight_product_sum)

In [13]:
print(query_weight_product_sum.shape)
print(key_weight_product_sum.shape)
print(value_weight_product_sum.shape)

torch.Size([16, 10, 64])
torch.Size([16, 64, 10])
torch.Size([16, 10, 64])


In [14]:
# Now arrange and combine the key, query, and value weight product sumsand contributions
# to form the weight product sum and contributions for the fused qkv layer output

# In the forward pass they did the following which needs to be undone
# query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
# key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
# value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)

# 3 x [batch_size, seq_length, num_heads, head_dim]
query_contribution = query_contribution.unsqueeze(0).transpose(1, 2).unsqueeze(-2)
query_weight_product_sum = query_weight_product_sum.unsqueeze(0).transpose(1, 2).unsqueeze(-2)

key_contribution = key_contribution.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(-2)
key_weight_product_sum = key_weight_product_sum.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(-2)

value_contribution = value_contribution.unsqueeze(0).transpose(1, 2).unsqueeze(-2)
value_weight_product_sum = value_weight_product_sum.unsqueeze(0).transpose(1, 2).unsqueeze(-2)

# fuse'em
fused_qkv_contribution = torch.cat([query_contribution, key_contribution, value_contribution], -2)
fused_qkv_weight_product_sum = torch.cat([query_weight_product_sum, key_weight_product_sum, value_weight_product_sum], -2)
fused_qkv_contribution = fused_qkv_contribution.view(batch_size, seq_length, num_heads * 3 * head_dim)
fused_qkv_weight_product_sum = fused_qkv_weight_product_sum.view(batch_size, seq_length, num_heads * 3 * head_dim).squeeze()
print(fused_qkv_contribution.shape)

torch.Size([1, 10, 3072])


In [15]:
# Contribution of transformer block input to query, key, and value output

# hidden states input
block_input = outputs.activations[0]["transformer"]["h"]["23"]["input_layernorm"]

# weights from hidden states input to qkv output
fused_qkv_weights = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.self_attention.query_key_value.weight":
        fused_qkv_weights = param
        print(fused_qkv_weights.shape)
        
input_size = fused_qkv_weights.shape[1]
output_size = fused_qkv_weights.shape[0]

# Give weights shape (seq length x output size x input size)
# Expand layer's weights by sequence length
fused_qkv_weights = fused_qkv_weights.expand(seq_length, output_size, input_size)

# Expand weight product sum by input size, to multiply every input node's weights by the weight product sum for each output node
weight_product_sum = fused_qkv_weight_product_sum.unsqueeze(-1)
weight_product_sum = weight_product_sum.expand(seq_length, output_size, input_size)

weight_product = torch.mul(fused_qkv_weights, weight_product_sum)
weight_product_sum  = torch.sum(weight_product, 1)

block_input_contribution = torch.mul(weight_product_sum, block_input)

torch.Size([3072, 1024])


In [16]:
# Contribution of dense layer after attention heads to FFN

# transformer.h.23.self_attention.query_key_value.weight torch.Size([3072, 1024])
# Activation shape: torch.Size([1, 10, 3072])
# transformer.h.23.self_attention.dense.weight torch.Size([1024, 1024])

#delta_v 



In [17]:
for name, param in model.named_parameters():
    if "bias" in name:
        continue
        
    # output = xW(transpose) + b
    # The model weights are already transposed, so input and output dims appear swapped
    print(name, param.shape)

    hierarchy = name.split(".")
    curr_act = outputs.activations[0]

    for level in hierarchy:
        if level not in ["weight", "bias"]:
            curr_act = curr_act[level]
        else:
            break

    print(f"Activation shape: {curr_act.shape}")

transformer.word_embeddings.weight torch.Size([250880, 1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.word_embeddings_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.input_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.self_attention.query_key_value.weight torch.Size([3072, 1024])
Activation shape: torch.Size([1, 10, 3072])
transformer.h.0.self_attention.dense.weight torch.Size([1024, 1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.post_attention_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.mlp.dense_h_to_4h.weight torch.Size([4096, 1024])
Activation shape: torch.Size([1, 10, 4096])
transformer.h.0.mlp.dense_4h_to_h.weight torch.Size([1024, 4096])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.1.input_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.1

In [18]:
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_4h_to_h.weight":
        print(param)

Parameter containing:
tensor([[-0.0016, -0.0079,  0.0035,  ..., -0.0129,  0.0022, -0.0012],
        [ 0.0247,  0.0062, -0.0100,  ..., -0.0147, -0.0129, -0.0123],
        [-0.0165,  0.0039,  0.0336,  ...,  0.0026, -0.0098, -0.0336],
        ...,
        [ 0.0010, -0.0039, -0.0060,  ..., -0.0028, -0.0009, -0.0025],
        [ 0.0130,  0.0023, -0.0083,  ..., -0.0064,  0.0052, -0.0009],
        [ 0.0005, -0.0489, -0.0112,  ...,  0.0155, -0.0183,  0.0062]],
       requires_grad=True)
