In [1]:
import torch
import numpy as np
import pickle as pkl

from collections import OrderedDict
from node_attribution.gradient_node_attribution import get_attributions

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
human_filtered_pairs = pkl.load(open("44_human_filtered_conv_pairs.pkl", "rb"))

In [3]:
calibration_data = human_filtered_pairs[:22]

In [4]:
# Get attributions
model_size = "560m"
avg_contributions, max_contributions, model, model_params = get_attributions(model_size, [calibration_data[0]])

Finished loading bigscience/bloom-560m.
                    Num Transformer Blocks: 24
                    Num Attention Heads: 16
                    Head Dim: 64
                    Hidden Size: 1024
                    Base Model Param Count: 559,214,592
                    Total Param Count (w/ LM Head): 816,115,712
1


In [5]:
#take absolute value of average contributions
for layer_name, contribution_tensor in avg_contributions.items():
    avg_contributions[layer_name] = torch.abs(avg_contributions[layer_name])

In [6]:
#pkl.dump(avg_contributions, open("avg_contri_560m_22pair_calibration.pkl", "wb"))

In [7]:
# pkl.dump(max_contributions, open("max_contri_3B_22pair_calibration.pkl", "wb"))

In [8]:
#avg_contributions = pkl.load(open("avg_contri_560m_22pair_calibration.pkl", "rb"))

In [29]:
prune_percent = 0.30
num_params_to_prune = 559214592 * prune_percent

head_dim = 64
num_heads = 16
hidden_size = 1024
num_layers = 24

print(num_params_to_prune)

167764377.6


In [30]:
index = 0
params_to_index = {}

for param_name in model_params.keys():
    if "bias" not in param_name:
        params_to_index[param_name] = index
        index += 1
    
index_to_params = {params_to_index[param_name]: param_name for param_name in params_to_index.keys()}

In [31]:
# layer_shape_map = {}
# layer_names, contribution_tensors = zip(*avg_contributions.items())
# num_layers = len(layer_names)

# for i in range(num_layers):
#     layer_name = layer_names[i]
#     layer_size = contribution_tensors[i].shape[0]
    
#     if i != 0:
#         next_layer_size = contribution_tensors[i - 1].shape[0]
#     else:
#         next_layer_size = 250880
        
    
#     if i != (num_layers - 1):
#         prev_layer_size = contribution_tensors[i + 1].shape[0]
#     else:
#         prev_layer_size = 1024
        
#     layer_shape_map[layer_name] = {
#         "prev_layer_size": prev_layer_size,
#         "next_layer_size": next_layer_size,
#         "current_layer_size": layer_size
#     }

In [32]:
all_nodes = []

for layer_name, contribution_tensor in avg_contributions.items():
    for node_id, node in enumerate(contribution_tensor.tolist()):
        node_name = f"{layer_name}.{node_id}"
        all_nodes.append((node_name, node))
    

In [33]:
# Sort all the nodes
all_nodes.sort(key = lambda x:x[1])

In [34]:
all_layers = []

# Figure out which layers have the most to prune
for layer_name, contribution_tensor in avg_contributions.items():
    
    # Don't prune self_attention.dense.weight directly, use value matrix to decide what to prune
    # Get average contribution over the whole layer
    mean_contribution = torch.mean(contribution_tensor, 0).item()
    all_layers.append((layer_name, mean_contribution))

In [35]:
all_layers.sort(key = lambda x:x[1])

In [36]:
all_layers

[('transformer.h.23.mlp.dense_4h_to_h.weight', 0.00235293572768569),
 ('transformer.h.9.mlp.dense_4h_to_h.weight', 0.0025348886847496033),
 ('transformer.h.8.mlp.dense_4h_to_h.weight', 0.002768360311165452),
 ('transformer.h.7.mlp.dense_4h_to_h.weight', 0.0037689958699047565),
 ('transformer.h.2.self_attention.value.weight', 0.00436257291585207),
 ('transformer.h.8.self_attention.value.weight', 0.004458594601601362),
 ('transformer.h.10.self_attention.value.weight', 0.004510129801928997),
 ('transformer.h.23.self_attention.value.weight', 0.00453915074467659),
 ('transformer.h.3.self_attention.value.weight', 0.0045631686225533485),
 ('transformer.h.0.mlp.dense_4h_to_h.weight', 0.004628546070307493),
 ('transformer.h.10.mlp.dense_4h_to_h.weight', 0.004696134943515062),
 ('transformer.h.5.mlp.dense_4h_to_h.weight', 0.004815695807337761),
 ('transformer.h.1.mlp.dense_4h_to_h.weight', 0.004817564971745014),
 ('transformer.h.11.mlp.dense_4h_to_h.weight', 0.004893769044429064),
 ('transformer

In [37]:
layer_masks = {}
num_params_pruned = 0
node_num = 0
min_nodes = 24
min_heads = 4
value_dim = 2

while num_params_pruned < num_params_to_prune:
    lowest_contr_layer_name = all_layers[0][0]               
    stop_pruning_layer = False
    
    if "self_attention.value.weight" in lowest_contr_layer_name:
        if lowest_contr_layer_name not in layer_masks:
            layer_contributions = avg_contributions[lowest_contr_layer_name]
            layer_contributions = layer_contributions.view(num_heads, head_dim)
            head_contributions = []
            
            for head in range(num_heads):
                head_score = torch.mean(layer_contributions[head])
                head_contributions.append(head_score)
                
            head_contributions = torch.stack(head_contributions)
            sorted_heads = torch.argsort(head_contributions)        
            mask = torch.zeros_like(layer_contributions)
            num_pruned = 0

        else:
            mask = layer_masks[lowest_contr_layer_name][0]
            sorted_heads = layer_masks[lowest_contr_layer_name][1]
            num_pruned = layer_masks[lowest_contr_layer_name][2] 
            
            
        head_to_prune = sorted_heads[num_pruned]
        for i in range(head_dim):
            mask[head_to_prune][i] = 1
            
        heads_left = num_heads - num_pruned - 1
        
        # Keep from deleting all nodes in a layer
        if heads_left > min_heads:
            layer_masks[lowest_contr_layer_name] = (mask, sorted_heads, num_pruned + 1)
            num_params_pruned += (head_dim * 3 * hidden_size)
            num_params_pruned += (head_dim * hidden_size)
            node_num += (head_dim * 3)
        else:
            stop_pruning_layer = True
    
    else:
        # Prune one node at time
        if lowest_contr_layer_name not in layer_masks:
            layer_contributions = avg_contributions[lowest_contr_layer_name]
            mask = torch.zeros_like(layer_contributions)
            sorted_contributions = torch.argsort(layer_contributions)
            num_pruned = 0

        else:
            mask = layer_masks[lowest_contr_layer_name][0]
            sorted_contributions = layer_masks[lowest_contr_layer_name][1]
            num_pruned = layer_masks[lowest_contr_layer_name][2]

        index_to_mask = sorted_contributions[num_pruned]
        mask[index_to_mask] = 1

        nodes_left = torch.numel(mask) - int(torch.sum(mask).item())

        # Keep from deleting all nodes in a layer
        if nodes_left > min_nodes:
            layer_masks[lowest_contr_layer_name] = (mask, sorted_contributions, num_pruned + 1)
            num_params_pruned += 2 * hidden_size
            node_num += 1
        else:
            stop_pruning_layer = True

    # Apply mask and update the layer mean in "all_layers"
    if stop_pruning_layer:
        new_layer_contr_score = float('inf')
    else:
        mean_array = np.ma.array(data=avg_contributions[lowest_contr_layer_name], mask=mask)
        new_layer_contr_score = mean_array.mean()

    all_layers[0] = (lowest_contr_layer_name, new_layer_contr_score)
    all_layers.sort(key = lambda x:x[1])

In [38]:
for key, value in layer_masks.items():
    print(key, value[-1])

transformer.h.23.mlp.dense_4h_to_h.weight 3315
transformer.h.9.mlp.dense_4h_to_h.weight 3287
transformer.h.8.mlp.dense_4h_to_h.weight 3293
transformer.h.7.mlp.dense_4h_to_h.weight 2949
transformer.h.2.self_attention.value.weight 11
transformer.h.8.self_attention.value.weight 11
transformer.h.10.self_attention.value.weight 11
transformer.h.23.self_attention.value.weight 11
transformer.h.3.self_attention.value.weight 11
transformer.h.0.mlp.dense_4h_to_h.weight 2662
transformer.h.10.mlp.dense_4h_to_h.weight 2560
transformer.h.5.mlp.dense_4h_to_h.weight 2656
transformer.h.1.mlp.dense_4h_to_h.weight 2698
transformer.h.11.mlp.dense_4h_to_h.weight 2569
transformer.h.1.self_attention.value.weight 11
transformer.h.6.self_attention.value.weight 11
transformer.h.4.self_attention.value.weight 11
transformer.h.5.self_attention.value.weight 11
transformer.h.0.self_attention.value.weight 11
transformer.h.7.self_attention.value.weight 11
transformer.h.6.mlp.dense_4h_to_h.weight 2521
transformer.h.11.s

In [39]:
# Line up weights to prune and weights in the state dict
mask_index = 0
sorted_weight_index = 1
copied_model_params = model_params.copy()
pruned_model_params = OrderedDict()

for layer_name in copied_model_params:
    if layer_name in layer_masks.keys():
        if "self_attention.value.weight" in layer_name:   
            # Prune as input
            # Look at value matrix to decide what should be droped in "self_attention.dense.weight"
            value_reshape_mask = layer_masks[layer_name][mask_index].reshape(num_heads * head_dim)
            value_keep_index = torch.tensor([i for i in range(value_reshape_mask.shape[-1]) if value_reshape_mask[i] == 0])

            qk_layer_name = layer_name.replace("value", "query_key")
            qk_keep_index = torch.add(value_keep_index, hidden_size)
            qk_keep_index = torch.cat([value_keep_index, qk_keep_index], 0)

            pruned_query_key_output_weights = torch.index_select(copied_model_params[qk_layer_name], 0, qk_keep_index)
            pruned_value_output_weights = torch.index_select(copied_model_params[layer_name], 0, value_keep_index)

            qkv_layer_name = layer_name.replace("value", "query_key_value")
            pruned_query_key_value_weights = torch.cat((pruned_query_key_output_weights, pruned_value_output_weights), 0)
            pruned_model_params[qkv_layer_name] = pruned_query_key_value_weights

            # Also do bias terms
            qk_bias_name = qk_layer_name.replace("weight", "bias")
            value_bias_name = layer_name.replace("weight", "bias")
            qkv_bias_name = qkv_layer_name.replace("weight", "bias")
            pruned_qk_bias = torch.index_select(copied_model_params[qk_bias_name], 0, qk_keep_index)
            pruned_value_bias = torch.index_select(copied_model_params[value_bias_name], 0, value_keep_index)
            pruned_query_key_value_bias = torch.cat((pruned_qk_bias, pruned_value_bias), 0)
            pruned_model_params[qkv_bias_name] = pruned_query_key_value_bias
            
            dense_layer_name = layer_name.replace("value", "dense")
            pruned_dense_input_weights = torch.index_select(copied_model_params[dense_layer_name], -1, value_keep_index)
            pruned_model_params[dense_layer_name] = pruned_dense_input_weights

        else:
            # Prune when nodes are the input
            num_nodes_to_drop = int(sum(layer_masks[layer_name][mask_index]).item())
            keep_index = torch.sort(layer_masks[layer_name][sorted_weight_index][num_nodes_to_drop:]).values
            pruned_input_weights = torch.index_select(copied_model_params[layer_name], -1, keep_index)
            pruned_model_params[layer_name] = pruned_input_weights

            # Go to previous layer and prune when nodes are the output
            prev_layer_index = params_to_index[layer_name] - 1
            prev_layer_name = index_to_params[prev_layer_index]

            if "layernorm" in prev_layer_name:
                pruned_layer_norm_weights = torch.index_select(copied_model_params[prev_layer_name], 0, keep_index)
                pruned_model_params[prev_layer_name] = pruned_layer_norm_weights 

                # Also do bias term
                bias_layer_name = prev_layer_name.replace("weight", "bias")
                pruned_bias_weights = torch.index_select(copied_model_params[bias_layer_name], 0, keep_index)
                pruned_model_params[bias_layer_name] = pruned_bias_weights

                prev_layer_index = prev_layer_index - 1
                prev_layer_name = index_to_params[prev_layer_index]

            pruned_output_weights = torch.index_select(copied_model_params[prev_layer_name], 0, keep_index)
            pruned_model_params[prev_layer_name] = pruned_output_weights

            # Also do bias term
            bias_layer_name = prev_layer_name.replace("weight", "bias")
            pruned_bias_weights = torch.index_select(copied_model_params[bias_layer_name], 0, keep_index)
            pruned_model_params[bias_layer_name] = pruned_bias_weights
            
    elif "self_attention.value.weight" in layer_name:
        value_weights = copied_model_params[layer_name]
        query_key_layer = layer_name.replace("value", "query_key")
        query_key_weights = copied_model_params[query_key_layer]
        
        new_layer_name = layer_name.replace("value", "query_key_value")
        new_layer_weights = torch.cat((query_key_weights, value_weights), 0)
        pruned_model_params[new_layer_name] = new_layer_weights
        
        # Do bias
        value_bias_name = layer_name.replace("weight", "bias")
        query_key_bias_name = query_key_layer.replace("weight", "bias")
        new_layer_bias_name = new_layer_name.replace("weight", "bias")
        
        value_bias = copied_model_params[value_bias_name]
        query_key_bias = copied_model_params[query_key_bias_name]
        new_layer_bias = torch.cat((query_key_bias, value_bias), 0)
        pruned_model_params[new_layer_bias_name] = new_layer_bias
        
        # Do dense layer
        dense_layer_name = layer_name.replace("value", "dense")
        pruned_model_params[dense_layer_name] = copied_model_params[dense_layer_name]
        
    elif "self_attention.value.bias" in layer_name:
        continue
    elif "self_attention.query_key.weight" in layer_name:
        continue
    elif "self_attention.query_key.bias" in layer_name:
        continue
    elif "self_attention.dense.weight" in layer_name:
        continue
    else:
        pruned_model_params[layer_name] = copied_model_params[layer_name]

In [40]:
pruned_model_params

OrderedDict([('transformer.word_embeddings.weight',
              tensor([[-0.0099, -0.0048, -0.0111,  ..., -0.0426,  0.0099,  0.0212],
                      [ 0.0048, -0.0127,  0.0138,  ..., -0.0448,  0.0003, -0.0120],
                      [ 0.0065,  0.0239,  0.0050,  ..., -0.0431, -0.0067,  0.0137],
                      ...,
                      [-0.0028, -0.0038, -0.0012,  ..., -0.0252,  0.0013,  0.0012],
                      [-0.0028, -0.0038, -0.0012,  ..., -0.0252,  0.0013,  0.0012],
                      [-0.0028, -0.0038, -0.0012,  ..., -0.0252,  0.0013,  0.0012]])),
             ('transformer.word_embeddings_layernorm.weight',
              tensor([0.4409, 0.3167, 0.4749,  ..., 0.0816, 0.2927, 0.6006])),
             ('transformer.word_embeddings_layernorm.bias',
              tensor([-0.0513,  0.0164,  0.0052,  ...,  0.2412,  0.0072, -0.0292])),
             ('transformer.h.0.input_layernorm.weight',
              tensor([0.6621, 0.8457, 0.5884,  ..., 3.2090, 0.9180, 0.55

In [41]:
torch.save(pruned_model_params, "pruned_30percent_560m_bloom.pt")

In [42]:
state_dict_shapes = {}

In [43]:
num_params_now = 0
all_num_heads = []
for param_name in pruned_model_params.keys():
    shape = pruned_model_params[param_name].shape
    state_dict_shapes[param_name] = shape
    print(param_name, shape)
    
    if "self_attention.query_key_value.bias" in param_name:
        num_heads = shape[0] // 3 // head_dim
        all_num_heads.append(num_heads)
    
    if param_name == "lm_head.weight":
        continue
    elif len(shape) == 1:
        num_params_now += shape[0]
    else:
        num_params_now += (shape[0] * shape[1])
    
pkl.dump((state_dict_shapes, all_num_heads), open("pruned_30percent_560m_bloom_state_dict_shapes.pkl", "wb"))

transformer.word_embeddings.weight torch.Size([250880, 1024])
transformer.word_embeddings_layernorm.weight torch.Size([1024])
transformer.word_embeddings_layernorm.bias torch.Size([1024])
transformer.h.0.input_layernorm.weight torch.Size([1024])
transformer.h.0.input_layernorm.bias torch.Size([1024])
transformer.h.0.self_attention.query_key_value.weight torch.Size([768, 1024])
transformer.h.0.self_attention.query_key_value.bias torch.Size([768])
transformer.h.0.self_attention.dense.weight torch.Size([1024, 256])
transformer.h.0.self_attention.dense.bias torch.Size([1024])
transformer.h.0.post_attention_layernorm.weight torch.Size([1024])
transformer.h.0.post_attention_layernorm.bias torch.Size([1024])
transformer.h.0.mlp.dense_h_to_4h.weight torch.Size([1434, 1024])
transformer.h.0.mlp.dense_h_to_4h.bias torch.Size([1434])
transformer.h.0.mlp.dense_4h_to_h.weight torch.Size([1024, 1434])
transformer.h.0.mlp.dense_4h_to_h.bias torch.Size([1024])
transformer.h.1.input_layernorm.weight to

In [44]:
num_params_now / 559214592

0.6909103884756999

In [45]:
all_num_heads

[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 5, 4, 11, 16, 16, 4]

In [46]:
node_num / (len(all_nodes) + (24 * 1024 * 2))

0.5639706566220238

In [47]:
node_num

97021