In [5]:
import torch
import torch.nn as nn
from models import SALT, SALTEdoraLinearV2 # or import directly if it's in the same file
from utils.svd_utils import svd_head_tail

def analyze_salt_trainables(in_features=768, out_features=768, r_top=32, tail_rank=64):
    """
    Creates a dummy SALT layer with the given dimensions and prints:
    - Total parameters
    - Trainable parameters
    - Non-trainable parameters
    - Breakdown by tensor
    """
    # Step 1️⃣: Create a dummy base linear layer (simulating BERT projection)
    base = nn.Linear(in_features, out_features, bias=True)
    salt = SALTEdoraLinearV2(base, tail_rank)

    # Step 2️⃣: Count total & trainable parameters
    total_params = sum(p.numel() for p in salt.parameters())
    trainable_params = sum(p.numel() for p in salt.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params

    print("=" * 80)
    print(f"SALT Parameter Analysis")
    print(f"Input dim: {in_features}, Output dim: {out_features}")
    print(f"r_top: {r_top}, tail_rank: {tail_rank}")
    print("-" * 80)
    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Frozen parameters:    {frozen_params:,}")
    print(f"Trainable %:          {trainable_params / total_params * 100:.4f}%")
    print("-" * 80)

    # Step 3️⃣: Show detailed breakdown
    print("Breakdown of trainable tensors:")
    for name, param in salt.named_parameters():
        print(f"{name:<20} {'trainable' if param.requires_grad else 'frozen':<10} {param.numel():>10,}")

    print("=" * 80)
    return salt


# Example usage (simulating BERT dimensions)
if __name__ == "__main__":
    # BERT attention projection (768 x 768)
    salt_layer = analyze_salt_trainables(in_features=768, out_features=768, r_top=32, tail_rank=64)

    # or BERT feedforward layer (3072 x 768)
    # salt_layer = analyze_salt_trainables(in_features=768, out_features=3072, r_top=32, tail_rank=8)


SALT Parameter Analysis
Input dim: 768, Output dim: 768
r_top: 32, tail_rank: 64
--------------------------------------------------------------------------------
Total parameters:     594,880
Trainable parameters: 4,288
Frozen parameters:    590,592
Trainable %:          0.7208%
--------------------------------------------------------------------------------
Breakdown of trainable tensors:
alpha                trainable          64
beta                 trainable          64
D                    trainable          64
R                    trainable       4,096
base.weight          frozen        589,824
base.bias            frozen            768


In [6]:
def svd_head_tail(W, r):
    U, S, Vh = torch.linalg.svd(W, full_matrices=False)
    p = S.numel()

    # Ensure we don't exceed the available rank
    r_tail = min(r, p)
    r_top = max(p - r_tail, 0)

    # Split head/tail
    U_top, S_top, Vh_top = U[:, :r_top], S[:r_top], Vh[:r_top, :]
    U_tail, S_tail, Vh_tail = U[:, r_top:], S[r_top:], Vh[r_top:, :]

    return (U_top, S_top, Vh_top), (U_tail, S_tail, Vh_tail)

In [7]:
W = torch.randn(768, 768)
(head, tail) = svd_head_tail(W, r=4)

print("U_top:", head[0].shape)
print("S_top:", head[1].shape)
print("Vh_top:", head[2].shape)
print("U_tail:", tail[0].shape)
print("S_tail:", tail[1].shape)
print("Vh_tail:", tail[2].shape)

U_top: torch.Size([768, 764])
S_top: torch.Size([764])
Vh_top: torch.Size([764, 768])
U_tail: torch.Size([768, 4])
S_tail: torch.Size([4])
Vh_tail: torch.Size([4, 768])


In [8]:
import torch

W = torch.randn(768, 768)

# Run your SVD split
(head, tail) = svd_head_tail(W, r=4)

# Unpack
U_top, S_top, Vh_top = head
U_tail, S_tail, Vh_tail = tail

# ✅ Reconstruct W from head + tail
W_head = (U_top * S_top) @ Vh_top
W_tail = (U_tail * S_tail) @ Vh_tail
W_recon = W_head + W_tail

# ✅ Compute reconstruction error (should be near zero)
recon_error = (W - W_recon).norm() / W.norm()
print(f"Reconstruction relative error: {recon_error.item():.3e}")

# ✅ Optional sanity check on shapes
print("U_top:", U_top.shape)
print("S_top:", S_top.shape)
print("Vh_top:", Vh_top.shape)
print("U_tail:", U_tail.shape)
print("S_tail:", S_tail.shape)
print("Vh_tail:", Vh_tail.shape)


Reconstruction relative error: 1.802e-06
U_top: torch.Size([768, 764])
S_top: torch.Size([764])
Vh_top: torch.Size([764, 768])
U_tail: torch.Size([768, 4])
S_tail: torch.Size([4])
Vh_tail: torch.Size([4, 768])


In [9]:
W

tensor([[-0.3186, -0.7997,  0.6924,  ..., -0.0438, -1.4127, -0.9437],
        [-0.9846, -2.0882,  1.5418,  ...,  1.9014,  0.2866, -1.0197],
        [ 0.5301, -1.9153, -1.0040,  ...,  0.0254, -1.7016,  0.8479],
        ...,
        [-0.4989, -0.2758,  1.0541,  ...,  0.2765, -0.4789, -0.6002],
        [-2.7582, -1.0077,  0.9167,  ..., -1.2549,  1.7061, -0.6058],
        [ 0.3568,  1.0276,  1.3051,  ...,  1.0987, -0.3125, -1.8925]])

In [15]:
W_head + W_tail

tensor([[-0.3186, -0.7997,  0.6924,  ..., -0.0438, -1.4127, -0.9437],
        [-0.9847, -2.0882,  1.5418,  ...,  1.9014,  0.2866, -1.0197],
        [ 0.5301, -1.9153, -1.0040,  ...,  0.0254, -1.7016,  0.8479],
        ...,
        [-0.4989, -0.2758,  1.0541,  ...,  0.2765, -0.4789, -0.6002],
        [-2.7582, -1.0077,  0.9167,  ..., -1.2549,  1.7061, -0.6058],
        [ 0.3568,  1.0276,  1.3051,  ...,  1.0987, -0.3125, -1.8925]])

In [11]:
W_tail

tensor([[-4.7981e-04,  5.9225e-04, -2.0136e-04,  ..., -1.5246e-04,
          8.7929e-05,  1.6057e-04],
        [-2.9198e-04,  2.6308e-04, -1.9486e-04,  ...,  6.7842e-06,
          1.1114e-04, -2.4365e-05],
        [-1.2387e-04,  1.6415e-04, -7.1647e-05,  ..., -3.0214e-05,
          6.3690e-05, -9.4080e-06],
        ...,
        [ 3.0820e-04, -1.6131e-04,  2.0097e-04,  ..., -4.1044e-06,
          8.5188e-05, -1.2484e-04],
        [-2.5827e-04,  4.5846e-04, -9.9390e-05,  ..., -1.3269e-04,
          1.7327e-04,  7.1233e-06],
        [ 1.1637e-04, -2.7852e-05,  9.4625e-05,  ..., -5.1087e-05,
         -1.9922e-05,  1.1857e-05]])