In [1]:
import torch
from transformers import AutoTokenizer, BloomForCausalLM
from bloom_for_node_attribution import BloomForCausalLMForNodeAttribution

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model = BloomForCausalLMForNodeAttribution.from_pretrained("bigscience/bloom-560m")

In [3]:
inputs = tokenizer("Hello, I am an AlexPrize chatbot", return_tensors="pt")
seq_length = len(inputs["input_ids"][0])
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=1, do_sample=True, top_k=50, top_p=0.95, return_dict_in_generate=True)

In [4]:
tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

['Hello, I am an AlexPrize chatbot.\n']

In [26]:
# Final hidden states to LM head contributions
lm_head_weights = next(param for param in model.lm_head.parameters())

# Only care about sequence tokens and none of the others
seq_token_weights = torch.index_select(lm_head_weights, 0, outputs.sequences[0][:-1])
hidden_state_activations = outputs.activations[0]["transformer"]["ln_f"].squeeze()

# Contribution of layer normed final hidden states to token logit
contribution1 = torch.mul(seq_token_weights, hidden_state_activations)

In [27]:
# Contribution of final feed forward layer to hidden states

# Give activations shape (seq length, 1, input size)
# delta_x
mlp_dense_4h_to_h_activations = outputs.activations[0]["transformer"]["h"]["23"]["mlp"]["dense_h_to_4h"]
mlp_dense_4h_to_h_activations = torch.transpose(mlp_dense_4h_to_h_activations, 0, 1)

# w_xy
mlp_dense_4h_to_h_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_4h_to_h.weight":
        mlp_dense_4h_to_h_weight = param
        
input_size = mlp_dense_4h_to_h_weight.shape[1]
output_size = mlp_dense_4h_to_h_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
mlp_dense_4h_to_h_weight = mlp_dense_4h_to_h_weight.expand(seq_length, output_size, input_size)

# Essentially the previous weight product
# w_yz
seq_token_weights = seq_token_weights.unsqueeze(-1)
seq_token_weights = seq_token_weights.expand(seq_length, output_size, input_size)

# Multiply current layer weights with previous layer weights
# Element-wise multiply each column by the output layer's weights to the final layer to get input layer's contribution to final prediction
# w_xy * w_yz
weight_product = torch.mul(mlp_dense_4h_to_h_weight, seq_token_weights)

# Element-wise multiply each weight row by the input node's activation
# Each column in contribution2 contains one input node's weights to every output node
# w_xy * w_yz * delta_x
contribution2 = torch.mul(weight_product, mlp_dense_4h_to_h_activations)

# Sum over column's elements to get contribution of input node to final token output
# sum(w_xy * w_yz * delta_x) over all y
contribution2 = torch.sum(contribution2, 1)

# Sum over column's elements (aka all weights from one input node to all output nodes) to have weight matrix for next layer
# sum(w_xy * w_yz) over all y, used for next computation
weight_product_sum = torch.sum(weight_product, 1)

In [28]:
# Contribution of mlp.dense_h_to_4h inputs to the predicted token
#delta_w
post_attention_layernorm_activations = outputs.activations[0]["transformer"]["h"]["23"]["post_attention_layernorm"]
post_attention_layernorm_activations = torch.transpose(post_attention_layernorm_activations, 0, 1)

# w_wx
mlp_dense_h_to_4h_weight = None
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_h_to_4h.weight":
        mlp_dense_h_to_4h_weight = param
        
input_size = mlp_dense_h_to_4h_weight.shape[1]
output_size = mlp_dense_h_to_4h_weight.shape[0]

# Give weights shape (seq length x output size x input size)
# In each (output size x input size) matrix, each row's elements are the weights from an input node to the output node corresponding to that rows index
mlp_dense_h_to_4h_weight = mlp_dense_h_to_4h_weight.expand(seq_length, output_size, input_size)

# Rehape previous weight product sum for multiplciation with current weights
# sum(w_xy * w_yz) over all y, used for next computation
weight_product_sum = weight_product_sum.unsqueeze(-1)
weight_product_sum = weight_product_sum.expand(seq_length, output_size, input_size)

# w_wx * sum(w_xy * w_yz) over all y
weight_product = torch.mul(mlp_dense_h_to_4h_weight, weight_product_sum)

# w_wx * delta_w * sum(w_xy * w_yz) over all y
contribution3 = torch.mul(weight_product, post_attention_layernorm_activations)

# sum(w_wx * delta_w * sum(w_xy * w_yz) over all y) over all x
contribution3 = torch.sum(contribution3, 1)

# sum(w_wx * sum(w_xy * w_yz) over all y) over all x, for next computation
weight_product_sum = torch.sum(weight_product, 1)

torch.Size([10, 4096])
torch.Size([10, 4096, 1])
torch.Size([10, 4096, 1024])
torch.Size([4096, 1024])
torch.Size([1024])
tensor([0.0098, 0.0098, 0.0098,  ..., 0.0098, 0.0098, 0.0098],
       grad_fn=<SelectBackward0>)
tensor([0.0047, 0.0047, 0.0047,  ..., 0.0047, 0.0047, 0.0047],
       grad_fn=<SelectBackward0>)


In [None]:
# Contribution of dense layer after attention heads to FFN

# transformer.h.23.self_attention.query_key_value.weight torch.Size([3072, 1024])
# Activation shape: torch.Size([1, 10, 3072])
# transformer.h.23.self_attention.dense.weight torch.Size([1024, 1024])

#delta_v 



In [29]:
for name, param in model.named_parameters():
    if "bias" in name:
        continue
        
    # output = xW(transpose) + b
    # The model weights are already transposed, so input and output dims appear swapped
    print(name, param.shape)

    hierarchy = name.split(".")
    curr_act = outputs.activations[0]

    for level in hierarchy:
        if level not in ["weight", "bias"]:
            curr_act = curr_act[level]
        else:
            break

    print(f"Activation shape: {curr_act.shape}")

transformer.word_embeddings.weight torch.Size([250880, 1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.word_embeddings_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.input_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.self_attention.query_key_value.weight torch.Size([3072, 1024])
Activation shape: torch.Size([1, 10, 3072])
transformer.h.0.self_attention.dense.weight torch.Size([1024, 1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.post_attention_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.0.mlp.dense_h_to_4h.weight torch.Size([4096, 1024])
Activation shape: torch.Size([1, 10, 4096])
transformer.h.0.mlp.dense_4h_to_h.weight torch.Size([1024, 4096])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.1.input_layernorm.weight torch.Size([1024])
Activation shape: torch.Size([1, 10, 1024])
transformer.h.1

In [None]:
for name, param in model.named_parameters():
    if name == "transformer.h.23.mlp.dense_4h_to_h.weight":
        print(param)