In [1]:
import torch
import numpy as np
import pickle as pkl

from node_attribution.node_attribution import get_attributions

  from .autonotebook import tqdm as notebook_tqdm


In [32]:
model_size = "560m"
data = ["Hello, I am an AlexPrize chatbot"]
prune_percent = 0.10
num_params_to_prune = 816115712 * prune_percent

head_dim = 64
num_heads = 16

In [33]:
num_params_to_prune

81611571.2

In [34]:
# Get attributions
avg_contributions, max_contributions, model, model_params = get_attributions(model_size, data)

Finished loading bigscience/bloom-560m.
                    Num Transformer Blocks: 24
                    Num Attention Heads: 16
                    Head Dim: 64
                    Hidden Size: 1024
                    Base Model Param Count: 559,214,592
                    Total Param Count (w/ LM Head): 816,115,712


In [35]:
index = 0
params_to_index = {}
for param_name in model_params.keys():
    if "bias" not in param_name:
        params_to_index[param_name] = index
        index += 1
    
index_to_params = {params_to_index[param_name]: param_name for param_name in params_to_index.keys()}

In [36]:
layer_shape_map = {}
layer_names, contribution_tensors = zip(*avg_contributions.items())
num_layers = len(layer_names)

for i in range(num_layers):
    layer_name = layer_names[i]
    layer_size = contribution_tensors[i].shape[0]
    
    if i != 0:
        next_layer_size = contribution_tensors[i - 1].shape[0]
    else:
        next_layer_size = 250880
        
    
    if i != (num_layers - 1):
        prev_layer_size = contribution_tensors[i + 1].shape[0]
    else:
        prev_layer_size = 1024
        
    layer_shape_map[layer_name] = {
        "prev_layer_size": prev_layer_size,
        "next_layer_size": next_layer_size,
        "current_layer_size": layer_size
    }

In [37]:
all_nodes = []

for layer_name, contribution_tensor in avg_contributions.items():
    for node_id, node in enumerate(contribution_tensor.tolist()):
        node_name = f"{layer_name}.{node_id}"
        all_nodes.append((node_name, node))
    

In [38]:
# Sort all the nodes
all_nodes.sort(key = lambda x:x[1])

In [39]:
all_layers = []

# Figure out which layers have the most to prune
for layer_name, contribution_tensor in avg_contributions.items():
    
    # Don't prune self_attention.dense.weight directly, use value matrix to decide what to prune
    if "self_attention.dense.weight" in layer_name:
        continue
        
    if "mlp.dense_h_to_4h.weight" in layer_name:
        continue
        
    if "self_attention.query_key_value.weight" in layer_name:
        continue
    
    # Get average contribution over the whole layer
    mean_contribution = torch.mean(contribution_tensor, 0).item()
    all_layers.append((layer_name, mean_contribution))

In [40]:
all_layers.sort(key = lambda x:x[1])

In [41]:
layer_masks = {}
num_params_pruned = 0
node_num = 0

while num_params_pruned < num_params_to_prune:
    lowest_contr_layer_name = all_layers[0][0]               
    shapes = layer_shape_map[lowest_contr_layer_name]
    
    # If the layer is a query_key_value_fused_output layer
    if "query_key_value_fused_output" in lowest_contr_layer_name:

        # Build pruning mask
        if lowest_contr_layer_name not in layer_masks:
            layer_contributions = avg_contributions[lowest_contr_layer_name]
            layer_contributions = layer_contributions.view(num_heads, 3, head_dim)

            mask = torch.zeros_like(layer_contributions)
            sorted_contributions = torch.argsort(layer_contributions)
            st = torch.sort(layer_contributions)
            num_pruned = 0
        else:
            mask = layer_masks[lowest_contr_layer_name][0]
            sorted_contributions = layer_masks[lowest_contr_layer_name][1]
            num_pruned = layer_masks[lowest_contr_layer_name][2]

        # Prune 3 * num head nodes at once, (this will decrease head dim by 1 for each atten head in this layer)
        # Make sure each of the 3 nodes are in the 3 different query, key, and value regions

        # TODO: Find a more efficient way of doing this
        for head in range(num_heads):
            for qkv in range(3):
                index_to_mask = sorted_contributions[head][qkv][num_pruned]
                mask[head][qkv][index_to_mask] = 1

        layer_masks[lowest_contr_layer_name] = (mask, sorted_contributions, num_pruned + 1)
        num_params_pruned += (shapes["prev_layer_size"] * num_heads * 3)
        num_params_pruned += (shapes["next_layer_size"] * num_heads * 3)
        node_num += (num_heads * 3)

    else:
        # Prune one node at time
        if lowest_contr_layer_name not in layer_masks:
            layer_contributions = avg_contributions[lowest_contr_layer_name]
            mask = torch.zeros_like(layer_contributions)
            sorted_contributions = torch.argsort(layer_contributions)
            num_pruned = 0

        else:
            mask = layer_masks[lowest_contr_layer_name][0]
            sorted_contributions = layer_masks[lowest_contr_layer_name][1]
            num_pruned = layer_masks[lowest_contr_layer_name][2]

        index_to_mask = sorted_contributions[num_pruned]
        mask[index_to_mask] = 1
        
        layer_masks[lowest_contr_layer_name] = (mask, sorted_contributions, num_pruned + 1)
        num_params_pruned += shapes["prev_layer_size"]
        num_params_pruned += shapes["next_layer_size"]
        node_num += 1

    # Apply mask and update the layer mean in "all_layers"
    mean_array = np.ma.array(data=avg_contributions[lowest_contr_layer_name], mask=mask)
    new_layer_contr_score = mean_array.mean()

    # print(all_layers[0])
    all_layers[0] = (lowest_contr_layer_name, new_layer_contr_score)
    # print(all_layers[0])
    # print(f"Num params removed: {num_params_pruned}")
    # print(f"Num Nodes removed: {node_num}")
    # print("=====")
    
    # re-sort layers now that this one has been pruned and pick the lowest contributing layer again
    all_layers.sort(key = lambda x:x[1])
    



In [42]:
# Line up weights to prune and weights in the state dict
mask_index = 0
sorted_weight_index = 1
pruned_model_params = model_params.copy()

for layer_name in layer_masks.keys():
    if layer_name == "transformer.h.0.self_attention.query_key_value.weight":
        continue
    elif "query_key_value_fused_output" in layer_name:   
        # Prune as input
        # Look at value matrix to decide what should be droped in "self_attention.dense.weight"
        value_reshape_mask = layer_masks[layer_name][mask_index].transpose(0, 1)[-1].reshape(num_heads * head_dim)
        num_nodes_to_drop = int(sum(value_reshape_mask).item())
        value_indices = layer_masks[layer_name][sorted_weight_index].transpose(0, 1)[-1].reshape(num_heads * head_dim)
        value_keep_index = torch.sort(value_indices[num_nodes_to_drop:]).values
        
        dense_layer_name = layer_name.replace("query_key_value_fused_output", "dense")
        pruned_input_weights = torch.index_select(pruned_model_params[dense_layer_name], -1, value_keep_index)
        pruned_model_params[dense_layer_name] = pruned_input_weights
        
        # Re-arrange mask to flatten shape
        reshaped_mask = layer_masks[layer_name][mask_index].view(num_heads * 3 * head_dim)
        rehsaped_indices = layer_masks[layer_name][sorted_weight_index].view(num_heads * 3 * head_dim)
        num_nodes_to_drop = int(sum(reshaped_mask).item())
        keep_index = torch.sort(rehsaped_indices[num_nodes_to_drop:]).values
        
        # Prune as output
        prev_layer_name = layer_name.replace("query_key_value_fused_output", "query_key_value")
        pruned_output_weights = torch.index_select(pruned_model_params[prev_layer_name], 0, keep_index)
        pruned_model_params[prev_layer_name] = pruned_output_weights
        
        # Also do bias term
        bias_layer_name = prev_layer_name.replace("weight", "bias")
        pruned_bias_weights = torch.index_select(pruned_model_params[bias_layer_name], 0, keep_index)
        pruned_model_params[bias_layer_name] = pruned_bias_weights
        
    else:
        # Prune when nodes are the input
        num_nodes_to_drop = int(sum(layer_masks[layer_name][mask_index]).item())
        keep_index = torch.sort(layer_masks[layer_name][sorted_weight_index][num_nodes_to_drop:]).values
        pruned_input_weights = torch.index_select(pruned_model_params[layer_name], -1, keep_index)
        pruned_model_params[layer_name] = pruned_input_weights
        
        # Go to previous layer and prune when nodes are the output
        prev_layer_index = params_to_index[layer_name] - 1
        prev_layer_name = index_to_params[prev_layer_index]
        
        if "layernorm" in prev_layer_name:
            pruned_layer_norm_weights = torch.index_select(pruned_model_params[prev_layer_name], 0, keep_index)
            pruned_model_params[prev_layer_name] = pruned_layer_norm_weights 
            
            # Also do bias term
            bias_layer_name = prev_layer_name.replace("weight", "bias")
            pruned_bias_weights = torch.index_select(pruned_model_params[bias_layer_name], 0, keep_index)
            pruned_model_params[bias_layer_name] = pruned_bias_weights
            
            prev_layer_index = prev_layer_index - 1
            prev_layer_name = index_to_params[prev_layer_index]
            
        pruned_output_weights = torch.index_select(pruned_model_params[prev_layer_name], 0, keep_index)
        pruned_model_params[prev_layer_name] = pruned_output_weights
        
        # Also do bias term
        bias_layer_name = prev_layer_name.replace("weight", "bias")
        pruned_bias_weights = torch.index_select(pruned_model_params[bias_layer_name], 0, keep_index)
        pruned_model_params[bias_layer_name] = pruned_bias_weights

In [43]:
torch.save(pruned_model_params, "pruned_560m_bloom.pt")

In [44]:
state_dict_shapes = {}

In [45]:
for param_name in pruned_model_params.keys():
    state_dict_shapes[param_name] = pruned_model_params[param_name].shape
    print(param_name, pruned_model_params[param_name].shape)
    
pkl.dump(state_dict_shapes, open("state_dict_shapes.pkl", "wb"))

transformer.word_embeddings.weight torch.Size([250880, 1024])
transformer.word_embeddings_layernorm.weight torch.Size([1024])
transformer.word_embeddings_layernorm.bias torch.Size([1024])
transformer.h.0.input_layernorm.weight torch.Size([1024])
transformer.h.0.input_layernorm.bias torch.Size([1024])
transformer.h.0.self_attention.query_key_value.weight torch.Size([3072, 1024])
transformer.h.0.self_attention.query_key_value.bias torch.Size([3072])
transformer.h.0.self_attention.dense.weight torch.Size([1024, 1024])
transformer.h.0.self_attention.dense.bias torch.Size([1024])
transformer.h.0.post_attention_layernorm.weight torch.Size([1024])
transformer.h.0.post_attention_layernorm.bias torch.Size([1024])
transformer.h.0.mlp.dense_h_to_4h.weight torch.Size([2197, 1024])
transformer.h.0.mlp.dense_h_to_4h.bias torch.Size([2197])
transformer.h.0.mlp.dense_4h_to_h.weight torch.Size([1024, 2197])
transformer.h.0.mlp.dense_4h_to_h.bias torch.Size([1024])
transformer.h.1.input_layernorm.weight

In [46]:
node_num / len(all_nodes)

0.1614772432572614

In [47]:
node_num

39850

In [48]:
for layer_name, _ in all_layers:
    if layer_name not in layer_masks:
        print(layer_name)

transformer.h.4.mlp.dense_4h_to_h.weight
lm_head.weight
transformer.h.6.self_attention.query_key_value_fused_output.weight
transformer.h.14.self_attention.query_key_value_fused_output.weight
transformer.h.0.self_attention.query_key_value_fused_output.weight
transformer.h.9.self_attention.query_key_value_fused_output.weight
transformer.h.13.self_attention.query_key_value_fused_output.weight
transformer.h.7.self_attention.query_key_value_fused_output.weight
transformer.h.11.self_attention.query_key_value_fused_output.weight
transformer.h.18.self_attention.query_key_value_fused_output.weight
transformer.h.8.self_attention.query_key_value_fused_output.weight
transformer.h.12.self_attention.query_key_value_fused_output.weight
transformer.h.10.self_attention.query_key_value_fused_output.weight


In [49]:
for name, param in model.named_parameters():
    if "bias" in name:
        continue

    # print(name, param.shape)