In [1]:
# RMSNorm Statistics

In [9]:
import torch
import numpy as np
import os
from utils import load_model
def save_rmsnorm_weights_with_stats(model, save_path="rmsnorm_weights.npz"):
    """
    Extracts and saves all RMSNorm weight vectors from a LLaMA model into a .npz file.
    Also prints basic statistics (mean, std, min, max, norm) for each RMSNorm layer.
    
    The saved file contains the ACTUAL weight vectors, not just the stats.
    """
    rms_weights = {}

    print("\n=== Extracting RMSNorm weights ===")
    # Iterate through all transformer layers
    for i, layer in enumerate(model.model.layers):
        for name, module in layer.named_modules():
            if "norm" in name.lower() and hasattr(module, "weight"):
                key = f"layer_{i}_{name}"
                w = module.weight.detach().cpu().float().numpy()
                rms_weights[key] = w

                # Print stats for this layer
                print(f"{key:45s} | mean={w.mean():.6f} | std={w.std():.6f} | "
                      f"min={w.min():.6f} | max={w.max():.6f} | norm={np.linalg.norm(w):.6f}")

    # Final RMSNorm before LM head
    if hasattr(model.model, "norm"):
        key = "final_norm"
        w = model.model.norm.weight.detach().cpu().float().numpy()
        rms_weights[key] = w

        print(f"{key:45s} | mean={w.mean():.6f} | std={w.std():.6f} | "
              f"min={w.min():.6f} | max={w.max():.6f} | norm={np.linalg.norm(w):.6f}")

    # Save all weights
    os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
    np.savez_compressed(save_path, **rms_weights)

    print(f"\n✅ Saved {len(rms_weights)} RMSNorm vectors to '{save_path}'")
    return list(rms_weights.keys())


In [11]:
MODEL_PATH = "/home/chashi/Desktop/Research/My Projects/models/Llama-3.1-8B-Instruct"
model, tokenizer = load_model(MODEL_PATH)

keys = save_rmsnorm_weights_with_stats(model, "../results/llama_rmsnorm_weights.npz")

print("\nExample saved keys:", keys[:5])

Loading tokenizer...
Loading model...


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.38it/s]

Model loaded successfully on device: cuda:0
Vocabulary size: 128256

=== Extracting RMSNorm weights ===
layer_0_input_layernorm                       | mean=0.070758 | std=0.121427 | min=-0.210938 | max=1.085938 | norm=8.994487
layer_0_post_attention_layernorm              | mean=0.133770 | std=0.011952 | min=-0.000138 | max=0.234375 | norm=8.595359
layer_1_input_layernorm                       | mean=0.105187 | std=0.081226 | min=-0.012451 | max=0.890625 | norm=8.505515
layer_1_post_attention_layernorm              | mean=0.166858 | std=0.022191 | min=-0.000052 | max=0.281250 | norm=10.772962
layer_2_input_layernorm                       | mean=0.305289 | std=0.098434 | min=-0.001221 | max=0.843750 | norm=20.528986
layer_2_post_attention_layernorm              | mean=0.208197 | std=0.029590 | min=-0.000123 | max=0.229492 | norm=13.458517
layer_3_input_layernorm                       | mean=0.346206 | std=0.035772 | min=0.002411 | max=0.644531 | norm=22.275162
layer_3_post_attention_la


