In [1]:
%load_ext autoreload
%autoreload 2
import torch as th

batch_size = 10
layers = 2
d_model = 108
d_sae = 1016
batch = th.randint(0, 100, (batch_size, layers, d_model))
linear = th.randint(0, 100, (layers, d_model, d_sae))

sae_activations = th.einsum("bld, ldD -> blD", batch, linear)

print(sae_activations.shape)

torch.Size([10, 2, 1016])


In [2]:
assert sae_activations.shape == (batch_size, layers, d_sae)
th.testing.assert_close(sae_activations[:, 0], batch[:, 0] @ linear[0])
th.testing.assert_close(sae_activations[:, 1], batch[:, 1] @ linear[1])

In [3]:
linear_decoder = th.randint(0, 100, (layers, d_sae, d_model))
reconstructed = th.einsum("blD, lDd -> bld", sae_activations, linear_decoder)

print(reconstructed.shape)
assert reconstructed.shape == (batch_size, layers, d_model)
th.testing.assert_close(reconstructed[:, 0], sae_activations[:, 0] @ linear_decoder[0])
th.testing.assert_close(reconstructed[:, 1], sae_activations[:, 1] @ linear_decoder[1])

torch.Size([10, 2, 108])


## Test the CrossCoder

In [4]:
import torch
import pytest
from tempfile import NamedTemporaryFile
from dictionary import CrossCoder


def test_crosscoder_save_load():
    # Create a CrossCoder instance
    activation_dim = 64
    dict_size = 32
    num_layers = 4
    crosscoder = CrossCoder(activation_dim, dict_size, num_layers)

    # Save the CrossCoder to a temporary file
    with NamedTemporaryFile(suffix=".pt", delete=False) as tmp_file:
        torch.save(crosscoder.state_dict(), tmp_file.name)
        tmp_file_name = tmp_file.name

    # Load the CrossCoder from the temporary file
    loaded_crosscoder = CrossCoder.from_pretrained(tmp_file_name)

    # Check if the loaded CrossCoder has the same parameters
    for param1, param2 in zip(crosscoder.parameters(), loaded_crosscoder.parameters()):
        assert torch.allclose(
            param1, param2
        ), "Loaded CrossCoder parameters do not match the original"


def test_crosscoder_forward_pass(**kwargs):
    # Create a CrossCoder instance
    activation_dim = 64
    dict_size = 32
    num_layers = 4
    crosscoder = CrossCoder(activation_dim, dict_size, num_layers, **kwargs)

    # Create random input tensor
    batch_size = 16
    x = torch.randn(batch_size, num_layers, activation_dim)

    # Perform forward pass without output_features
    x_hat = crosscoder(x)
    assert x_hat.shape == (
        batch_size,
        num_layers,
        activation_dim,
    ), "Output shape mismatch"

    # Perform forward pass with output_features
    x_hat, f_scaled = crosscoder(x, output_features=True)
    assert x_hat.shape == (
        batch_size,
        num_layers,
        activation_dim,
    ), "Output shape mismatch"
    assert f_scaled.shape == (batch_size, dict_size), "Feature shape mismatch"

    # Check if the output is different from the input (i.e., some transformation occurred)
    assert not torch.allclose(x, x_hat), "Output is identical to input"

In [5]:
test_crosscoder_forward_pass()
test_crosscoder_save_load()

In [16]:
test_crosscoder_forward_pass(same_init_for_all_layers=True)

In [17]:
test_crosscoder_forward_pass(norm_init_scale=0.005)

In [18]:
test_crosscoder_forward_pass(init_with_transpose=False)


In [6]:
from dictionary import CrossCoderEncoder, CrossCoderDecoder, CrossCoder

d_model = 108
d_sae = 1016
num_layers = 2
encoder = CrossCoderEncoder(d_model, d_sae, num_layers)
decoder = CrossCoderDecoder(d_model, d_sae, num_layers)

batch = th.randn(batch_size, num_layers, d_model)

In [7]:
f = encoder(batch)
assert f.shape == (batch_size, d_sae)

reconstructed = decoder(f)
assert reconstructed.shape == (batch_size, num_layers, d_model)

In [8]:
crosscoder = CrossCoder(d_model, d_sae, num_layers)
reconstructed_cc = crosscoder(batch)
assert reconstructed_cc.shape == (batch_size, num_layers, d_model)


In [9]:
print(crosscoder.decoder.weight.shape)
print(crosscoder.decoder.weight.norm(dim=2).shape)
print(crosscoder.decoder.weight.norm(dim=2).sum(dim=0, keepdim=True).shape)
print(f.shape)

torch.Size([2, 1016, 108])
torch.Size([2, 1016])
torch.Size([1, 1016])
torch.Size([10, 1016])


In [10]:
(f * crosscoder.decoder.weight.norm(dim=2).sum(dim=0, keepdim=True)).shape

torch.Size([10, 1016])

In [11]:
f.shape

torch.Size([10, 1016])

In [12]:
weight = th.randn(d_model, d_sae)
weight = weight.repeat(num_layers, 1, 1)
weight.shape


torch.Size([2, 108, 1016])

In [13]:
weight[0, :, :] = 0
print(weight[0, :, :].norm())
assert weight[1, :, :].norm() != 0


tensor(0.)


In [14]:
weight = weight / weight.norm(dim=1, keepdim=True)

In [15]:
weight.norm(dim=1, keepdim=True)

tensor([[[   nan,    nan,    nan,  ...,    nan,    nan,    nan]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]]])