In [1]:
import os
import sys
import torch
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

# -----------------------------------------------------------------------------
# default config values
model_name = "gelu-2l" 
dataset_path = "NeelNanda/c4-code-tokenized-2b"

total_training_steps = 2_000 
batch_size = 4096 
new_cached_activations_path = (
    f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}"
)

hook_layer=1
hook_name=f"blocks.{hook_layer}.attn.hook_z"
hook_head_index = None
l1_coefficients = [2, 5, 10]
d_in= 512
expansion_factor = 32

total_training_tokens = total_training_steps * batch_size

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training steps.
print(f"lr_decay_steps: {lr_decay_steps}")
l1_warmup_steps = total_training_steps // 20  # 5% of training steps.
print(f"l1_warmup_steps: {l1_warmup_steps}")
log_to_wandb = False

for l1_coefficient in l1_coefficients:
    cfg = LanguageModelSAERunnerConfig(
        # Pick a tiny model to make this easier.
        model_name=model_name, 
        ## MLP Layer 0 ##
        hook_name=hook_name, 
        hook_layer=hook_layer, 
        hook_head_index=hook_head_index,
        d_in=d_in,
        dataset_path=dataset_path,
        streaming=True, # pre-download the token dataset
        context_size=1024,
        is_dataset_tokenized=True,
        prepend_bos=True,

        n_heads=8, # TODO: replace the hack

        # How big do we want our SAE to be?
        expansion_factor=expansion_factor,
        # Dataset / Activation Store
        # When we do a proper test
        # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100)
        # For now.
        use_cached_activations=False,
        #cached_activations_path="./gelu-2l",
        training_tokens=total_training_tokens,  # For initial testing I think this is a good number.
        train_batch_size_tokens=batch_size,
        # Loss Function
        ## Reconstruction Coefficient.
        mse_loss_normalization=None,  # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch.
        ## Anthropic does not mention using an Lp norm other than L1.
        l1_coefficient=l1_coefficient,
        lp_norm=1.0,
        # Instead, they multiply the L1 loss contribution
        # from each feature of the activations by the decoder norm of the corresponding feature.
        scale_sparsity_penalty_by_decoder_norm=True,
        # Learning Rate
        lr_scheduler_name="constant",  # we set this independently of warmup and decay steps. # TODO: understand why it's constant
        l1_warm_up_steps=l1_warmup_steps,
        lr_warm_up_steps=lr_warm_up_steps,
        lr_decay_steps=lr_warm_up_steps,
        ## No ghost grad term.
        use_ghost_grads=False,
        # Initialization / Architecture
        architecture="block_diag",
        apply_b_dec_to_input=False,
        # encoder bias zero's. (I'm not sure what it is by default now)
        # decoder bias zero's.
        b_dec_init_method="zeros",
        normalize_sae_decoder=False,
        decoder_heuristic_init=True,
        init_encoder_as_decoder_transpose=True,
        # Optimizer
        lr=5e-5,
        ## adam optimizer has no weight decay by default so worry about this.
        adam_beta1=0.9,
        adam_beta2=0.999,
        # Buffer details won't matter in we cache / shuffle our activations ahead of time.
        n_batches_in_buffer=64,
        store_batch_size_prompts=16,
        normalize_activations="expected_average_only_in", # TODO: not sure what's the best choice # Activation Normalization Strategy. Either none, expected_average_only_in, or constant_norm_rescale.
        # Feature Store
        feature_sampling_window=1000,
        dead_feature_window=1000,
        dead_feature_threshold=1e-4,
        # WANDB
        log_to_wandb=log_to_wandb,  # always use wandb unless you are just testing code.
        wandb_project=f"{model_name}-attn-{hook_layer}-sae",
        wandb_log_frequency=50,
        eval_every_n_wandb_logs=10,
        # Misc
        device=device,
        seed=42,
        n_checkpoints=2,
        checkpoint_path="checkpoints",
        dtype="float32",
    )

    print(f"Total Training Tokens: {total_training_tokens}")
    break


Using device: cuda
lr_decay_steps: 400
l1_warmup_steps: 100
Run name: 16384-L1-2-LR-5e-05-Tokens-8.192e+06
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 2000
Total wandb updates: 40
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 2 times.
Number tokens in sparsity calculation window: 4.10e+06
Total Training Tokens: 8192000


In [2]:
# look at the next cell to see some instruction for what to do while this is running.
torch.manual_seed(42)
runner = SAETrainingRunner(cfg)

Loaded pretrained model gelu-2l into HookedTransformer


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]



In [5]:
B = 16
T = 1024
n_heads = 8
d_heads = 64
torch.manual_seed(34)
x = torch.randn(B, T, n_heads * d_heads, device=runner.sae.device)
encode_out1, _ = runner.sae.encode_block_diag(x)
encode_out1.shape

torch.Size([16, 1024, 16384])

In [6]:
encode_out2 = runner.sae.encode_block_diag2(x)
encode_out2.shape

torch.Size([16, 1024, 16384])

In [7]:
torch.allclose(encode_out1, encode_out2)

True

In [8]:
decode_out1 = runner.sae.decode_block_diag(encode_out1)
decode_out1.shape

torch.Size([16, 1024, 512])

In [9]:
decode_out2 = runner.sae.decode_block_diag2(encode_out1)
decode_out2.shape

torch.Size([16, 1024, 512])

In [10]:
torch.allclose(decode_out1, decode_out2)

True

: 

Right, all of the encoder and decoder blocks are in sae.parameters(), but W_dec and W_enc are themselves not parameters.

In [13]:
torch.allclose(init_enc_block0_weight, trained_sae.enc_blocks["0"].weight)

False

In [12]:
trained_sae.enc_blocks["0"].weight

Parameter containing:
tensor([[-0.2146, -0.2100,  0.2851,  ...,  0.2254, -0.2738,  0.2695],
        [-0.2679, -0.3336,  0.4122,  ...,  0.2343, -0.3494,  0.2086],
        [-0.2055, -0.3342,  0.4172,  ...,  0.3241, -0.3312,  0.2757],
        ...,
        [-0.2709, -0.2905,  0.2231,  ...,  0.4075, -0.2526,  0.2426],
        [-0.2616, -0.3662,  0.2116,  ...,  0.2071, -0.1918,  0.3365],
        [-0.0830, -0.3156,  0.2708,  ...,  0.2718,  0.2263,  0.0888]],
       device='cuda:0', requires_grad=True)

In [3]:
init_enc_block0_weight = runner.sae.enc_blocks["0"].weight.data.clone()
print(type(runner.sae.enc_blocks["0"].weight))
init_enc_block0_weight

<class 'torch.nn.parameter.Parameter'>


tensor([[ 0.0956,  0.1038, -0.0293,  ..., -0.0853,  0.0385, -0.0430],
        [ 0.0383, -0.0260,  0.1037,  ..., -0.0729, -0.0428, -0.0987],
        [ 0.1048, -0.0248,  0.1075,  ...,  0.0134, -0.0221, -0.0373],
        ...,
        [ 0.0353,  0.0179, -0.0835,  ...,  0.1018,  0.0542, -0.0653],
        [ 0.0442, -0.0610, -0.0666,  ..., -0.0071,  0.1038,  0.0466],
        [-0.0996, -0.0305, -0.0282,  ..., -0.0201, -0.0599,  0.0170]],
       device='cuda:0')

In [14]:
trained_sae.W_dec

tensor([[-0.0211,  0.0075, -0.0008,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0076, -0.0160, -0.0113,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0147, -0.0185,  0.0202,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0030,  0.0142, -0.0211],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0172, -0.0179,  0.0185],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0214,  0.0196,  0.0009]],
       device='cuda:0', grad_fn=<BlockDiagBackward0>)

In [15]:
init_W_dec

tensor([[-0.0211,  0.0075, -0.0008,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0076, -0.0160, -0.0113,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0147, -0.0185,  0.0202,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0030,  0.0142, -0.0211],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0172, -0.0179,  0.0185],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0214,  0.0196,  0.0009]],
       device='cuda:0')

In [17]:
trained_sae.dec_blocks["0"].weight

Parameter containing:
tensor([[ 0.2814,  0.2926,  0.3111,  ...,  0.2991,  0.2828,  0.2396],
        [ 0.3061,  0.2845,  0.2847,  ...,  0.3076,  0.3029,  0.2191],
        [-0.2998, -0.3089, -0.2818,  ..., -0.2952, -0.3136, -0.3017],
        ...,
        [-0.3106, -0.2902, -0.3064,  ..., -0.2969, -0.3158, -0.2894],
        [ 0.2977,  0.2948,  0.2994,  ...,  0.3026,  0.3081, -0.2788],
        [-0.3063, -0.2794, -0.3147,  ..., -0.3104, -0.2895, -0.3027]],
       device='cuda:0', requires_grad=True)

In [4]:
init_W_dec = runner.sae.W_dec.data.clone()
init_W_enc = runner.sae.W_enc.data.clone()
print(type(runner.sae.W_dec))

<class 'torch.Tensor'>


In [5]:
trained_sae = runner.run()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshehper[0m. Use [1m`wandb login --relogin`[0m to force relogin


Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:35<00:00, 28.09it/s]
6000| MSE Loss 558.173 | L1 747.712: 100%|██████████| 24576000/24576000 [07:45<00:00, 52841.55it/s]


VBox(children=(Label(value='16.268 MB of 16.268 MB uploaded (0.002 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
details/current_l1_coefficient,▁▅██████████████████████████████████████
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/auxiliary_reconstruction_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/ghost_grad_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,▁▅▆▇▇███████████████████████████████████
losses/mse_loss,▁▄▆▇▇███████▇███▇███████▇██▇█████▇█▇████
losses/overall_loss,▁▅▇▇████████████████████████████████████
metrics/CE_loss_score,▇▇▁▂▆▅▄▅▇█▂▃
metrics/ce_loss_with_ablation,▅▁▂▁▆▄▃▃▇▅█▂

0,1
details/current_l1_coefficient,2.0
details/current_learning_rate,5e-05
details/n_training_tokens,24576000.0
losses/auxiliary_reconstruction_loss,0.0
losses/ghost_grad_loss,0.0
losses/l1_loss,373.85617
losses/mse_loss,558.17346
losses/overall_loss,1305.88574
metrics/CE_loss_score,-0.00654
metrics/ce_loss_with_ablation,6.99475


In [6]:
trained_sae.W_dec

tensor([[-0.0211,  0.0075, -0.0008,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0076, -0.0160, -0.0113,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0147, -0.0185,  0.0202,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0030,  0.0142, -0.0211],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0172, -0.0179,  0.0185],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0214,  0.0196,  0.0009]],
       device='cuda:0', grad_fn=<BlockDiagBackward0>)

In [8]:
torch.allclose(init_W_dec, trained_sae.W_dec)

True

In [9]:
torch.allclose(init_W_enc, trained_sae.W_enc)

True

In [10]:
trained_sae.b_enc

tensor([ 0.0597,  0.0812, -0.0701,  ...,  0.0253, -0.1019, -0.1229],
       device='cuda:0', grad_fn=<CatBackward0>)

In [11]:
trained_sae.b_dec

tensor([ 1.3798e-03, -1.4406e-02,  6.6600e-03,  1.7949e-02, -2.0928e-02,
        -1.7106e-02, -6.7341e-03, -1.4137e-02,  1.8258e-02,  7.8433e-03,
        -1.6477e-02, -6.5140e-03,  6.9009e-03, -6.8866e-04,  2.0495e-02,
        -7.3730e-03,  1.2815e-03, -4.3822e-03, -4.4309e-04,  1.5135e-02,
        -2.4795e-03, -1.4174e-02,  2.8641e-03, -1.9982e-02,  9.7295e-03,
        -8.5921e-03, -1.9043e-02,  1.8825e-02,  2.1587e-02,  9.4444e-03,
         1.0079e-02, -9.3242e-03, -1.9717e-02, -1.8062e-02, -8.3509e-04,
         1.3807e-02,  3.9873e-03,  1.5570e-02, -6.1561e-03, -3.0117e-03,
         1.1046e-03,  1.3812e-02,  1.8806e-03,  2.1446e-02, -1.1533e-02,
         8.5836e-04,  1.5602e-02,  1.7486e-02, -4.0159e-03,  8.7399e-04,
        -5.6304e-03,  6.2776e-03,  1.8152e-02,  1.1937e-02, -1.7574e-02,
        -1.2467e-02, -8.6007e-03, -1.5669e-03,  1.9530e-03, -1.4182e-02,
         1.6515e-02, -8.2939e-03,  4.3075e-03, -7.4483e-03, -9.9832e-03,
         1.5400e-02, -7.8276e-03,  1.8333e-02,  1.6

In [42]:
type(trained_sae.W_dec)

torch.Tensor

In [43]:
type(trained_sae.b_dec)

torch.nn.parameter.Parameter

In [54]:
for p in trained_sae.parameters():
    print(p.shape)

torch.Size([16384])
torch.Size([512])


In [44]:
type(trained_sae.enc_blocks[0])

torch.nn.parameter.Parameter

In [45]:
# look at the next cell to see some instruction for what to do while this is running.
torch.manual_seed(42)
new_runner = SAETrainingRunner(cfg)

Loaded pretrained model gelu-2l into HookedTransformer


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

{'n_heads': 8, 'architecture': 'block_diag', 'd_in': 512, 'd_sae': 16384, 'activation_fn_str': 'relu', 'activation_fn_kwargs': {}, 'apply_b_dec_to_input': False, 'dtype': 'float32', 'model_name': 'gelu-2l', 'hook_name': 'blocks.1.attn.hook_z', 'hook_layer': 1, 'hook_head_index': None, 'device': 'cuda', 'context_size': 1024, 'prepend_bos': True, 'finetuning_scaling_factor': False, 'normalize_activations': 'expected_average_only_in', 'dataset_path': 'NeelNanda/c4-code-tokenized-2b', 'dataset_trust_remote_code': True, 'sae_lens_training_version': '3.11.0'}


In [48]:
type(new_runner.sae.b_dec)

torch.nn.parameter.Parameter

In [53]:
import torch.nn as nn
type(torch.block_diag(nn.Parameter(torch.randn(2, 2)), nn.Parameter(torch.randn(2, 2))))

torch.Tensor

In [55]:
for p in new_runner.sae.parameters():
    print(p.shape)

torch.Size([16384])
torch.Size([512])
