In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM
import torch

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [4]:
# Load the Llama2-7B model and tokenizer
model_name = "meta-llama/Llama-2-7b-hf"  # Replace with the actual model name
access_token = "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"
tokenizer = LlamaTokenizer.from_pretrained(model_name, use_fast=False, token=access_token)
model = LlamaForCausalLM.from_pretrained(model_name, device_map="auto", token=access_token)

# Define a sample input
sample_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer(sample_text, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}  # Move inputs to GPU

# Record norms for attention heads and MLP activations during the first forward pass
attention_head_norms = []
mlp_neuron_norms = []

def hook_attention_head(module, inputs, outputs):
    # Record the norm of the attention head outputs
    attention_head_norms.append(torch.norm(outputs, dim=-1).mean(dim=(0, 1)).cpu())

def hook_mlp_neurons(module, inputs, outputs):
    # Record the norm of MLP neuron activations
    mlp_neuron_norms.append(torch.norm(outputs, dim=0).cpu())

# Register hooks for attention heads and MLP layers
attention_hooks = []
mlp_hooks = []

for name, module in model.named_modules():
    if "attn.out_proj" in name:  # Attention output projection
        attention_hooks.append(module.register_forward_hook(hook_attention_head))
    elif "mlp" in name:  # MLP layers
        mlp_hooks.append(module.register_forward_hook(hook_mlp_neurons))

Downloading shards: 100%|██████████| 2/2 [02:05<00:00, 62.53s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.46s/it]


In [8]:
# Perform the forward pass
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)

# Remove hooks
for hook in attention_hooks + mlp_hooks:
    hook.remove()

# # Analyze sparsity in attention heads and MLP activations
# attention_head_sparsity = 100 * (1 - torch.stack(attention_head_norms) > 0.01).float().mean().item()
# mlp_neuron_sparsity = 100 * (1 - torch.stack(mlp_neuron_norms) > 0.01).float().mean().item()

# print(attention_head_sparsity, mlp_neuron_sparsity)



In [12]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [15]:
print(len(outputs.attentions))
print(outputs.attentions[0].shape)
norms = outputs.attentions[0].abs().sum(dim=-1).mean(dim=-2)  # Average norm per head
print(norms.shape)

32
torch.Size([1, 32, 13, 13])
torch.Size([1, 13])


In [16]:
# Define sparsity tracking
def track_attention_and_ffn_sparsity(model, inputs):
    attention_sparsity = []
    ffn_sparsity = []
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
        
        # Analyze sparsity in attention heads
        for attention in outputs.attentions:  # B x Head x L x L
            norms = attention.abs().sum(dim=-1).mean(dim=-2)  # Average norm per head
            sparsity = (norms < 0.1).float().mean().item() * 100  # % of heads below threshold
            attention_sparsity.append(sparsity)
        
        # Analyze sparsity in feed-forward neurons
        for hidden_state in outputs.hidden_states:  # B x L x D
            norms = hidden_state.abs().mean(dim=0)  # Average norm per neuron
            sparsity = (norms < 0.1).float().mean().item() * 100  # % of neurons below threshold
            ffn_sparsity.append(sparsity)
    
    return attention_sparsity, ffn_sparsity

# Compute sparsity
attention_sparsity, ffn_sparsity = track_attention_and_ffn_sparsity(model, inputs)
print(attention_sparsity)
print(ffn_sparsity)

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[99.98873472213745, 98.24782609939575, 85.32715439796448, 80.77486753463745, 72.77268767356873, 60.888671875, 51.7822265625, 44.79604959487915, 40.84472954273224, 36.23046875, 32.45192468166351, 29.383262991905212, 28.280875086784363, 25.828200578689575, 23.955830931663513, 20.941632986068726, 18.0908203125, 16.586539149284363, 15.25503396987915, 13.713191449642181, 12.336613982915878, 11.525315791368484, 10.439828783273697, 9.870793670415878, 9.510216861963272, 8.749625086784363, 8.501727879047394, 8.146785199642181, 7.557091861963272, 7.127028703689575, 6.156099960207939, 7.279147207736969, 5.357947945594788]
