In [1]:
import torch
import numpy as np
import pickle as pkl

from node_attribution.blender_bot_gradient_node_attribution import get_attributions

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
human_filtered_pairs = pkl.load(open("44_human_filtered_conv_pairs.pkl", "rb"))

In [3]:
calibration_data = human_filtered_pairs[:22]

In [4]:
# Get attributions
model_size = "400M-distill"
avg_contributions, max_contributions, model, model_params = get_attributions(model_size, calibration_data)

Finished loading facebook/blenderbot-400M-distill.
                    Num Encoder Transformer Blocks: 2
                    Num Decoder Transformer Blocks: 12
                    Num Encoder Attention Heads: 32
                    Num Decoder Attention Heads: 32
                    Encoder Head Dim: 40
                    Decoder Head Dim: 40
                    Hidden Size: 1280
                    Base Model Param Count: 364,802,560
                    Total Param Count (w/ LM Head): 375,052,800
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22


In [5]:
# take absolute value of average contributions
for layer_name, contribution_tensor in avg_contributions.items():
    avg_contributions[layer_name] = torch.abs(avg_contributions[layer_name])

In [6]:
pkl.dump(avg_contributions, open("avg_contri_blender_bot_400M_22pair_calibration.pkl", "wb"))

In [7]:
pkl.dump(max_contributions, open("max_contri_blender_bot_400M_22pair_calibration.pkl", "wb"))

In [28]:
prune_percent = 0.0
num_params_to_prune = 375052800 * prune_percent

head_dim = 40
num_heads = 32

In [29]:
index = 0
params_to_index = {}
for param_name in model_params.keys():
    if "bias" not in param_name:
        params_to_index[param_name] = index
        index += 1
    
index_to_params = {params_to_index[param_name]: param_name for param_name in params_to_index.keys()}

In [30]:
all_nodes = []

for layer_name, contribution_tensor in avg_contributions.items():
    if "fc2.weight" in layer_name:
        for node_id, node in enumerate(contribution_tensor.tolist()):
            node_name = f"{layer_name}.{node_id}"
            all_nodes.append((node_name, node))

In [31]:
all_layers = []

# Figure out which layers have the most to prune
for layer_name, contribution_tensor in avg_contributions.items():
    
    if "fc2.weight" in layer_name:
        # Get average contribution over the whole layer
        mean_contribution = torch.mean(contribution_tensor, 0).item()
        all_layers.append((layer_name, mean_contribution))

In [32]:
all_layers.sort(key = lambda x:x[1])

In [33]:
all_layers

[('model.encoder.layers.0.fc2.weight', 0.00012805909500457346),
 ('model.encoder.layers.1.fc2.weight', 0.00016461317136418074),
 ('model.decoder.layers.0.fc2.weight', 0.0002610544615890831),
 ('model.decoder.layers.1.fc2.weight', 0.00031016708817332983),
 ('model.decoder.layers.2.fc2.weight', 0.00037199354846961796),
 ('model.decoder.layers.3.fc2.weight', 0.00040806649485602975),
 ('model.decoder.layers.4.fc2.weight', 0.00044847078970633447),
 ('model.decoder.layers.9.fc2.weight', 0.00045117439003661275),
 ('model.decoder.layers.8.fc2.weight', 0.0004720468132290989),
 ('model.decoder.layers.10.fc2.weight', 0.00047606881707906723),
 ('model.decoder.layers.7.fc2.weight', 0.0004931202274747193),
 ('model.decoder.layers.5.fc2.weight', 0.000497086439281702),
 ('model.decoder.layers.6.fc2.weight', 0.0005099925328977406),
 ('model.decoder.layers.11.fc2.weight', 0.0005775314057245851)]

In [34]:
layer_masks = {}
num_params_pruned = 0
node_num = 0
min_nodes = 24

while num_params_pruned < num_params_to_prune:
    lowest_contr_layer_name = all_layers[0][0]               
    stop_pruning_layer = False
 
    # Prune one node at time
    if lowest_contr_layer_name not in layer_masks:
        layer_contributions = avg_contributions[lowest_contr_layer_name]
        mask = torch.zeros_like(layer_contributions)
        sorted_contributions = torch.argsort(layer_contributions)
        num_pruned = 0

    else:
        mask = layer_masks[lowest_contr_layer_name][0]
        sorted_contributions = layer_masks[lowest_contr_layer_name][1]
        num_pruned = layer_masks[lowest_contr_layer_name][2]

    index_to_mask = sorted_contributions[num_pruned]
    mask[index_to_mask] = 1

    nodes_left = torch.numel(mask) - int(torch.sum(mask).item())

    # Keep from deleting all nodes in a layer
    if nodes_left > min_nodes:
        layer_masks[lowest_contr_layer_name] = (mask, sorted_contributions, num_pruned + 1)
        num_params_pruned += 2560
        node_num += 1
    else:
        stop_pruning_layer = True

    # Apply mask and update the layer mean in "all_layers"
    if stop_pruning_layer:
        new_layer_contr_score = float('inf')
    else:
        mean_array = np.ma.array(data=avg_contributions[lowest_contr_layer_name], mask=mask)
        new_layer_contr_score = mean_array.mean()

    all_layers[0] = (lowest_contr_layer_name, new_layer_contr_score)

    # re-sort layers now that this one has been pruned and pick the lowest contributing layer again
    all_layers.sort(key = lambda x:x[1])

In [35]:
layer_masks

{}

In [36]:
# Line up weights to prune and weights in the state dict
mask_index = 0
sorted_weight_index = 1
pruned_model_params = model_params.copy()

for layer_name in layer_masks.keys():
    # Prune when nodes are the input
    num_nodes_to_drop = int(sum(layer_masks[layer_name][mask_index]).item())
    keep_index = torch.sort(layer_masks[layer_name][sorted_weight_index][num_nodes_to_drop:]).values
    pruned_input_weights = torch.index_select(pruned_model_params[layer_name], -1, keep_index)
    pruned_model_params[layer_name] = pruned_input_weights

    # Go to previous layer and prune when nodes are the output
    prev_layer_index = params_to_index[layer_name] - 1
    prev_layer_name = index_to_params[prev_layer_index]

    pruned_output_weights = torch.index_select(pruned_model_params[prev_layer_name], 0, keep_index)
    pruned_model_params[prev_layer_name] = pruned_output_weights

    # Also do bias term
    bias_layer_name = prev_layer_name.replace("weight", "bias")
    pruned_bias_weights = torch.index_select(pruned_model_params[bias_layer_name], 0, keep_index)
    pruned_model_params[bias_layer_name] = pruned_bias_weights

In [37]:
torch.save(pruned_model_params, "pruned_400m_blender_bot2.pt")

In [38]:
state_dict_shapes = {}

In [39]:
for param_name in pruned_model_params.keys():
    state_dict_shapes[param_name] = pruned_model_params[param_name].shape
    print(param_name, pruned_model_params[param_name].shape)
    
pkl.dump(state_dict_shapes, open("pruned_400m_blender_bot2_state_dict_shapes.pkl", "wb"))

final_logits_bias torch.Size([1, 8008])
model.shared.weight torch.Size([8008, 1280])
model.encoder.embed_tokens.weight torch.Size([8008, 1280])
model.encoder.embed_positions.weight torch.Size([128, 1280])
model.encoder.layers.0.self_attn.k_proj.weight torch.Size([1280, 1280])
model.encoder.layers.0.self_attn.k_proj.bias torch.Size([1280])
model.encoder.layers.0.self_attn.v_proj.weight torch.Size([1280, 1280])
model.encoder.layers.0.self_attn.v_proj.bias torch.Size([1280])
model.encoder.layers.0.self_attn.q_proj.weight torch.Size([1280, 1280])
model.encoder.layers.0.self_attn.q_proj.bias torch.Size([1280])
model.encoder.layers.0.self_attn.out_proj.weight torch.Size([1280, 1280])
model.encoder.layers.0.self_attn.out_proj.bias torch.Size([1280])
model.encoder.layers.0.self_attn_layer_norm.weight torch.Size([1280])
model.encoder.layers.0.self_attn_layer_norm.bias torch.Size([1280])
model.encoder.layers.0.fc1.weight torch.Size([5120, 1280])
model.encoder.layers.0.fc1.bias torch.Size([5120])

In [40]:
node_num / len(all_nodes)

0.0

In [41]:
node_num

0