In [None]:
import argparse
import os
from typing import Optional, Tuple

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from datasets import load_dataset

In [None]:
model_id = "distilbert/distilgpt2"

model_svd = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [None]:
def compute_svd(W: torch.Tensor, k: Optional[int] = None, energy: Optional[float] = None,
                sample_rows: int = 0, sample_cols: int = 0) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """Compute (approx) SVD on CPU. Returns (A, B, r) where
    W (out, in) ~= A (out, r) @ B (r, in)
    If k provided, use top-k. If energy provided, choose minimum k capturing that energy.
    This function always computes SVD (possibly on a sampled matrix) then returns truncated factors.
    """
    assert W.dim() == 2
    orig_dtype = W.dtype

    # Move to CPU float32 for stable SVD
    W_cpu = W.detach().to(torch.float32).cpu()

    # Optional simple uniform sampling
    if sample_rows > 0 and sample_rows < W_cpu.shape[0]:
        rows = torch.linspace(0, W_cpu.shape[0] - 1, steps=sample_rows).long()
        W_sample = W_cpu[rows]
    else:
        W_sample = W_cpu

    if sample_cols > 0 and sample_cols < W_sample.shape[1]:
        cols = torch.linspace(0, W_sample.shape[1] - 1, steps=sample_cols).long()
        W_s = W_sample[:, cols]
    else:
        W_s = W_sample

    # compute SVD on the sampled matrix
    with torch.no_grad():
        try:
            U, S, Vt = torch.linalg.svd(W_s, full_matrices=False)
        except RuntimeError:
            U, S, V = torch.svd(W_s)
            Vt = V.t()

    # determine rank r from arguments (if any)
    if k is not None:
        r = min(k, S.numel())
    elif energy is not None:
        sv2 = S * S
        cum = torch.cumsum(sv2, dim=0)
        tot = cum[-1]
        r = int(torch.searchsorted(cum, tot * float(energy)).item() + 1)
        r = min(r, S.numel())
    else:
        r = S.numel()

    # If truncated, recompute top-r on full W_cpu to reduce sampling artifacts
    if r < min(W_cpu.shape):
        try:
            U_full, S_full, Vt_full = torch.linalg.svd(W_cpu, full_matrices=False)
        except RuntimeError:
            U_full, S_full, V_full = torch.svd(W_cpu)
            Vt_full = V_full.t()

        U_r = U_full[:, :r]
        S_r = S_full[:r]
        Vt_r = Vt_full[:r, :]
    else:
        U_r = U
        S_r = S
        Vt_r = Vt
        r = S_r.numel()

    A = (U_r * S_r.unsqueeze(0))  # (out, r)
    B = Vt_r  # (r, in)

    return A.to(orig_dtype), B.to(orig_dtype), r


def replace_linear_with_lowrank(module_parent, name: str, module: nn.Linear, A: torch.Tensor, B: torch.Tensor,
                                preserve_bias=True, device=None):
    out_features, in_features = module.weight.shape
    r = A.shape[1]

    L1 = nn.Linear(in_features, r, bias=False)
    L2 = nn.Linear(r, out_features, bias=(module.bias is not None and preserve_bias))

    with torch.no_grad():
        L1.weight.copy_(B)
        L2.weight.copy_(A)
        if module.bias is not None and preserve_bias:
            L2.bias.copy_(module.bias.detach().clone())

    seq = nn.Sequential(L1, L2)

    if device is not None:
        seq.to(device)

    setattr(module_parent, name, seq)
    return seq


def find_linear_modules(model: nn.Module):
    for parent in model.modules():
        for name, child in list(parent.named_modules()):
            if isinstance(child, nn.Linear):
                yield parent, name, child

def pick_rank_knee_from_singulars(s_vals, layer_name=None, energy=0.99):
    """
    Computes energy-based rank, then searches for knee.

    Returns:
        chosen_rank, knee_idx, ratios
    """
    n = len(s_vals)
    s_safe = np.maximum(s_vals, 1e-12)
    s_vals = np.log(s_safe)

    # energy rank
    sv2 = s_vals * s_vals
    cum = np.cumsum(sv2)
    tot = cum[-1]
    k_energy = int(np.searchsorted(cum, tot * float(energy)) + 1)

    # window range
    k_start = int(max(1, int(0.95 * k_energy)))
    k_end = int(min(1.05 * k_energy, n))

    # compute ratio only inside window
    window_s = s_vals[k_start:k_end]  # shape (k_end-k_start,)
    ratios = window_s[1:] - window_s[:-1]  # shape (k_end-k_start-1,)
    # pick max ratio index
    local_idx = np.argmax(ratios)
    # convert local index to global singular value index
    chosen = k_start + local_idx + 1

    print("\n--- Knee debug for layer:", layer_name)
    print(f"num svals={n}")
    print(f"energy={energy}, k_energy={k_energy}")
    print(f"window=[{k_start} → {k_end}]")
    print(f"chosen_rank={chosen}")

    # Plot singular values (window only)
    plt.figure(figsize=(12, 4))
    plt.plot(np.arange(k_start, k_end+1),
              s_vals[k_start-1:k_end],
              marker='o', markersize=2)
    plt.axvline(chosen, color='red', linestyle='--', label="chosen rank")
    plt.title(f"Singular Values (window) — {layer_name}")
    plt.grid()
    plt.legend()
    plt.show()

    return chosen, ratios


In [None]:
def compress_model(model, out_dir, method = "knee",
                   energy = 0.99, device = "cpu", sample_rows = 0, sample_cols = 0,
                   dry_run = False, debug=True):
    """Compress given model in-place.
    method: 'adaptive' (energy) or 'knee' (ratio-based)
    energy: energy threshold for energy-based rank (used in both modes as cap)
    debug: prints and inline plots in Colab
    """
    device = torch.device(device)
    model.eval()

    stats = {}
    total_params = 0
    saved_params = 0

    for parent, name, linear in tqdm(list(find_linear_modules(model)), desc="Scanning linears"):
        W = linear.weight.data  # (out, in)
        bias = linear.bias.data if linear.bias is not None else None
        out_f, in_f = W.shape
        total = out_f * in_f + (bias.numel() if bias is not None else 0)

        # 1) quick sampled SVD to get spectrum (cheap if small sampling)
        try:
            # operate on CPU
            W_cpu = W.detach().to(torch.float32).cpu()
            if sample_rows > 0 and sample_rows < W_cpu.shape[0]:
                rows = torch.linspace(0, W_cpu.shape[0] - 1, steps=sample_rows).long()
                W_s = W_cpu[rows]
            else:
                W_s = W_cpu
            if sample_cols > 0 and sample_cols < W_s.shape[1]:
                cols = torch.linspace(0, W_s.shape[1] - 1, steps=sample_cols).long()
                W_s = W_s[:, cols]

            U_s, S_s, Vt_s = torch.linalg.svd(W_s, full_matrices=False)
            svals = S_s.cpu().numpy()
        except Exception as e:
            print("SVD sampling failed for layer", name, "- falling back to full CPU SVD. Error:", e)
            U_s, S_s, Vt_s = torch.linalg.svd(W_cpu, full_matrices=False)
            svals = S_s.cpu().numpy()

        # 2) Decide knee_rank (ratio-based)
        if method == "knee":
            knee_rank, ratios = pick_rank_knee_from_singulars(svals, layer_name=f"{parent.__class__.__name__}.{name}")
        elif method == "adaptive":
            # if purely energy-based, compute energy rank directly
            sv2 = svals * svals
            cum = np.cumsum(sv2)
            tot = cum[-1]
            energy_rank = int(np.searchsorted(cum, tot * float(energy)) + 1)
            knee_rank = energy_rank
            if debug:
                print(f"Layer {parent.__class__.__name__}.{name}: energy-only adaptive rank = {energy_rank}")
        else:
            raise ValueError("Unknown method; choose 'knee' or 'adaptive'")

        # 3) compute energy_rank as cap (if energy provided)
        if energy is not None:
            sv2 = svals * svals
            cum = np.cumsum(sv2)
            tot = cum[-1]
            energy_rank = int(np.searchsorted(cum, tot * float(energy)) + 1)
        else:
            energy_rank = svals.shape[0]

        final_k = min(knee_rank, energy_rank)

        # debug prints
        if debug:
            print("\n========== Layer summary ==========")
            print(f"Layer: {parent.__class__.__name__}.{name}")
            print(f"shape: ({out_f}, {in_f}) — total params: {total}")
            print(f"knee_rank: {knee_rank}, energy_rank: {energy_rank}, final_k (min): {final_k}")
            print("------------------------------------")

        # 4) compute final truncated SVD on full W
        A, B, chosen_k = compute_svd(W, k=final_k, energy=None, sample_rows=0, sample_cols=0)

        if dry_run:
            stats[f"{parent.__class__.__name__}.{name}"] = {"out": out_f, "in": in_f, "k": chosen_k}
            continue

        # 5) replace modules
        replace_linear_with_lowrank(parent, name, linear, A, B, preserve_bias=True, device=device)

        saved = total - (out_f * chosen_k + chosen_k * in_f + (bias.numel() if bias is not None else 0))
        stats[f"{parent.__class__.__name__}.{name}"] = {"out": out_f, "in": in_f, "k": chosen_k, "saved_params": saved}
        print("Total:", total)
        print("Saved:", saved)
        saved_params += max(0, saved)
        total_params += total

    # summary
    print("\nCompression summary:")
    for k, v in stats.items():
        print(k, v)

    print(f"Total params touched: {total_params:,}")
    print(f"Approx params saved: {saved_params:,}")

    if not dry_run:
        print(f"Saving compressed model to {out_dir}...")
        os.makedirs(out_dir, exist_ok=True)
        model.save_pretrained(out_dir)

compress_model(model_svd, out_dir, method='knee', energy=0.98, device='cpu', debug=True)