In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
from sprint.loading import load_all, load_sae

model, data, sae_l1 = load_all(model_name="gelu-2l", run_id="l1")
sae_l0 = load_sae(run_id="l0")

Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cuda
Changing model dtype to torch.float16
Model device: cuda:0
Tokens shape: torch.Size([215402, 128]), dtype: torch.int64, device: cuda:0
{'act_name': 'blocks.1.hook_mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'd_mlp': 512,
 'device': 'cuda:0',
 'dict_mult': 32,
 'dict_size': 16384,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'layer': 1,
 'lr': 0.0001,
 'model_batch_size': 512,
 'model_name': 'gelu-2l',
 'num_tokens': 2000000000,
 'remove_rare_dir': False,
 'seed': 50,
 'seq_len': 128,
 'site': 'mlp_out'}
Encoder device: cuda:0
{'act_name': 'blocks.0.hook_mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'd_mlp': 512,
 'device': 'cuda:1',
 'dict_mult': 32,
 'dict_size': 16384,
 'enc_dtype': 'fp32'

In [3]:
# Check that it's two layers
len(model.blocks)

2

In [20]:
# Shapes - gelu-2l

print(model.blocks[0].mlp.W_in.shape)
print(sae.W_enc.shape)

torch.Size([512, 2048])
torch.Size([512, 16384])


In [22]:
# That doesn't look right... let's check this for the old model

model, data, sae = load_all()

print(model.blocks[0].mlp.W_in.shape)
print(sae.W_enc.shape)

Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  cuda
Changing model dtype to torch.float16
Model device: cuda:0
Tokens shape: torch.Size([215402, 128]), dtype: torch.int64, device: cuda:0
{'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'd_mlp': 2048,
 'dict_mult': 8,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'lr': 0.0001,
 'model_batch_size': 512,
 'num_tokens': 2000000000,
 'seed': 52,
 'seq_len': 128}
Encoder device: cuda:0
torch.Size([512, 2048])
torch.Size([2048, 16384])


In [4]:
# Verify linearization still works at L0

from sprint.linearization import analyze_linearized_feature

analyze_linearized_feature(
    feature_idx=100,
    sample_idx=30,
    token_idx=30,
    model=model,
    data=data,
    encoder=sae_l0,
    layer=0,
    batch_size=32,
    n_tokens=10
)

{'feature': tensor([ 0.0340,  0.0626,  0.0676,  ...,  0.1301, -0.0206,  0.0258],
        device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>),
 'sae activations': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.1543, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0117, 0.0000]],
        device='cuda:0', dtype=torch.float16, grad_fn=<ReluBackward0>),
 'domain': tensor([ 2.1805e-02, -4.8462e-02,  1.9946e-01,  2.5977e-01,  3.3417e-02,
         -7.5073e-02, -1.3672e-01,  1.1084e-01, -1.6577e-01,  1.0535e-01,
          1.1176e-01, -3.1445e-01, -1.1035e-01,  3.7155e-03,  1.6016e-01,
          9.6802e-02, -5.1270e-02, -9.2590e-02, -1.1841e-01,  1.1920e-01,
          1.7261e-01, -5.5695e-02, -7.4158e

In [5]:
# Repeat at L1 (this multiplies by the wrong features, but I'm just seeing what breaks)

analyze_linearized_feature(
    feature_idx=100,
    sample_idx=30,
    token_idx=30,
    model=model,
    layer=1,
    data=data,
    encoder=sae_l1,
    batch_size=32,
    n_tokens=10
)

{'feature': tensor([ 0.0340,  0.0626,  0.0676,  ...,  0.1301, -0.0206,  0.0258],
        device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>),
 'sae activations': tensor([[17.1562, 27.7969, 26.9375,  ...,  0.0000,  0.0000, 36.5625],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
        device='cuda:0', dtype=torch.float16, grad_fn=<ReluBackward0>),
 'domain': tensor([-3.8300e-02, -3.7573e-01, -3.3716e-01, -2.6074e-01,  4.0771e-01,
          2.7100e-01,  1.9458e-01, -2.4731e-01, -6.2378e-02,  1.8213e-01,
          2.7893e-02, -3.2928e-02, -6.9641e-02, -2.1631e-01,  4.9347e-02,
          1.6833e-01, -9.8705e-04, -1.9080e-01,  1.2230e-02, -2.0508e-01,
       

In [7]:
load_all(model_name="gelu-2l", run="l1")

Loaded pretrained model gelu-2l into HookedTransformer
Moving model to device:  cuda
Changing model dtype to torch.float16
Model device: cuda:0
Tokens shape: torch.Size([215402, 128]), dtype: torch.int64, device: cuda:0
{'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'd_mlp': 2048,
 'dict_mult': 8,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'lr': 0.0001,
 'model_batch_size': 512,
 'num_tokens': 2000000000,
 'seed': 52,
 'seq_len': 128}
Encoder device: cuda:0


(HookedTransformer(
   (embed): Embed()
   (hook_embed): HookPoint()
   (pos_embed): PosEmbed()
   (hook_pos_embed): HookPoint()
   (blocks): ModuleList(
     (0-1): 2 x TransformerBlock(
       (ln1): LayerNormPre(
         (hook_scale): HookPoint()
         (hook_normalized): HookPoint()
       )
       (ln2): LayerNormPre(
         (hook_scale): HookPoint()
         (hook_normalized): HookPoint()
       )
       (attn): Attention(
         (hook_k): HookPoint()
         (hook_q): HookPoint()
         (hook_v): HookPoint()
         (hook_z): HookPoint()
         (hook_attn_scores): HookPoint()
         (hook_pattern): HookPoint()
         (hook_result): HookPoint()
       )
       (mlp): MLP(
         (hook_pre): HookPoint()
         (hook_post): HookPoint()
       )
       (hook_attn_in): HookPoint()
       (hook_q_input): HookPoint()
       (hook_k_input): HookPoint()
       (hook_v_input): HookPoint()
       (hook_mlp_in): HookPoint()
       (hook_attn_out): HookPoint()
       (ho