
# LLaMA 3.2 (3B) – Global Magnitude Pruning (20%) → Sparse CSR Export

This notebook:

1. Loads **LLaMA 3.2-Text-3B** in **FP32** (falls back to a public 1–3B-ish model if you don't have access).
2. Applies **20% global magnitude pruning** across all `nn.Linear.weight` tensors via `torch.nn.utils.prune.global_unstructured`.
3. **Removes** pruning reparameterizations so weights are *physically zeroed*.
4. Converts every linear **weight** to **CSR sparse** and saves a compact checkpoint.
5. Runs a tiny sanity generation before and after pruning.


In [1]:

!pip -q install transformers>=4.43.3 accelerate>=0.32.0 torch>=2.1 --extra-index-url https://download.pytorch.org/whl/cu121


In [1]:

import os, math, time, json, torch, platform, gc
import torch.nn as nn
from torch.nn.utils import prune
from transformers import AutoTokenizer, AutoModelForCausalLM

print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| Py:", platform.python_version())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
dtype = torch.float32  # <- keep FP32 as requested


Torch: 2.3.1+cu121 | CUDA: 12.1 | Py: 3.11.11
Device: cuda



## Load model (FP32)

We try the following in order (first one you have access to will be used):

- `meta-llama/Llama-3.2-3B`
- `meta-llama/Llama-3.2-3B-Instruct`
- Fallbacks: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `Qwen/Qwen2.5-1.5B-Instruct`


In [2]:

CANDIDATES = [
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "Qwen/Qwen2.5-1.5B-Instruct",
]

chosen = None
last_err = None
for name in CANDIDATES:
    try:
        print(f"Trying to load: {name} (FP32)...")
        tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True, use_fast=False)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        mdl = AutoModelForCausalLM.from_pretrained(
            name,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            device_map=None,
        )
        mdl.to(device)
        chosen = name
        print("Loaded:", name)
        break
    except Exception as e:
        last_err = e
        print("Failed:", name, "|", e.__class__.__name__, str(e)[:200])

if chosen is None:
    raise RuntimeError(f"Could not load any candidate model. Last error: {last_err}")

model, tokenizer = mdl, tok
model.eval();


Trying to load: TinyLlama/TinyLlama-1.1B-Chat-v1.0 (FP32)...


2025-08-25 02:14:43.138705: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-25 02:14:43.152560: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-25 02:14:43.170647: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-25 02:14:43.176251: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-25 02:14:43.189159: I tensorflow/core/platform/cpu_feature_guar

Loaded: TinyLlama/TinyLlama-1.1B-Chat-v1.0



## Quick pre‑prune sanity check


In [3]:

@torch.no_grad()
def sample_text(model, tokenizer, prompt: str, max_new_tokens=30):
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    out_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    new_tokens = out_ids[:, inputs["input_ids"].shape[1]:]
    return tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0]

demo_prompt = "Write a one-sentence fun fact about dolphins:"
print("Pre‑prune sample ->", sample_text(model, tokenizer, demo_prompt))


Pre‑prune sample -> 

Dolphins are the only mammals that can change the color of their skin.



## 20% Global Magnitude Pruning over Linear weights
We collect all `nn.Linear.weight` parameters and apply `global_unstructured` with `L1Unstructured`.


In [4]:
import torch, torch.nn as nn, math, gc

def gather_linear_weight_params(module):
    params = []
    for name, m in module.named_modules():
        if isinstance(m, nn.Linear) and getattr(m, "weight", None) is not None:
            params.append((m, "weight"))
    return params

@torch.no_grad()
def _hist_threshold_for_global(amount, params_to_prune, bins=2048):
    """
    First pass: build a global histogram of |weights| across all Linear layers (CPU),
    then find the magnitude threshold that prunes `amount` fraction globally.
    """
    # 1) min/max over abs weights, streamed
    gmin, gmax = float("inf"), 0.0
    total_elems = 0
    for mod, pname in params_to_prune:
        w = getattr(mod, pname).detach()
        total_elems += w.numel()
        # move per-layer to CPU to avoid GPU peak
        a = w.abs().float().cpu()
        gmin = min(gmin, float(a.min()))
        gmax = max(gmax, float(a.max()))
        del a
    if not math.isfinite(gmin): gmin = 0.0
    if gmax <= gmin:  # all zeros?
        return 0.0, total_elems

    # 2) build histogram over [gmin, gmax]
    hist = torch.zeros(bins, dtype=torch.long)
    for mod, pname in params_to_prune:
        a = getattr(mod, pname).detach().abs().float().cpu()
        # torch.histc returns float; use long counts
        h = torch.histc(a, bins=bins, min=gmin, max=gmax).to(torch.long)
        hist += h
        del a, h
    cum = torch.cumsum(hist, dim=0)

    # 3) pick cutoff bin for desired prune count
    k = int(amount * total_elems)
    k = max(0, min(k, total_elems - 1))
    idx = int(torch.searchsorted(cum, torch.tensor(k, dtype=torch.long)))
    idx = max(0, min(idx, bins - 1))

    # 4) convert bin index -> magnitude threshold (upper edge of that bin)
    bin_width = (gmax - gmin) / bins
    threshold = gmin + (idx + 1) * bin_width
    return float(threshold), total_elems

@torch.no_grad()
def global_prune_linear_weights_streamed(model, amount=0.20, bins=2048):
    """
    Global unstructured magnitude pruning without allocating a giant vector on GPU.
    Finds a single global |w| threshold, then zeroes weights layer-by-layer.
    """
    params = gather_linear_weight_params(model)
    total = sum(getattr(m, p).numel() for m, p in params)
    print(f"Linear weights found: {len(params)} modules | {total/1e6:.2f}M params")

    thr, total_elems = _hist_threshold_for_global(amount, params, bins=bins)
    print(f"Computed global threshold |w| < {thr:.6g} to prune ~{amount:.0%} of {total_elems/1e6:.2f}M params")

    # Second pass: apply mask per layer (on current device)
    pruned = 0
    for mod, pname in params:
        W = getattr(mod, pname)
        mask = (W.abs() >= thr)
        pruned += (mask.numel() - mask.sum()).item()
        W.mul_(mask.to(W.dtype))  # in-place zero
        del mask
    sparsity = pruned / max(1, total)
    print(f"Applied global pruning. Sparsity: {sparsity:.2%}")
    torch.cuda.empty_cache(); gc.collect()

# (Optional) Convert dense weights to sparse CSR tensors for storage/reporting.
# Note: PyTorch Linear expects dense weights for forward(). Use this only to export.
@torch.no_grad()
def export_sparse_state_dict(model):
    sd = {}
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear) and getattr(m, "weight", None) is not None:
            W = m.weight.detach().to("cpu")
            # make CSR from 2D matrix
            idx = W.nonzero(as_tuple=False)
            if idx.numel() == 0:
                values = torch.tensor([], dtype=W.dtype)
                crow   = torch.zeros(W.size(0)+1, dtype=torch.int64)
                col    = torch.tensor([], dtype=torch.int64)
            else:
                # build CSR row pointers
                rows = idx[:, 0]
                cols = idx[:, 1]
                # sort by row then col
                order = torch.lexsort((cols, rows)) if hasattr(torch, "lexsort") else torch.argsort(rows*W.size(1)+cols)
                rows = rows[order]; cols = cols[order]
                values = W[rows, cols]
                crow = torch.zeros(W.size(0)+1, dtype=torch.int64)
                crow.index_add_(0, rows, torch.ones_like(rows, dtype=torch.int64))
                crow = torch.cumsum(crow, dim=0)
                col = cols.clone()
            sd[f"{name}.weight_csr"] = {
                "shape": tuple(W.shape),
                "crow_indices": crow,
                "col_indices": col,
                "values": values,
            }
    return sd



## Convert pruned Linear weights to CSR sparse and save
This exports a compact mapping `{linear_name.weight: CSR tensor}` to `pruned_sparse_ckpt/linear_weights_csr.pt`.


In [6]:
from pathlib import Path

save_dir = Path("pruned_sparse_ckpt")
save_dir.mkdir(parents=True, exist_ok=True)

sparse_dump = {}
nonzero_total = 0
elem_total = 0

with torch.no_grad():
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and hasattr(module, "weight") and module.weight is not None:
            W = module.weight.detach().to("cpu", copy=True)
            elem_total += W.numel()
            nonzero_total += (W != 0).sum().item()
            W_csr = W.to_sparse_csr()
            sparse_dump[f"{name}.weight"] = W_csr

torch.save(sparse_dump, save_dir / "linear_weights_csr.pt")
print("Saved CSR sparse weights to:", save_dir / "linear_weights_csr.pt")
print(f"Overall density: {nonzero_total/elem_total:.2%} | sparsity: {1 - nonzero_total/elem_total:.2%}")


  W_csr = W.to_sparse_csr()


Saved CSR sparse weights to: pruned_sparse_ckpt/linear_weights_csr.pt
Overall density: 100.00% | sparsity: 0.00%



## (Optional) Save dense pruned checkpoint
Useful if you want to reload with standard APIs later.


In [7]:

dense_path = save_dir / "pruned_dense_state_dict.pt"
torch.save(model.state_dict(), dense_path)
print("Saved dense pruned state_dict to:", dense_path)


Saved dense pruned state_dict to: pruned_sparse_ckpt/pruned_dense_state_dict.pt



## Post‑prune sanity check


In [8]:

print("Post‑prune sample ->", sample_text(model, tokenizer, demo_prompt))


Post‑prune sample -> 

Dolphins are the only mammals that can change the color of their skin.
