In [1]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.cuda import empty_cache

In [None]:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to("cuda")
model.eval()

In [3]:
def cleanup_hooks():
    for handle in activation_handles:
        handle.remove()
    for handle in attention_handles:
        handle.remove()

In [4]:
activations = {}
attentions = {}

activation_handles = []
attention_handles = []

def activation_hook(layer_name):
    def hook(module, input, output):
        if layer_name not in activations:
            activations[layer_name] = []
        activations[layer_name].append(output.detach().cpu())
    return hook

def attention_hook(layer_name):
    def hook(module, input, output):
        _, attn_weights, _ = output
        if layer_name not in attentions:
            attentions[layer_name] = []
        attentions[layer_name].append(attn_weights.detach().cpu())
    return hook

for i, block in enumerate(model.model.layers):
    if hasattr(block, "mlp"):
        activation_handles.append(block.mlp.register_forward_hook(activation_hook(f"layer_{i}_mlp")))
    if hasattr(block, "self_attn"):
        attention_handles.append(block.self_attn.register_forward_hook(attention_hook(f"layer_{i}_self_attn")))

In [5]:
# cleanup_hooks()

In [6]:
dataset = load_dataset("lighteval/summarization", "cnn-dm")["train"]

In [7]:
def prep_toks(x):
    msg = [
        {
            "role": "system",
            "content": "You are an helpful AI assistant whose job is to provide a concise summarize of the given content"
        },
        {
            "role": "user",
            "content": x["article"]
        }
    ]
    return {
        "tokens": tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
    }

dataset = dataset.select(range(128)).map(prep_toks)

In [None]:
tmp = dataset.select_columns("tokens").to_dict()
input_prompts = [v for v in tmp.values()][0]
all_outputs = []

for bs in tqdm(range(0, len(input_prompts), 16)):
    be = bs+16
    input_tokens = tokenizer(input_prompts[bs:be], return_tensors="pt", truncation=True, padding=True, padding_side="left")
    input_tokens = {k: v.cuda() for k, v in input_tokens.items()}

    with torch.no_grad():
        outputs = model.generate(
            **input_tokens,
            return_dict_in_generate=True,
            output_attentions=True,
            output_hidden_states=True,
            max_length=1024
        )['sequences'].cpu()
        all_outputs.extend(outputs)

    empty_cache()

In [None]:
single_seq = all_outputs[1]
tokenizer.decode(single_seq)

In [None]:
# structure of activation hooks:
# [item] -> all steps
# item shape: batchsize, seq_len, dim

In [None]:
print("Activations:")
for key, value in activations.items():
    print(f"{key}: {len(value)}, {value[0].shape}")

# print("\nAttentions:")
# for key, value in attentions.items():
#     print(f"{key}: {len(value)}, {value[0].shape}")

In [None]:
def compute_activation_variance(activation_tensor):
    """
    Compute average variance of activations over tokens.
    Expected tensor shape: (batch, seq_len, hidden_dim)
    """
    # Variance along token dimension for each hidden unit
    # (if batch >1, variance computed per sample then averaged)
    var = activation_tensor.var(dim=1, unbiased=False)  # shape: (batch, hidden_dim)
    return var.mean().item()  # average variance across hidden units and batch

def compute_attention_entropy(attn_tensor):
    """
    Compute average entropy per head.
    Expected tensor shape: (batch, num_heads, seq_len, seq_len)
    """
    eps = 1e-9
    # Entropy: -sum(p * log(p)) over the key dimension (last dim)
    entropy = - (attn_tensor * torch.log(attn_tensor + eps)).sum(dim=-1)  # (batch, num_heads, seq_len)
    # Average over batch and query tokens
    avg_entropy = entropy.mean(dim=(0, 2))  # shape: (num_heads,)
    return avg_entropy.tolist()

def compute_attention_kl(attn_tensor):
    """
    Compute KL divergence of each head's attention distribution with respect to a uniform distribution.
    Expected tensor shape: (batch, num_heads, seq_len, seq_len)
    """
    eps = 1e-9
    batch, num_heads, seq_len, _ = attn_tensor.shape
    uniform = torch.full((seq_len,), 1.0/seq_len, device=attn_tensor.device)
    # Compute KL: sum(p * (log(p) - log(u))) for each query token
    kl = (attn_tensor * (torch.log(attn_tensor + eps) - torch.log(uniform + eps))).sum(dim=-1)  # (batch, num_heads, seq_len)
    avg_kl = kl.mean(dim=(0, 2))  # shape: (num_heads,)
    return avg_kl.tolist()

import torch
import math # Import math

def compute_attention_kl_attn_vs_uniform(attn_tensor):
    """
    Compute KL divergence KL(Attention || Uniform) for each head's attention distribution.
    Measures how much the attention distribution P deviates from a uniform distribution Q.
    KL(P || Q) = sum(P * log(P/Q)) = log(N) - Entropy(P)

    attn_tensor: shape (batch, num_heads, seq_len_query, seq_len_key), assumed normalized.
    Returns: List of average KL divergence per head. Shape: (num_heads,)
    """
    # Explicitly get all dimensions, especially L_key (last dimension)
    B, H, L_query, L_key = attn_tensor.shape
    eps = 1e-9 # Small epsilon

    # --- FIX: Use L_key for log(N) calculation ---
    if L_key <= 0:
        # This case should ideally not happen with attention tensors
        raise ValueError(f"Length of attention distribution (L_key) must be positive, but got L_key={L_key}")
    elif L_key == 1:
        # If sequence length is 1, KL divergence is 0
        log_L = 0.0
    else:
        # Calculate log(L_key). Use float64 intermediate for precision if desired.
        log_L = torch.log(torch.tensor(L_key, device=attn_tensor.device, dtype=torch.float64)).to(attn_tensor.dtype)
    # --- End of FIX ---

    # Entropy calculation H(P) = -sum(P * log(P + eps))
    # Sum over the distribution dimension L_key (the last dimension, -1)
    log_term = torch.log(attn_tensor + eps)
    entropy = - (attn_tensor * log_term).sum(dim=-1)  # Shape: (B, H, L_query)

    # Calculate KL divergence KL(P || Uniform) = log(L_key) - H(P)
    # Resulting kl_div tensor has shape corresponding to entropy: (B, H, L_query)
    kl_div = log_L - entropy

    # Average over batch (dim 0) and query tokens (dim 2, which is L_query)
    avg_kl_div = kl_div.mean(dim=(0, 2)) # Shape: (num_heads,)
    return avg_kl_div.tolist()

In [None]:
activation_variances = {}
attention_entropies = {}
attention_kl = {}

In [None]:
# Compute metrics for activations
for key, act_list in activations.items():
    variances = []
    for tensor in act_list:
        # Ensure tensor shape is (batch, seq_len, hidden_dim)
        variances.append(compute_activation_variance(tensor))
    activation_variances[key] = sum(variances) / len(variances)

# Compute metrics for attentions
for key, attn_list in attentions.items():
    entropies = []
    kls = []

    for tensor in attn_list:
        # validate_tensor(tensor)
        # Ensure tensor shape is (batch, num_heads, seq_len, seq_len)
        entropies.append(compute_attention_entropy(tensor))
        kls.append(compute_attention_kl_attn_vs_uniform(tensor))

    # Convert list of lists to tensor and average over decoding steps
    entropies_tensor = torch.tensor(entropies)  # shape: (num_steps, num_heads)
    kl_tensor = torch.tensor(kls)              # shape: (num_steps, num_heads)
    avg_entropy = entropies_tensor.mean(dim=0).tolist()
    avg_kl = kl_tensor.mean(dim=0).tolist()
    attention_entropies[key] = avg_entropy
    attention_kl[key] = avg_kl

In [None]:
# Print out the metrics
print("Activation Variances per layer:")
for key, var in activation_variances.items():
    print(f"{key}: {var}")

print("\nAttention Entropy per head per layer:")
for key, entropy in attention_entropies.items():
    print(f"{key}: {entropy}")

print("\nAttention KL Divergence per head per layer:")
for key, kl in attention_kl.items():
    print(f"{key}: {kl}")

In [None]:
def compute_composite_scores(activation_variances, attention_entropies, attention_kl, alpha=1.0, beta=1.0):
    # Normalize a metric dict to [0, 1] based on its own values.
    def normalize(metric_dict):
        if not metric_dict:
            return {}
        values = list(metric_dict.values())
        min_val, max_val = min(values), max(values)
        if max_val == min_val:
            return {k: 0.5 for k in metric_dict}
        return {k: (v - min_val) / (max_val - min_val) for k, v in metric_dict.items()}

    norm_act = normalize(activation_variances)

    norm_att_ent = {}
    for l, v in attention_entropies.items():
        h_max, h_min = max(v), min(v)
        norm_att_ent[l] = {i: (h - h_min) / (h_max - h_min) for i, h in enumerate(v)}

    norm_att_kl = {}
    for l, v in attention_kl.items():
        h_max, h_min = max(v), min(v)
        norm_att_kl[l] = {i: (h - h_min) / (h_max - h_min) for i, h in enumerate(v)}

    composite_scores = {}
    for l, heads in norm_att_ent.items():
        composite_scores[l] = {}
        for h, v in heads.items():
            composite_scores[l][h] = alpha * (1 - v) + beta * (1 - norm_att_kl[l][h])

    return norm_act, composite_scores


In [None]:
mlp_scores, attn_scores = compute_composite_scores(activation_variances, attention_entropies, attention_kl)

In [None]:
weight_contributions = {}

total_params = sum(p.numel() for p in model.parameters())

for n, param in model.named_parameters():
    weight_contributions[n] = param.numel() / total_params


def combine_layer_weights(model_state_dict, layer_num):
    layer_prefix = f"model.layers.{layer_num}."
    self_attn_prefix = f"{layer_prefix}self_attn."
    mlp_prefix = f"{layer_prefix}mlp."
    
    # Initialize combined weights
    combined = {
        "self_attn": 0.0,
        "mlp": 0.0,
        "overall": 0.0
    }
    
    # Count components to calculate average later
    counts = {
        "self_attn": 0,
        "mlp": 0,
        "overall": 0.0
    }
    
    # Process all keys in the state dict
    for key, value in model_state_dict.items():
        if key.startswith(layer_prefix):
            combined["overall"] += value
            counts["overall"] += 1

        if key.startswith(self_attn_prefix):
            combined["self_attn"] += value
            counts["self_attn"] += 1
        elif key.startswith(mlp_prefix):
            combined["mlp"] += value
            counts["mlp"] += 1

    return {k: round(100*v, 2) for k, v in combined.items()}

def combine_all_layers(model_state_dict, num_layers=24):
    all_layers = {}
    
    for i in range(num_layers):
        all_layers[f"layer_{i}"] = combine_layer_weights(model_state_dict, i)
    
    return all_layers

combined_weights = combine_all_layers(weight_contributions)

In [None]:
combined_weights

In [None]:
def visualize_attention(attn, title="Attention Map"):
    attn = attn.to(torch.float16).cpu().numpy()
    plt.imshow(attn, cmap='viridis')
    plt.colorbar()
    plt.title(title)
    plt.xlabel("Key")
    plt.ylabel("Query")
    plt.show()

In [None]:
for l, heads in attn_scores.items():
    print(l, sum(heads.values()))

In [None]:
layer_idx = 0
head_idx = 6
step_idx = 0

temp = attentions[f'layer_{layer_idx}_self_attn'][step_idx][0, head_idx]
print(temp.shape, temp.sum())
visualize_attention(temp)

In [None]:
def prune_attention_hook(prune_heads, num_heads):
    def hook(module, input, output):
        if isinstance(output, tuple):
            attn_output, attn_weights = output[:2]

        if prune_heads:
            attn_weights[:, prune_heads, :, :] = 0.0

        if isinstance(output, tuple):
            return (attn_output, attn_weights) + output[2:]
        else:
            return attn_weights
    return hook
