In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import torch
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

# -----------------------------------------------------------------------------
# default config values
model_name = "gelu-2l" 
dataset_path = "NeelNanda/c4-code-tokenized-2b"

total_training_steps = 25_000 
batch_size = 4096 
new_cached_activations_path = (
    f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}"
)

hook_layer=1
hook_name=f"blocks.{hook_layer}.attn.hook_z"
hook_head_index = None
l1_coefficients = [2, 5, 10]
d_in= 512
expansion_factor = 32

total_training_tokens = total_training_steps * batch_size

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training steps.
print(f"lr_decay_steps: {lr_decay_steps}")
l1_warmup_steps = total_training_steps // 20  # 5% of training steps.
print(f"l1_warmup_steps: {l1_warmup_steps}")
log_to_wandb = True

def save_init_weights(runner):
    init_weights = {}
    for n, p in runner.sae.named_parameters():
        init_weights[n] = p.data.cpu().clone()
    return init_weights

for l1_coefficient in l1_coefficients:
    cfg = LanguageModelSAERunnerConfig(
        # Pick a tiny model to make this easier.
        model_name=model_name, 
        ## MLP Layer 0 ##
        hook_name=hook_name, 
        hook_layer=hook_layer, 
        hook_head_index=hook_head_index,
        d_in=d_in,
        dataset_path=dataset_path,
        streaming=True, # pre-download the token dataset
        context_size=1024,
        is_dataset_tokenized=True,
        prepend_bos=True,

        n_heads=8, # TODO: replace the hack

        # How big do we want our SAE to be?
        expansion_factor=expansion_factor,
        # Dataset / Activation Store
        # When we do a proper test
        # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100)
        # For now.
        use_cached_activations=False,
        #cached_activations_path="./gelu-2l",
        training_tokens=total_training_tokens,  # For initial testing I think this is a good number.
        train_batch_size_tokens=batch_size,
        # Loss Function
        ## Reconstruction Coefficient.
        mse_loss_normalization=None,  # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch.
        ## Anthropic does not mention using an Lp norm other than L1.
        l1_coefficient=l1_coefficient,
        lp_norm=1.0,
        # Instead, they multiply the L1 loss contribution
        # from each feature of the activations by the decoder norm of the corresponding feature.
        scale_sparsity_penalty_by_decoder_norm=True,
        # Learning Rate
        lr_scheduler_name="constant",  # we set this independently of warmup and decay steps. # TODO: understand why it's constant
        l1_warm_up_steps=l1_warmup_steps,
        lr_warm_up_steps=lr_warm_up_steps,
        lr_decay_steps=lr_warm_up_steps,
        ## No ghost grad term.
        use_ghost_grads=False,
        # Initialization / Architecture
        architecture="block_diag",
        apply_b_dec_to_input=False,
        # encoder bias zero's. (I'm not sure what it is by default now)
        # decoder bias zero's.
        b_dec_init_method="zeros",
        normalize_sae_decoder=False,
        decoder_heuristic_init=True,
        init_encoder_as_decoder_transpose=True,
        # Optimizer
        lr=5e-5,
        ## adam optimizer has no weight decay by default so worry about this.
        adam_beta1=0.9,
        adam_beta2=0.999,
        # Buffer details won't matter in we cache / shuffle our activations ahead of time.
        n_batches_in_buffer=64,
        store_batch_size_prompts=16,
        normalize_activations="expected_average_only_in", # TODO: not sure what's the best choice # Activation Normalization Strategy. Either none, expected_average_only_in, or constant_norm_rescale.
        # Feature Store
        feature_sampling_window=1000,
        dead_feature_window=1000,
        dead_feature_threshold=1e-4,
        # WANDB
        log_to_wandb=log_to_wandb,  # always use wandb unless you are just testing code.
        wandb_project=f"{model_name}-attn-{hook_layer}-sae",
        wandb_log_frequency=50,
        eval_every_n_wandb_logs=10,
        # Misc
        device=device,
        seed=42,
        n_checkpoints=2,
        checkpoint_path="checkpoints",
        dtype="float32",
    )

    print(f"Total Training Tokens: {total_training_tokens}")
    break

Using device: cuda
lr_decay_steps: 5000
l1_warmup_steps: 1250
Run name: 16384-L1-2-LR-5e-05-Tokens-1.024e+08
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 25000
Total wandb updates: 500
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 25 times.
Number tokens in sparsity calculation window: 4.10e+06
Total Training Tokens: 102400000


In [3]:
runner = SAETrainingRunner(cfg)
init_weights = save_init_weights(runner=runner)
runner.run()

Loaded pretrained model gelu-2l into HookedTransformer


Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshehper[0m. Use [1m`wandb login --relogin`[0m to force relogin


Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:36<00:00, 27.77it/s]
1000| MSE Loss 105.488 | L1 159.019:   4%|▍         | 4096000/102400000 [01:46<31:30, 52008.24it/s]

interrupted, saving progress
done saving


InterruptedException: 

1000| MSE Loss 105.488 | L1 159.019:   4%|▍         | 4096000/102400000 [01:57<31:30, 52008.24it/s]

: 

#### Initial weights

In [None]:
init_weights = {}
for n, p in runner.sae.named_parameters():
    init_weights[n] = p.data.clone()

In [5]:
init_weights['enc_blocks.0.weight']

tensor([[ 0.0956,  0.1038, -0.0293,  ..., -0.0853,  0.0385, -0.0430],
        [ 0.0383, -0.0260,  0.1037,  ..., -0.0729, -0.0428, -0.0987],
        [ 0.1048, -0.0248,  0.1075,  ...,  0.0134, -0.0221, -0.0373],
        ...,
        [ 0.0353,  0.0179, -0.0835,  ...,  0.1018,  0.0542, -0.0653],
        [ 0.0442, -0.0610, -0.0666,  ..., -0.0071,  0.1038,  0.0466],
        [-0.0996, -0.0305, -0.0282,  ..., -0.0201, -0.0599,  0.0170]],
       device='cuda:0')

In [8]:
init_weights["W_enc"]

tensor([[ 0.0956,  0.0383,  0.1048,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1038, -0.0260, -0.0248,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0293,  0.1037,  0.1075,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0785,  0.0362, -0.0852],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.1197, -0.0665,  0.0464],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0890,  0.0489, -0.0697]],
       device='cuda:0', grad_fn=<CloneBackward0>)

In [9]:
init_weights["W_dec"]

tensor([[-0.0211,  0.0075, -0.0008,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0076, -0.0160, -0.0113,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0147, -0.0185,  0.0202,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0030,  0.0142, -0.0211],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0172, -0.0179,  0.0185],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0214,  0.0196,  0.0009]],
       device='cuda:0', grad_fn=<CloneBackward0>)

In [10]:
init_weights["b_enc"]

tensor([ 0.0597,  0.0812, -0.0701,  ...,  0.0253, -0.1019, -0.1229],
       device='cuda:0', grad_fn=<CloneBackward0>)

In [11]:
init_weights["b_dec"]

tensor([ 1.3798e-03, -1.4406e-02,  6.6600e-03,  1.7949e-02, -2.0928e-02,
        -1.7106e-02, -6.7341e-03, -1.4137e-02,  1.8258e-02,  7.8433e-03,
        -1.6477e-02, -6.5140e-03,  6.9009e-03, -6.8866e-04,  2.0495e-02,
        -7.3730e-03,  1.2815e-03, -4.3822e-03, -4.4309e-04,  1.5135e-02,
        -2.4795e-03, -1.4174e-02,  2.8641e-03, -1.9982e-02,  9.7295e-03,
        -8.5921e-03, -1.9043e-02,  1.8825e-02,  2.1587e-02,  9.4444e-03,
         1.0079e-02, -9.3242e-03, -1.9717e-02, -1.8062e-02, -8.3509e-04,
         1.3807e-02,  3.9873e-03,  1.5570e-02, -6.1561e-03, -3.0117e-03,
         1.1046e-03,  1.3812e-02,  1.8806e-03,  2.1446e-02, -1.1533e-02,
         8.5836e-04,  1.5602e-02,  1.7486e-02, -4.0159e-03,  8.7399e-04,
        -5.6304e-03,  6.2776e-03,  1.8152e-02,  1.1937e-02, -1.7574e-02,
        -1.2467e-02, -8.6007e-03, -1.5669e-03,  1.9530e-03, -1.4182e-02,
         1.6515e-02, -8.2939e-03,  4.3075e-03, -7.4483e-03, -9.9832e-03,
         1.5400e-02, -7.8276e-03,  1.8333e-02,  1.6

#### Run

We initiate `runner` by instantiating `SAETrainingRunner(cfg)`. This class is defined in `sae_training_runner.py`.

It has an attribute `sae` which is an instance of `TrainingSAE` defined in `sae_lens.training.training_sae`.

In the `__init__` method of `TrainingSAE`, I have placed

```
if self.cfg.architecture == "block_diag":
    self.encode_with_hidden_pre_fn = self.encode_block_diag
    self.decode_fn = self.decode_block_diag
```

We train the SAE by calling `runner.run()`. This defines `trainer` to be an instance of `SAETrainer`, which is defined in `sae_lens.training.sae_trainer`.

Training is done by calling `trainer.fit()`, which calls `self._train_step` method. This method calls `self.sae.training_forward_pass` which is a method of the `TrainingSAE` class. It also calls `self._run_and_log_evals`, which calls `run_evals` from `sae_lens.evals.py`.

My job is to make sure that both `TrainingSAE.training_forward_pass` and `run_evals` are consistent with the block diagonal architecture. 

In [6]:
runner.run()

Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:34<00:00, 28.85it/s]


AttributeError: 'TrainingSAE' object has no attribute 'W_dec'

In [8]:
import torch 
import torch.nn as nn
type(torch.cat([nn.Parameter(torch.randn(2, )), nn.Parameter(torch.randn(2, ))]))

torch.Tensor

In [13]:
n_in = 16
n_blocks = 8
n_latents = 32
device = "cuda"
dec_blocks = nn.ModuleDict({str(i): nn.Linear(n_latents // n_blocks, n_in // n_blocks).to(device=device) for i in range(n_blocks)})

In [14]:
dec_blocks["0"].weight

Parameter containing:
tensor([[ 0.1662, -0.2651, -0.4207, -0.2487],
        [-0.3440,  0.3033,  0.1173, -0.0140]], device='cuda:0',
       requires_grad=True)

In [15]:
W_dec = torch.block_diag(*[dec_blocks[str(i)].weight.t() for i in range(n_blocks)])

In [16]:
W_dec.shape

torch.Size([32, 16])

In [18]:
W_dec.norm(dim=1)

tensor([0.3820, 0.4028, 0.4367, 0.2491, 0.3898, 0.1922, 0.4765, 0.5371, 0.4102,
        0.1023, 0.1718, 0.4046, 0.3490, 0.2021, 0.4888, 0.2652, 0.1489, 0.3086,
        0.2403, 0.4287, 0.1808, 0.3064, 0.1680, 0.3593, 0.4716, 0.4752, 0.4245,
        0.4255, 0.4611, 0.2527, 0.2090, 0.4547], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)

In [21]:
dec_blocks["0"].weight.t().norm(dim=1)

tensor([0.3820, 0.4028, 0.4367, 0.2491], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)

In [None]:
torch.tensor([nn.Parameter(2, 4), nn.Parameter(2, 4)])