# Find Important LoRA Layers

Find the most important layers in a LoRA, using a variety of measurements. 

By default, this uses the softmax of the SVD spectral norm.

Some models will show upwards of 90% of the softmax'd score in a single layer, while others will spread the data across 20 or more layers. 

After training a LoRA using all layers, you can analyze the most significant layers using this notebooks, then target those layers for future training.

#### TODO: include ai-toolkit layer config

In [1]:
import safetensors.torch
from __future__ import annotations
import re
from typing import Dict, Callable, List, Tuple, Any, Optional
from tqdm.notebook import tqdm
import math
import torch
import pandas as pd
import numpy as np

In [11]:
lora = safetensors.torch.load_file("/mnt/models/tensors/loras/qwen_image/cumpire_qwen_v6_000005500.safetensors")

In [12]:
keys = list(lora.keys())
keys = keys[:10]
keys

['diffusion_model.transformer_blocks.0.attn.add_k_proj.lora_A.weight',
 'diffusion_model.transformer_blocks.0.attn.add_k_proj.lora_B.weight',
 'diffusion_model.transformer_blocks.0.attn.add_q_proj.lora_A.weight',
 'diffusion_model.transformer_blocks.0.attn.add_q_proj.lora_B.weight',
 'diffusion_model.transformer_blocks.0.attn.add_v_proj.lora_A.weight',
 'diffusion_model.transformer_blocks.0.attn.add_v_proj.lora_B.weight',
 'diffusion_model.transformer_blocks.0.attn.to_add_out.lora_A.weight',
 'diffusion_model.transformer_blocks.0.attn.to_add_out.lora_B.weight',
 'diffusion_model.transformer_blocks.0.attn.to_k.lora_A.weight',
 'diffusion_model.transformer_blocks.0.attn.to_k.lora_B.weight']

In [17]:
# ----------------------------
# Helpers: pairing & key parsing
# ----------------------------
_A_SUFFIX = ".lora_A.weight"
_B_SUFFIX = ".lora_B.weight"

def find_lora_pairs(state_dict: Dict[str, torch.Tensor]) -> List[str]:
    """
    Discover base keys that have both lora_A and lora_B weights.
    Example key:
    'diffusion_model.transformer_blocks.0.attn.add_k_proj.lora_A.weight'
    -> base: 'diffusion_model.transformer_blocks.0.attn.add_k_proj'
    """
    keys = list(state_dict.keys())
    bases_A = {k[:-len(_A_SUFFIX)] for k in keys if k.endswith(_A_SUFFIX)}
    bases_B = {k[:-len(_B_SUFFIX)] for k in keys if k.endswith(_B_SUFFIX)}
    bases = sorted(bases_A.intersection(bases_B))
    return bases

def get_pair(state_dict: Dict[str, torch.Tensor], base: str) -> Tuple[torch.Tensor, torch.Tensor]:
    A = state_dict[base + _A_SUFFIX]
    B = state_dict[base + _B_SUFFIX]
    return A, B


# ----------------------------
# Statistic functions (extensible registry)
# ----------------------------
StatFn = Callable[[torch.Tensor], Dict[str, Any]]

def _ensure_1d(x: torch.Tensor) -> torch.Tensor:
    if x.is_complex():
        x = x.abs()
    return x.reshape(-1)

def stat_basic(x: torch.Tensor) -> Dict[str, Any]:
    v = _ensure_1d(x)
    return {
        "min": float(v.min().item()),
        "max": float(v.max().item()),
        "mean": float(v.mean().item()),
        "std": float(v.std(unbiased=False).item()),
        "median": float(v.median().item()),
    }

def stat_sign(x: torch.Tensor) -> Dict[str, Any]:
    v = _ensure_1d(x)
    gt0 = (v > 0).sum().item()
    lt0 = (v < 0).sum().item()
    eq0 = (v == 0).sum().item()
    n = v.numel()
    return {
        "count": int(n),
        "pos": int(gt0),
        "neg": int(lt0),
        "zero": int(eq0),
        "pos_frac": float(gt0 / n if n else 0.0),
        "neg_frac": float(lt0 / n if n else 0.0),
        "zero_frac": float(eq0 / n if n else 0.0),
    }

def stat_norms(x: torch.Tensor) -> Dict[str, Any]:
    v = _ensure_1d(x)
    l1 = float(v.abs().sum().item())
    l2 = float(torch.linalg.norm(v, ord=2).item())
    linf = float(v.abs().max().item())
    return {"l1": l1, "l2": l2, "linf": linf}

def stat_percentiles(x: torch.Tensor, ps: List[float] = [1, 5, 25, 50, 75, 95, 99]) -> Dict[str, Any]:
    v = _ensure_1d(x).cpu().numpy()
    out = {}
    try:
        q = np.percentile(v, ps)
        for p, val in zip(ps, q):
            out[f"p{int(p)}"] = float(val)
    except Exception:
        # fallback if array is empty or other numeric oddity
        for p in ps:
            out[f"p{int(p)}"] = float("nan")
    return out

def _power_iteration_spectral_norm(W: torch.Tensor, iters: int = 50) -> float:
    """
    Fast approximation of spectral norm for large matrices.
    """
    device = W.device
    m, n = W.shape
    # Choose smaller side for u/v init to stabilize
    if m >= n:
        v = torch.randn(n, device=device)
        v = v / (v.norm() + 1e-12)
        for _ in range(iters):
            u = (W @ v)
            u = u / (u.norm() + 1e-12)
            v = (W.t() @ u)
            v = v / (v.norm() + 1e-12)
        sigma = (u @ (W @ v)).item()
    else:
        u = torch.randn(m, device=device)
        u = u / (u.norm() + 1e-12)
        for _ in range(iters):
            v = (W.t() @ u)
            v = v / (v.norm() + 1e-12)
            u = (W @ v)
            u = u / (u.norm() + 1e-12)
        sigma = (u @ (W @ v)).item()
    return float(abs(sigma))

def stat_svd(W: torch.Tensor, topk: int = 8, svd_exact_max_elems: int = 3072 * 3072) -> Dict[str, Any]:
    """
    SVD-driven stats for a 2D matrix.
    - Uses exact svdvals if matrix elements <= svd_exact_max_elems
    - Otherwise, computes spectral norm via power iteration and skips full spectrum
    Returns:
      - spectral_norm
      - nuclear_norm (if exact)
      - energy_topk, energy_total, energy_ratio_topk
      - top singular values (s1..sk)
    """
    if W.ndim != 2:
        return {"spectral_norm": float("nan")}
    m, n = W.shape
    elements = m * n
    out = {}
    if elements <= svd_exact_max_elems:
        # exact (or at least standard) SVD for singular values
        try:
            s = torch.linalg.svdvals(W)
            s_sorted = torch.sort(s, descending=True).values
            spectral = float(s_sorted[0].item()) if s_sorted.numel() > 0 else float("nan")
            out["spectral_norm"] = spectral
            total_energy = float((s_sorted ** 2).sum().item())
            out["energy_total"] = total_energy
            k = min(topk, s_sorted.numel())
            topk_vals = s_sorted[:k]
            energy_topk = float((topk_vals ** 2).sum().item())
            out["energy_topk"] = energy_topk
            out["energy_ratio_topk"] = float(energy_topk / total_energy) if total_energy > 0 else float("nan")
            out["nuclear_norm"] = float(s_sorted.sum().item())
            for i, val in enumerate(topk_vals.tolist(), start=1):
                out[f"s{i}"] = float(val)
        except RuntimeError:
            # fallback to power iteration on numerical issues
            out["spectral_norm"] = _power_iteration_spectral_norm(W)
    else:
        # too big for exact SVD: give spectral norm only
        out["spectral_norm"] = _power_iteration_spectral_norm(W)
    return out


# Registry of stat functions (add your own here)
TENSOR_STATS: List[Tuple[str, StatFn]] = [
    ("basic", stat_basic),
    ("sign", stat_sign),
    ("norms", stat_norms),
    ("percentiles", stat_percentiles),
]
# For ΔW matrix-specific stats (2D only) we also run SVD
DELTA_MATRIX_STATS: List[Tuple[str, Callable[[torch.Tensor], Dict[str, Any]]]] = [
    ("basic", stat_basic),
    ("sign", stat_sign),
    ("norms", stat_norms),
    ("percentiles", stat_percentiles),
    ("svd", stat_svd),  # includes spectral norm & energy capture
]


# ----------------------------
# Core analysis
# ----------------------------
def compute_stats(x: torch.Tensor,
                  fns: List[Tuple[str, StatFn]]) -> Dict[str, Any]:
    """
    Run a list of (name, fn) over the tensor and merge their dict outputs.
    """
    out = {}
    for name, fn in fns:
        try:
            res = fn(x)
            # namespace keys by function name
            for k, v in res.items():
                out[f"{name}.{k}"] = v
        except Exception as e:
            out[f"{name}.error"] = f"{type(e).__name__}: {e}"
    return out


def analyze_lora(
    state_dict: Dict[str, torch.Tensor],
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
    svd_topk: int = 8,
    svd_exact_max_elems: int = 3072 * 3072,
    return_df: bool = True,
) -> pd.DataFrame | List[Dict[str, Any]]:
    """
    Analyze LoRA layers (A, B, and ΔW=B@A) and produce per-layer stats.

    Args:
        state_dict: dict-like with LoRA weights
        device: move tensors to this device for analysis (default: current)
        dtype:  convert tensors to this dtype for analysis (e.g., torch.float32)
        svd_topk: number of top singular values/energy to report for ΔW
        svd_exact_max_elems: threshold for exact SVD (#elements)
        return_df: if True, return pandas DataFrame; else return list of dicts
    """
    results = []
    bases = find_lora_pairs(state_dict)

    for base in tqdm(bases):
        A, B = get_pair(state_dict, base)

        # Dtype/device management
        if device is not None:
            A = A.to(device)
            B = B.to(device)
        if dtype is not None:
            A = A.to(dtype)
            B = B.to(dtype)

        # Shapes: expect A: [r, in], B: [out, r]
        r, in_dim = list(A.shape)
        out_dim, rB = list(B.shape)
        # Sanity check for rank
        rank_match = (r == rB)

        # ΔW = B @ A
        try:
            deltaW = B @ A  # [out_dim, in_dim]
        except Exception as e:
            # record error and skip delta stats
            deltaW = None
            delta_err = f"{type(e).__name__}: {e}"
        else:
            delta_err = None

        # Stats for A & B
        stats_A = compute_stats(A, TENSOR_STATS)
        stats_B = compute_stats(B, TENSOR_STATS)

        # Stats for ΔW
        stats_dW = {}
        if deltaW is not None:
            # inject config for SVD
            def _svd_cfg(W: torch.Tensor) -> Dict[str, Any]:
                return stat_svd(W, topk=svd_topk, svd_exact_max_elems=svd_exact_max_elems)
            # Replace the svd entry temporarily to pass parameters
            d_matrix_stats = [
                (n, f if n != "svd" else _svd_cfg) for (n, f) in DELTA_MATRIX_STATS
            ]
            stats_dW = compute_stats(deltaW, d_matrix_stats)

        # Aggregate row
        row = {
            "layer": base,
            "A.shape": tuple(A.shape),
            "B.shape": tuple(B.shape),
            "deltaW.shape": tuple(deltaW.shape) if deltaW is not None else None,
            "rank": int(r),
            "rank_match": bool(rank_match),
            "in_dim": int(in_dim),
            "out_dim": int(out_dim),
            "delta_error": delta_err,
        }
        # namespace A/B/ΔW stats
        row.update({f"A.{k}": v for k, v in stats_A.items()})
        row.update({f"B.{k}": v for k, v in stats_B.items()})
        row.update({f"dW.{k}": v for k, v in stats_dW.items()})
        results.append(row)

        # Free intermediate to reduce peak memory when many layers
        del A, B
        if deltaW is not None:
            del deltaW
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return pd.DataFrame(results) if return_df else results


# ----------------------------
# Convenience: run & display
# ----------------------------
def analyze_and_display(
    lora_state: Dict[str, torch.Tensor],
    device: Optional[str] = None,
    dtype: Optional[torch.dtype] = torch.float32,
    sort_by: str = "dW.svd.spectral_norm",
    ascending: bool = False,
) -> pd.DataFrame:
    """
    One-liner for notebooks: analyze and show a sorted summary.
    Typical sort keys:
      - "dW.svd.spectral_norm"
      - "A.norms.l2"
      - "B.norms.l2"
      - "dW.energy_ratio_topk"
    """
    dev = torch.device(device) if device else None
    df = analyze_lora(lora_state, device=dev, dtype=dtype)
    if sort_by in df.columns:
        df = df.sort_values(sort_by, ascending=ascending, na_position="last").reset_index(drop=True)
    display(df)  # notebook display
    return df


In [18]:
df = analyze_and_display(lora, device="cpu", dtype=torch.float32,
                         sort_by="dW.svd.spectral_norm", ascending=False)

  0%|          | 0/840 [00:00<?, ?it/s]

Unnamed: 0,layer,A.shape,B.shape,deltaW.shape,rank,rank_match,in_dim,out_dim,delta_error,A.basic.min,...,dW.svd.energy_ratio_topk,dW.svd.nuclear_norm,dW.svd.s1,dW.svd.s2,dW.svd.s3,dW.svd.s4,dW.svd.s5,dW.svd.s6,dW.svd.s7,dW.svd.s8
0,diffusion_model.transformer_blocks.9.txt_mod.1,"(32, 3072)","(18432, 32)","(18432, 3072)",32,True,3072,18432,,-0.049072,...,,,,,,,,,,
1,diffusion_model.transformer_blocks.33.img_mod.1,"(32, 3072)","(18432, 32)","(18432, 3072)",32,True,3072,18432,,-0.120117,...,,,,,,,,,,
2,diffusion_model.transformer_blocks.37.img_mlp....,"(32, 3072)","(12288, 32)","(12288, 3072)",32,True,3072,12288,,-0.091797,...,,,,,,,,,,
3,diffusion_model.transformer_blocks.30.img_mod.1,"(32, 3072)","(18432, 32)","(18432, 3072)",32,True,3072,18432,,-0.088867,...,,,,,,,,,,
4,diffusion_model.transformer_blocks.17.img_mod.1,"(32, 3072)","(18432, 32)","(18432, 3072)",32,True,3072,18432,,-0.045410,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
835,diffusion_model.transformer_blocks.56.txt_mlp....,"(32, 12288)","(3072, 32)","(3072, 12288)",32,True,12288,3072,,-0.009033,...,,,,,,,,,,
836,diffusion_model.transformer_blocks.59.txt_mlp....,"(32, 3072)","(12288, 32)","(12288, 3072)",32,True,3072,12288,,-0.018066,...,,,,,,,,,,
837,diffusion_model.transformer_blocks.59.txt_mlp....,"(32, 12288)","(3072, 32)","(3072, 12288)",32,True,12288,3072,,-0.009033,...,,,,,,,,,,
838,diffusion_model.transformer_blocks.59.attn.to_...,"(32, 3072)","(3072, 32)","(3072, 3072)",32,True,3072,3072,,-0.018066,...,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [20]:
df[0:1]

Unnamed: 0,layer,A.shape,B.shape,deltaW.shape,rank,rank_match,in_dim,out_dim,delta_error,A.basic.min,...,dW.svd.energy_ratio_topk,dW.svd.nuclear_norm,dW.svd.s1,dW.svd.s2,dW.svd.s3,dW.svd.s4,dW.svd.s5,dW.svd.s6,dW.svd.s7,dW.svd.s8
0,diffusion_model.transformer_blocks.9.txt_mod.1,"(32, 3072)","(18432, 32)","(18432, 3072)",32,True,3072,18432,,-0.049072,...,,,,,,,,,,


In [21]:
df.columns

Index(['layer', 'A.shape', 'B.shape', 'deltaW.shape', 'rank', 'rank_match',
       'in_dim', 'out_dim', 'delta_error', 'A.basic.min', 'A.basic.max',
       'A.basic.mean', 'A.basic.std', 'A.basic.median', 'A.sign.count',
       'A.sign.pos', 'A.sign.neg', 'A.sign.zero', 'A.sign.pos_frac',
       'A.sign.neg_frac', 'A.sign.zero_frac', 'A.norms.l1', 'A.norms.l2',
       'A.norms.linf', 'A.percentiles.p1', 'A.percentiles.p5',
       'A.percentiles.p25', 'A.percentiles.p50', 'A.percentiles.p75',
       'A.percentiles.p95', 'A.percentiles.p99', 'B.basic.min', 'B.basic.max',
       'B.basic.mean', 'B.basic.std', 'B.basic.median', 'B.sign.count',
       'B.sign.pos', 'B.sign.neg', 'B.sign.zero', 'B.sign.pos_frac',
       'B.sign.neg_frac', 'B.sign.zero_frac', 'B.norms.l1', 'B.norms.l2',
       'B.norms.linf', 'B.percentiles.p1', 'B.percentiles.p5',
       'B.percentiles.p25', 'B.percentiles.p50', 'B.percentiles.p75',
       'B.percentiles.p95', 'B.percentiles.p99', 'dW.basic.min',
       'dW.

In [28]:
df[['layer', 'dW.svd.spectral_norm']]

Unnamed: 0,layer,dW.svd.spectral_norm
0,diffusion_model.transformer_blocks.9.txt_mod.1,25.801443
1,diffusion_model.transformer_blocks.33.img_mod.1,22.892666
2,diffusion_model.transformer_blocks.37.img_mlp....,20.750673
3,diffusion_model.transformer_blocks.30.img_mod.1,20.218285
4,diffusion_model.transformer_blocks.17.img_mod.1,19.840977
...,...,...
835,diffusion_model.transformer_blocks.56.txt_mlp....,0.009179
836,diffusion_model.transformer_blocks.59.txt_mlp....,0.000000
837,diffusion_model.transformer_blocks.59.txt_mlp....,0.000000
838,diffusion_model.transformer_blocks.59.attn.to_...,0.000000


In [31]:
df_filtered = df.loc[df["dW.svd.spectral_norm"] > 10, ["layer", "dW.svd.spectral_norm"]]
df_filtered

Unnamed: 0,layer,dW.svd.spectral_norm
0,diffusion_model.transformer_blocks.9.txt_mod.1,25.801443
1,diffusion_model.transformer_blocks.33.img_mod.1,22.892666
2,diffusion_model.transformer_blocks.37.img_mlp....,20.750673
3,diffusion_model.transformer_blocks.30.img_mod.1,20.218285
4,diffusion_model.transformer_blocks.17.img_mod.1,19.840977
...,...,...
107,diffusion_model.transformer_blocks.39.attn.add...,10.174883
108,diffusion_model.transformer_blocks.35.attn.add...,10.169664
109,diffusion_model.transformer_blocks.18.txt_mlp....,10.137526
110,diffusion_model.transformer_blocks.32.txt_mod.1,10.099074


In [32]:
import numpy as np

# pick the column of interest
vals = df["dW.svd.spectral_norm"].to_numpy(dtype=float)

# subtract max for stability, then exponentiate
exp_vals = np.exp(vals - np.max(vals))
softmax_vals = exp_vals / exp_vals.sum()

# add as a new column
df["dW.softmax_norm"] = softmax_vals

# sort to see the most significant layers
df_sorted = df[["layer", "dW.svd.spectral_norm", "dW.softmax_norm"]] \
    .sort_values("dW.softmax_norm", ascending=False) \
    .reset_index(drop=True)

display(df_sorted)

Unnamed: 0,layer,dW.svd.spectral_norm,dW.softmax_norm
0,diffusion_model.transformer_blocks.9.txt_mod.1,25.801443,9.324702e-01
1,diffusion_model.transformer_blocks.33.img_mod.1,22.892666,5.085915e-02
2,diffusion_model.transformer_blocks.37.img_mlp....,20.750673,5.971914e-03
3,diffusion_model.transformer_blocks.30.img_mod.1,20.218285,3.506712e-03
4,diffusion_model.transformer_blocks.17.img_mod.1,19.840977,2.404569e-03
...,...,...,...
835,diffusion_model.transformer_blocks.56.txt_mlp....,0.009179,5.864043e-12
836,diffusion_model.transformer_blocks.59.txt_mlp....,0.000000,5.810461e-12
837,diffusion_model.transformer_blocks.59.txt_mlp....,0.000000,5.810461e-12
838,diffusion_model.transformer_blocks.59.attn.to_...,0.000000,5.810461e-12


In [33]:
df_sorted[0:100]

Unnamed: 0,layer,dW.svd.spectral_norm,dW.softmax_norm
0,diffusion_model.transformer_blocks.9.txt_mod.1,25.801443,9.324702e-01
1,diffusion_model.transformer_blocks.33.img_mod.1,22.892666,5.085915e-02
2,diffusion_model.transformer_blocks.37.img_mlp....,20.750673,5.971914e-03
3,diffusion_model.transformer_blocks.30.img_mod.1,20.218285,3.506712e-03
4,diffusion_model.transformer_blocks.17.img_mod.1,19.840977,2.404569e-03
...,...,...,...
95,diffusion_model.transformer_blocks.43.img_mod.1,10.789688,2.819114e-07
96,diffusion_model.transformer_blocks.29.img_mod.1,10.768396,2.759725e-07
97,diffusion_model.transformer_blocks.40.txt_mlp....,10.764683,2.749495e-07
98,diffusion_model.transformer_blocks.31.img_mlp....,10.694803,2.563921e-07


In [34]:
import re
import numpy as np
import pandas as pd

def block_importance_from_softmax(df: pd.DataFrame,
                                  layer_col: str = "layer",
                                  norm_col: str = "dW.svd.spectral_norm",
                                  pattern: str = r"transformer_blocks\.(\d+)"
                                 ) -> pd.DataFrame:
    """
    - Extracts block index from layer key via regex.
    - Softmaxes spectral norms across layers.
    - Sums softmax weights per block to get block importance.
    Returns a DataFrame with per-block scores.
    """
    # 1) grab values; guard against NaNs/infs
    vals = pd.to_numeric(df[norm_col], errors="coerce").fillna(-np.inf).to_numpy()
    finite_mask = np.isfinite(vals)
    if not finite_mask.any():
        raise ValueError("No finite spectral norms available for softmax.")

    # 2) stable softmax over LAYERS
    vmax = np.max(vals[finite_mask])
    exp = np.zeros_like(vals, dtype=float)
    exp[finite_mask] = np.exp(vals[finite_mask] - vmax)
    softmax = exp / exp.sum()

    # 3) parse block index
    block_idx = df[layer_col].str.extract(pattern, expand=False)
    if block_idx.isna().any():
        # optional: drop unparseable rows
        # or raise; here we drop them from impact math
        softmax = pd.Series(softmax, index=df.index)
        keep = block_idx.notna()
        block_idx = block_idx[keep]
        softmax = softmax[keep]
        working = df.loc[keep, [layer_col, norm_col]].copy()
    else:
        working = df[[layer_col, norm_col]].copy()
        softmax = pd.Series(softmax, index=df.index)

    working["block"] = block_idx.astype(int)
    working["softmax_norm"] = softmax.values

    # 4) aggregate per block
    agg = (
        working
        .groupby("block", as_index=False)
        .agg(
            block_importance=("softmax_norm", "sum"),   # share of total impact
            layers=("softmax_norm", "size"),
            max_norm=(norm_col, "max"),
            mean_norm=(norm_col, "mean"),
        )
        .sort_values("block_importance", ascending=False, ignore_index=True)
    )

    return agg, working


In [37]:
agg_blocks, per_layer = block_importance_from_softmax(df)
display(agg_blocks.head(10))      # most significant blocks

Unnamed: 0,block,block_importance,layers,max_norm,mean_norm
0,9,0.932471,14,25.801443,6.732249
1,33,0.050871,14,22.892666,7.804079
2,37,0.006049,14,20.750673,9.015193
3,30,0.0036,14,20.218285,8.4999
4,17,0.002427,14,19.840977,6.631526
5,38,0.001347,14,19.237164,8.224072
6,2,0.000794,14,18.732994,5.663805
7,41,0.000448,14,18.02359,7.89482
8,36,0.000273,14,17.633244,8.048292
9,29,0.000225,14,17.445343,7.079424


In [38]:
display(per_layer.head())         # per-layer with block + softmax weight

Unnamed: 0,layer,dW.svd.spectral_norm,block,softmax_norm
0,diffusion_model.transformer_blocks.9.txt_mod.1,25.801443,9,0.93247
1,diffusion_model.transformer_blocks.33.img_mod.1,22.892666,33,0.050859
2,diffusion_model.transformer_blocks.37.img_mlp....,20.750673,37,0.005972
3,diffusion_model.transformer_blocks.30.img_mod.1,20.218285,30,0.003507
4,diffusion_model.transformer_blocks.17.img_mod.1,19.840977,17,0.002405


In [39]:
agg_blocks.sort_values('mean_norm', ascending=False)[0:10]

Unnamed: 0,block,block_importance,layers,max_norm,mean_norm
2,37,0.006049024,14,20.750673,9.015193
3,30,0.003600428,14,20.218285,8.4999
5,38,0.001347241,14,19.237164,8.224072
8,36,0.0002728016,14,17.633244,8.048292
7,41,0.0004479018,14,18.02359,7.89482
1,33,0.05087057,14,22.892666,7.804079
24,42,3.542766e-05,14,15.195049,7.743481
20,39,4.25136e-05,14,15.565969,7.728747
14,32,0.0001136685,14,16.106354,7.699471
25,43,3.428324e-05,14,15.530858,7.509102


In [40]:
def top_blocks_by_threshold(
    df: pd.DataFrame,
    threshold: float = 0.90,
    layer_col: str = "layer",
    norm_col: str = "dW.svd.spectral_norm",
    return_layers: bool = False,
):
    """
    Return the minimal set of blocks whose cumulative softmax-based importance
    reaches or exceeds `threshold`.

    Relies on block_importance_from_softmax(df) defined earlier.
    """
    if not (0.0 < threshold <= 1.0):
        raise ValueError("threshold must be in (0, 1].")

    agg, per_layer = block_importance_from_softmax(
        df, layer_col=layer_col, norm_col=norm_col
    )
    if agg.empty:
        return (agg, per_layer.iloc[0:0]) if return_layers else agg

    # 'agg' is already sorted by block_importance desc
    cum = agg["block_importance"].cumsum().to_numpy()
    # index of first block where cumulative >= threshold
    idx = np.searchsorted(cum, threshold, side="left")
    idx = min(idx, len(cum) - 1)

    top = agg.iloc[:idx + 1].copy()
    top["cumulative"] = top["block_importance"].cumsum()
    top = top.reset_index(drop=True)

    if return_layers:
        keep_blocks = set(top["block"].tolist())
        layers = (
            per_layer[per_layer["block"].isin(keep_blocks)]
            .sort_values("softmax_norm", ascending=False)
            .reset_index(drop=True)
        )
        return top, layers

    return top

In [41]:
top90 = top_blocks_by_threshold(df, threshold=0.90)
display(top90)

Unnamed: 0,block,block_importance,layers,max_norm,mean_norm,cumulative
0,9,0.932471,14,25.801443,6.732249,0.932471


In [42]:
# get blocks covering 99%
top99 = top_blocks_by_threshold(df, threshold=0.99)
display(top99)

Unnamed: 0,block,block_importance,layers,max_norm,mean_norm,cumulative
0,9,0.932471,14,25.801443,6.732249,0.932471
1,33,0.050871,14,22.892666,7.804079,0.983341
2,37,0.006049,14,20.750673,9.015193,0.98939
3,30,0.0036,14,20.218285,8.4999,0.992991
