<a href="https://colab.research.google.com/github/tomonari-masada/course2025-nlp/blob/main/superposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://transformer-circuits.pub/2022/toy_model/

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

np.random.seed(0)
torch.manual_seed(0)

In [None]:
class SuperpositionModel(nn.Module):
    def __init__(self, dim_n, dim_m):
        super(SuperpositionModel, self).__init__()
        self.dim_n = dim_n
        self.dim_m = dim_m
        self.W = nn.Parameter(torch.randn(dim_n, dim_m))
        self.b = nn.Parameter(torch.zeros(dim_n))

    def forward(self, x):
        h = torch.matmul(x, self.W)
        x_reconstructed = torch.relu(torch.matmul(h, self.W.t()) + self.b)
        return x_reconstructed, h

In [None]:
def training(model, optimizer, sparsity=0.0, num_epochs=100000, num_data=10000, batch_size=100):
    dim_n = model.dim_n
    loss_weights = torch.pow(0.7, torch.arange(dim_n).float())
    print(f'Dimension of input: {dim_n}')
    print(f'Loss weights: {loss_weights}')

    # 高次元の特徴量は同時に少数しかactiveにならないようにする
    X = torch.rand(num_data, dim_n)
    zero_mask = (torch.rand_like(X) >= sparsity).float()
    X = X * zero_mask

    print("Starting training...")
    for epoch in range(num_epochs):
        x = X[torch.randint(0, num_data, (batch_size,))]
        model.train()
        optimizer.zero_grad()
        x_reconstructed, h = model(x)
        loss = ((x_reconstructed - x) ** 2 * loss_weights).mean()
        loss.backward()
        optimizer.step()
        if epoch % 10000 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

In [None]:
def evaluate(model):
    with torch.no_grad():
        feat_norm = torch.norm(model.W, dim=1).to('cpu').numpy()
        print("Feature norms:", feat_norm)
        superposition = ((model.W @ model.W.t()) ** 2 * (1.0 - torch.eye(model.dim_n))).sum(dim=1)
        print("Superposition:", superposition)
        return feat_norm, superposition.to('cpu').numpy()

In [None]:
def visualize(feature_norm, superposition):
    fig, ax = plt.subplots(1, 2, figsize=(10, 6), gridspec_kw={'width_ratios': [4, 0.2]})
    y_pos = range(len(feature_norm))
    colors = plt.cm.viridis(superposition)
    ax[0].barh(y_pos[::-1], feature_norm, color=colors)
    ax[0].set_yticks(y_pos)
    ax[0].set_yticklabels([f'Feature {i}' for i in range(len(feature_norm) - 1, -1, -1)])
    ax[0].set_xlabel('Feature Norm')
    ax[0].set_title('Feature Norms and Superposition')
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = ax[1].figure.colorbar(sm, cax=ax[1])
    cbar.set_label('Superposition')
    plt.gca().invert_yaxis()
    plt.show()

In [None]:
dim_n = 10
dim_m = 2

model = SuperpositionModel(dim_n, dim_m)
optimizer = optim.Adam(model.parameters(), lr=0.001)
training(model, optimizer, sparsity=0.99, num_data=10000, batch_size=100)
feature_norm, superposition = evaluate(model)
visualize(feature_norm, superposition)

In [None]:
plt.imshow((model.W @ model.W.t()).detach().to('cpu').numpy(), aspect=1, cmap='viridis')
plt.colorbar(label='Inner Product')
plt.title('Feature Inner Products')
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
# plot the features in 2D space as arrows
W_cpu = model.W.detach().to('cpu').numpy()
origin = np.zeros((dim_n, 2))
plt.quiver(*origin.T, W_cpu[:, 0], W_cpu[:, 1], angles='xy', scale_units='xy', scale=1)
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.title('Features in 2D Space')
plt.grid()
plt.gca().set_aspect('equal')
plt.show()