In [1]:
import torch
import numpy as np
import pickle as pkl

from node_attribution.blender_bot_gradient_node_attribution import get_attributions

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
human_filtered_pairs = pkl.load(open("44_human_filtered_conv_pairs.pkl", "rb"))

In [3]:
calibration_data = human_filtered_pairs[:22]

In [75]:
# Get attributions
model_size = "3B"
avg_contributions, max_contributions, model, model_params = get_attributions(model_size, calibration_data)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Downloading (…)okenizer_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 74.0/74.0 [00:00<00:00, 25.1kB/s]
Downloading (…)lve/main/config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1.29k/1.29k [00:00<00:00, 471kB/s]
Downloading (…)olve/main/vocab.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 150k/150k [00:00<00:00, 952kB/s]
Downloading (…)olve/main/merges.txt: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 62.9k/62.9k [00:00<00:00, 423kB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 130/130 [00:00<00:00, 15.0kB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings

Finished loading facebook/blenderbot-3B.
                    Num Encoder Transformer Blocks: 2
                    Num Decoder Transformer Blocks: 24
                    Num Encoder Attention Heads: 32
                    Num Decoder Attention Heads: 32
                    Encoder Head Dim: 80
                    Decoder Head Dim: 80
                    Hidden Size: 2560
                    Base Model Param Count: 2,696,268,800
                    Total Param Count (w/ LM Head): 2,716,769,280
final_logits_bias torch.Size([1, 8008])
model.shared.weight torch.Size([8008, 2560])
model.encoder.embed_tokens.weight torch.Size([8008, 2560])
model.encoder.embed_positions.weight torch.Size([128, 2560])
model.encoder.layers.0.self_attn.k_proj.weight torch.Size([2560, 2560])
model.encoder.layers.0.self_attn.k_proj.bias torch.Size([2560])
model.encoder.layers.0.self_attn.v_proj.weight torch.Size([2560, 2560])
model.encoder.layers.0.self_attn.v_proj.bias torch.Size([2560])
model.encoder.layers.0.se

In [76]:
# take absolute value of average contributions
for layer_name, contribution_tensor in avg_contributions.items():
    avg_contributions[layer_name] = torch.abs(avg_contributions[layer_name])

In [77]:
pkl.dump(avg_contributions, open("avg_contri_blender_bot_3B_22pair_calibration.pkl", "wb"))

In [78]:
pkl.dump(max_contributions, open("max_contri_blender_bot_3B_22pair_calibration.pkl", "wb"))

In [123]:
prune_percent = 0.15
num_params_to_prune = 2696268800 * prune_percent

head_dim = 80
num_heads = 32

In [124]:
index = 0
params_to_index = {}
for param_name in model_params.keys():
    if "bias" not in param_name:
        params_to_index[param_name] = index
        index += 1
    
index_to_params = {params_to_index[param_name]: param_name for param_name in params_to_index.keys()}

In [125]:
all_nodes = []

for layer_name, contribution_tensor in avg_contributions.items():
    if "fc2.weight" in layer_name:
        for node_id, node in enumerate(contribution_tensor.tolist()):
            node_name = f"{layer_name}.{node_id}"
            all_nodes.append((node_name, node))

In [126]:
all_layers = []

# Figure out which layers have the most to prune
for layer_name, contribution_tensor in avg_contributions.items():
    
    if "fc2.weight" in layer_name:
        # Get average contribution over the whole layer
        mean_contribution = torch.mean(contribution_tensor, 0).item()
        all_layers.append((layer_name, mean_contribution))

In [127]:
all_layers.sort(key = lambda x:x[1])

In [128]:
all_layers

[('model.encoder.layers.0.fc2.weight', 1.6720172425266355e-05),
 ('model.encoder.layers.1.fc2.weight', 2.361466977163218e-05),
 ('model.decoder.layers.22.fc2.weight', 9.73643982433714e-05),
 ('model.decoder.layers.21.fc2.weight', 0.0001065301476046443),
 ('model.decoder.layers.20.fc2.weight', 0.00010903981456067413),
 ('model.decoder.layers.23.fc2.weight', 0.00010948643466690555),
 ('model.decoder.layers.19.fc2.weight', 0.00011663961049634963),
 ('model.decoder.layers.18.fc2.weight', 0.00012284422700759023),
 ('model.decoder.layers.17.fc2.weight', 0.00012652072473429143),
 ('model.decoder.layers.16.fc2.weight', 0.00013740375288762152),
 ('model.decoder.layers.15.fc2.weight', 0.0001586541038705036),
 ('model.decoder.layers.14.fc2.weight', 0.00016885095101315528),
 ('model.decoder.layers.13.fc2.weight', 0.00018126037321053445),
 ('model.decoder.layers.12.fc2.weight', 0.00018229440320283175),
 ('model.decoder.layers.11.fc2.weight', 0.00021381850820034742),
 ('model.decoder.layers.0.fc2.we

In [129]:
layer_masks = {}
num_params_pruned = 0
node_num = 0
min_nodes = 24

while num_params_pruned < num_params_to_prune:
    lowest_contr_layer_name = all_layers[0][0]               
    stop_pruning_layer = False
 
    # Prune one node at time
    if lowest_contr_layer_name not in layer_masks:
        layer_contributions = avg_contributions[lowest_contr_layer_name]
        mask = torch.zeros_like(layer_contributions)
        sorted_contributions = torch.argsort(layer_contributions)
        num_pruned = 0

    else:
        mask = layer_masks[lowest_contr_layer_name][0]
        sorted_contributions = layer_masks[lowest_contr_layer_name][1]
        num_pruned = layer_masks[lowest_contr_layer_name][2]

    index_to_mask = sorted_contributions[num_pruned]
    mask[index_to_mask] = 1

    nodes_left = torch.numel(mask) - int(torch.sum(mask).item())

    # Keep from deleting all nodes in a layer
    if nodes_left > min_nodes:
        layer_masks[lowest_contr_layer_name] = (mask, sorted_contributions, num_pruned + 1)
        num_params_pruned += 5120
        node_num += 1
    else:
        stop_pruning_layer = True

    # Apply mask and update the layer mean in "all_layers"
    if stop_pruning_layer:
        new_layer_contr_score = float('inf')
    else:
        mean_array = np.ma.array(data=avg_contributions[lowest_contr_layer_name], mask=mask)
        new_layer_contr_score = mean_array.mean()

    all_layers[0] = (lowest_contr_layer_name, new_layer_contr_score)

    # re-sort layers now that this one has been pruned and pick the lowest contributing layer again
    all_layers.sort(key = lambda x:x[1])

In [130]:
layer_masks

{'model.encoder.layers.0.fc2.weight': (tensor([1., 1., 1.,  ..., 1., 1., 1.]),
  tensor([5048, 8807, 1189,  ..., 3884, 5643, 4459]),
  10215),
 'model.encoder.layers.1.fc2.weight': (tensor([1., 1., 1.,  ..., 1., 1., 1.]),
  tensor([2575, 4733, 6676,  ..., 9328, 7370, 6048]),
  10150),
 'model.decoder.layers.22.fc2.weight': (tensor([1., 1., 1.,  ..., 0., 0., 1.]),
  tensor([ 740, 2624, 2448,  ..., 4023, 8360, 8561]),
  7573),
 'model.decoder.layers.21.fc2.weight': (tensor([0., 1., 1.,  ..., 1., 0., 0.]),
  tensor([ 3977,  5433,  7186,  ...,  9161, 10129,  6500]),
  7104),
 'model.decoder.layers.20.fc2.weight': (tensor([0., 1., 1.,  ..., 1., 1., 1.]),
  tensor([10016,  5858,  3353,  ...,  9204,  7077,  5807]),
  6894),
 'model.decoder.layers.23.fc2.weight': (tensor([1., 1., 0.,  ..., 1., 1., 0.]),
  tensor([5087, 5559, 3759,  ..., 7899, 4355,  709]),
  6362),
 'model.decoder.layers.19.fc2.weight': (tensor([1., 0., 1.,  ..., 0., 1., 0.]),
  tensor([3420, 3955, 8052,  ..., 2362,  234, 1810

In [131]:
# Line up weights to prune and weights in the state dict
mask_index = 0
sorted_weight_index = 1
pruned_model_params = model_params.copy()

for layer_name in layer_masks.keys():
    # Prune when nodes are the input
    num_nodes_to_drop = int(sum(layer_masks[layer_name][mask_index]).item())
    keep_index = torch.sort(layer_masks[layer_name][sorted_weight_index][num_nodes_to_drop:]).values
    pruned_input_weights = torch.index_select(pruned_model_params[layer_name], -1, keep_index)
    pruned_model_params[layer_name] = pruned_input_weights

    # Go to previous layer and prune when nodes are the output
    prev_layer_index = params_to_index[layer_name] - 1
    prev_layer_name = index_to_params[prev_layer_index]

    pruned_output_weights = torch.index_select(pruned_model_params[prev_layer_name], 0, keep_index)
    pruned_model_params[prev_layer_name] = pruned_output_weights

    # Also do bias term
    bias_layer_name = prev_layer_name.replace("weight", "bias")
    pruned_bias_weights = torch.index_select(pruned_model_params[bias_layer_name], 0, keep_index)
    pruned_model_params[bias_layer_name] = pruned_bias_weights

In [132]:
torch.save(pruned_model_params, "pruned_3B_blender_bot2.pt")

In [133]:
state_dict_shapes = {}

In [134]:
for param_name in pruned_model_params.keys():
    state_dict_shapes[param_name] = pruned_model_params[param_name].shape
    print(param_name, pruned_model_params[param_name].shape)
    
pkl.dump(state_dict_shapes, open("pruned_3B_blender_bot2_state_dict_shapes.pkl", "wb"))

final_logits_bias torch.Size([1, 8008])
model.shared.weight torch.Size([8008, 2560])
model.encoder.embed_tokens.weight torch.Size([8008, 2560])
model.encoder.embed_positions.weight torch.Size([128, 2560])
model.encoder.layers.0.self_attn.k_proj.weight torch.Size([2560, 2560])
model.encoder.layers.0.self_attn.k_proj.bias torch.Size([2560])
model.encoder.layers.0.self_attn.v_proj.weight torch.Size([2560, 2560])
model.encoder.layers.0.self_attn.v_proj.bias torch.Size([2560])
model.encoder.layers.0.self_attn.q_proj.weight torch.Size([2560, 2560])
model.encoder.layers.0.self_attn.q_proj.bias torch.Size([2560])
model.encoder.layers.0.self_attn.out_proj.weight torch.Size([2560, 2560])
model.encoder.layers.0.self_attn.out_proj.bias torch.Size([2560])
model.encoder.layers.0.self_attn_layer_norm.weight torch.Size([2560])
model.encoder.layers.0.self_attn_layer_norm.bias torch.Size([2560])
model.encoder.layers.0.fc1.weight torch.Size([24, 2560])
model.encoder.layers.0.fc1.bias torch.Size([24])
mod

In [135]:
node_num / len(all_nodes)

0.29669846754807694

In [136]:
node_num

78993