### spectral µTransfer 350m

In [3]:
import torch

In [4]:
NAME_TO_FAN_MAPPING = {
    # NOTE: the original
    # "token_embedding.weight": [49152, 1024],
    # "token_embedding.bias": [49152, 1],
    
    "token_embedding.weight": [1, 1024],
    "token_embedding.bias": [1, 1],
    
    # NOTE:
    "qkv_proj.weight": [1024, 2048],
    "qkv_proj.bias": [1024, 1],
    "o_proj.weight": [1024, 1024],
    "o_proj.bias": [1024, 1],
    "gate_up_proj.weight": [1024, 8192],
    "gate_up_proj.bias": [1024, 1],
    "down_proj.weight": [4096, 1024],
    "down_proj.bias": [4096, 1],
    "lm_head.pp_block.weight": [1024, 49152],
    # "token_embedding.bias": [1024, 1],
}

In [15]:
def spectral_lr(fan_in, fan_out):
    """Spectral parameterization from the [paper](https://arxiv.org/abs/2310.17813)."""
    return (fan_out / fan_in) / 64

In [16]:
lrs = [2**x.item() for x in torch.arange(-20, -1)]

In [18]:
lrs

[9.5367431640625e-07,
 1.9073486328125e-06,
 3.814697265625e-06,
 7.62939453125e-06,
 1.52587890625e-05,
 3.0517578125e-05,
 6.103515625e-05,
 0.0001220703125,
 0.000244140625,
 0.00048828125,
 0.0009765625,
 0.001953125,
 0.00390625,
 0.0078125,
 0.015625,
 0.03125,
 0.0625,
 0.125,
 0.25]

In [20]:
for lr in lrs:
    print(f"###### {lr} ###### ")
    for name, (fan_in, fan_out) in NAME_TO_FAN_MAPPING.items():
        print(f"{name}: {lr * spectral_lr(fan_in, fan_out)}")

###### 9.5367431640625e-07 ###### 
token_embedding.weight: 1.52587890625e-05
token_embedding.bias: 1.4901161193847656e-08
qkv_proj.weight: 2.9802322387695312e-08
qkv_proj.bias: 1.4551915228366852e-11
o_proj.weight: 1.4901161193847656e-08
o_proj.bias: 1.4551915228366852e-11
gate_up_proj.weight: 1.1920928955078125e-07
gate_up_proj.bias: 1.4551915228366852e-11
down_proj.weight: 3.725290298461914e-09
down_proj.bias: 3.637978807091713e-12
lm_head.pp_block.weight: 7.152557373046875e-07
###### 1.9073486328125e-06 ###### 
token_embedding.weight: 3.0517578125e-05
token_embedding.bias: 2.9802322387695312e-08
qkv_proj.weight: 5.960464477539063e-08
qkv_proj.bias: 2.9103830456733704e-11
o_proj.weight: 2.9802322387695312e-08
o_proj.bias: 2.9103830456733704e-11
gate_up_proj.weight: 2.384185791015625e-07
gate_up_proj.bias: 2.9103830456733704e-11
down_proj.weight: 7.450580596923828e-09
down_proj.bias: 7.275957614183426e-12
lm_head.pp_block.weight: 1.430511474609375e-06
###### 3.814697265625e-06 ###### 

In [22]:
for name, (fan_in, fan_out) in NAME_TO_FAN_MAPPING.items():
    print(f"{name}: {0.001 * spectral_lr(fan_in, fan_out)}")

token_embedding.weight: 0.016
token_embedding.bias: 1.5625e-05
qkv_proj.weight: 3.125e-05
qkv_proj.bias: 1.52587890625e-08
o_proj.weight: 1.5625e-05
o_proj.bias: 1.52587890625e-08
gate_up_proj.weight: 0.000125
gate_up_proj.bias: 1.52587890625e-08
down_proj.weight: 3.90625e-06
down_proj.bias: 3.814697265625e-09
lm_head.pp_block.weight: 0.00075
