In [1]:
import torch
import torch.nn as nn
from models import SALT  # or import directly if it's in the same file
from utils.svd_utils import svd_head_tail

def analyze_salt_trainables(in_features=768, out_features=768, r_top=32, tail_rank=8):
    """
    Creates a dummy SALT layer with the given dimensions and prints:
    - Total parameters
    - Trainable parameters
    - Non-trainable parameters
    - Breakdown by tensor
    """
    # Step 1️⃣: Create a dummy base linear layer (simulating BERT projection)
    base = nn.Linear(in_features, out_features, bias=True)
    salt = SALT(base, r_top=r_top, tail_rank=tail_rank)

    # Step 2️⃣: Count total & trainable parameters
    total_params = sum(p.numel() for p in salt.parameters())
    trainable_params = sum(p.numel() for p in salt.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params

    print("=" * 80)
    print(f"SALT Parameter Analysis")
    print(f"Input dim: {in_features}, Output dim: {out_features}")
    print(f"r_top: {r_top}, tail_rank: {tail_rank}")
    print("-" * 80)
    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Frozen parameters:    {frozen_params:,}")
    print(f"Trainable %:          {trainable_params / total_params * 100:.4f}%")
    print("-" * 80)

    # Step 3️⃣: Show detailed breakdown
    print("Breakdown of trainable tensors:")
    for name, param in salt.named_parameters():
        print(f"{name:<20} {'trainable' if param.requires_grad else 'frozen':<10} {param.numel():>10,}")

    print("=" * 80)
    return salt


# Example usage (simulating BERT dimensions)
if __name__ == "__main__":
    # BERT attention projection (768 x 768)
    salt_layer = analyze_salt_trainables(in_features=768, out_features=768, r_top=32, tail_rank=8)

    # or BERT feedforward layer (3072 x 768)
    # salt_layer = analyze_salt_trainables(in_features=768, out_features=3072, r_top=32, tail_rank=8)


SALT Parameter Analysis
Input dim: 768, Output dim: 768
r_top: 32, tail_rank: 8
--------------------------------------------------------------------------------
Total parameters:     591,168
Trainable parameters: 576
Frozen parameters:    590,592
Trainable %:          0.0974%
--------------------------------------------------------------------------------
Breakdown of trainable tensors:
alpha                trainable          32
beta                 trainable          32
X                    trainable         256
Y                    trainable         256
base.weight          frozen        589,824
base.bias            frozen            768
