In [1]:
from node_attribution.node_attribution import get_attributions

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_size = "560m"
data = ["Hello, I am an AlexPrize chatbot"]
prune_percent = 0.30
num_params_to_prune = 816115712 * prune_percent

In [3]:
num_params_to_prune

244834713.6

In [4]:
# Get attributions
avg_contributions, max_contributions, model, model_params = get_attributions(model_size, data)

Finished loading bigscience/bloom-560m.
                    Num Transformer Blocks: 24
                    Num Attention Heads: 16
                    Head Dim: 64
                    Hidden Size: 1024
                    Base Model Param Count: 559,214,592
                    Total Param Count (w/ LM Head): 816,115,712


In [22]:
layer_shape_map = {}
layer_names, contribution_tensors = zip(*avg_contributions.items())
num_layers = len(layer_names)

for i in range(num_layers):
    layer_name = layer_names[i]
    layer_size = contribution_tensors[i].shape[0]
    
    if i != 0:
        next_layer_size = contribution_tensors[i - 1].shape[0]
    else:
        next_layer_size = 250880
        
    
    if i != (num_layers - 1):
        prev_layer_size = contribution_tensors[i + 1].shape[0]
    else:
        prev_layer_size = 1024
        
    layer_shape_map[layer_name] = {
        "prev_layer_size": prev_layer_size,
        "next_layer_size": next_layer_size,
        "current_layer_size": layer_size
    }

In [24]:
layer_shape_map


{'lm_head.weight': {'prev_layer_size': 4096,
  'next_layer_size': 250880,
  'current_layer_size': 1024},
 'transformer.h.23.mlp.dense_4h_to_h.weight': {'prev_layer_size': 1024,
  'next_layer_size': 1024,
  'current_layer_size': 4096},
 'transformer.h.23.mlp.dense_h_to_4h.weight': {'prev_layer_size': 1024,
  'next_layer_size': 4096,
  'current_layer_size': 1024},
 'transformer.h.23.self_attention.dense.weight': {'prev_layer_size': 3072,
  'next_layer_size': 1024,
  'current_layer_size': 1024},
 'transformer.h.23.self_attention.query_key_value_fused_output.weight': {'prev_layer_size': 1024,
  'next_layer_size': 1024,
  'current_layer_size': 3072},
 'transformer.h.23.self_attention.query_key_value.weight': {'prev_layer_size': 4096,
  'next_layer_size': 3072,
  'current_layer_size': 1024},
 'transformer.h.22.mlp.dense_4h_to_h.weight': {'prev_layer_size': 1024,
  'next_layer_size': 1024,
  'current_layer_size': 4096},
 'transformer.h.22.mlp.dense_h_to_4h.weight': {'prev_layer_size': 1024,
 

In [12]:
all_nodes = []

for layer_name, contribution_tensor in avg_contributions.items():
    for node_id, node in enumerate(contribution_tensor.tolist()):
        node_name = f"{layer_name}.{node_id}"
        all_nodes.append((node_name, node))
    

In [29]:
# Sort all the nodes
all_nodes.sort(key = lambda x:x[1])

In [25]:
# Iterate through nodes to make the prune list. 
# Whenever a node is added to the prune list, check total number or params being 
nodes_to_prune = []
num_params_pruned = 0
node_num = 0

while num_params_pruned < num_params_to_prune:
    node = all_nodes[node_num]
    node_name = node[0]
    layer_name = ".".join(node_name.split(".")[:-1])
    print(layer_name)
                     
    nodes_to_prune.append(node)
    shapes = layer_shape_map[layer_name]
    num_params_pruned += shapes["prev_layer_size"]
    num_params_pruned += shapes["next_layer_size"]
    
    node_num += 1
    
print(node_num)
    


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [30]:
node_num / len(all_nodes)

0.22894920254149378

In [26]:
node_num

56501

In [31]:
nodes_to_prune

[('transformer.h.22.self_attention.query_key_value_fused_output.weight.1787',
  -17626.154296875),
 ('transformer.h.22.self_attention.query_key_value_fused_output.weight.1851',
  -17626.154296875),
 ('transformer.h.22.self_attention.query_key_value_fused_output.weight.1807',
  -5429.0966796875),
 ('transformer.h.22.self_attention.query_key_value_fused_output.weight.1743',
  -5429.09619140625),
 ('transformer.h.21.self_attention.query_key_value_fused_output.weight.2898',
  -4951.94775390625),
 ('transformer.h.21.self_attention.query_key_value_fused_output.weight.2962',
  -4951.947265625),
 ('transformer.h.22.self_attention.query_key_value_fused_output.weight.1785',
  -4597.2646484375),
 ('transformer.h.22.self_attention.query_key_value_fused_output.weight.1849',
  -4597.2646484375),
 ('transformer.h.22.self_attention.query_key_value_fused_output.weight.1779',
  -4437.36474609375),
 ('transformer.h.22.self_attention.query_key_value_fused_output.weight.1843',
  -4437.36474609375),
 ('tran

In [32]:
for name, param in model.named_parameters():
    if "bias" in name:
        continue

    print(name, param.shape)

transformer.word_embeddings.weight torch.Size([250880, 1024])
transformer.word_embeddings_layernorm.weight torch.Size([1024])
transformer.h.0.input_layernorm.weight torch.Size([1024])
transformer.h.0.self_attention.query_key_value.weight torch.Size([3072, 1024])
transformer.h.0.self_attention.dense.weight torch.Size([1024, 1024])
transformer.h.0.post_attention_layernorm.weight torch.Size([1024])
transformer.h.0.mlp.dense_h_to_4h.weight torch.Size([4096, 1024])
transformer.h.0.mlp.dense_4h_to_h.weight torch.Size([1024, 4096])
transformer.h.1.input_layernorm.weight torch.Size([1024])
transformer.h.1.self_attention.query_key_value.weight torch.Size([3072, 1024])
transformer.h.1.self_attention.dense.weight torch.Size([1024, 1024])
transformer.h.1.post_attention_layernorm.weight torch.Size([1024])
transformer.h.1.mlp.dense_h_to_4h.weight torch.Size([4096, 1024])
transformer.h.1.mlp.dense_4h_to_h.weight torch.Size([1024, 4096])
transformer.h.2.input_layernorm.weight torch.Size([1024])
transf

# Pruning on a Layer Scale