In [4]:
import warnings
warnings.filterwarnings('ignore')

import torch

from transformer_lens import FactoredMatrix

In [2]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using device: {device}')

Using device: mps


# 1. 

In [5]:
A = torch.randn(5, 2)
B = torch.randn(2, 5)
AB = A @ B
AB_factor = FactoredMatrix(A, B)
print("Norms:")
print(AB.norm())
print(AB_factor.norm())

print(f"Right dimension: {AB_factor.rdim}, Left dimension: {AB_factor.ldim}, Hidden dimension: {AB_factor.mdim}")

Norms:
tensor(11.3166)
tensor(11.3166)
Right dimension: 5, Left dimension: 5, Hidden dimension: 2


In [6]:
print("Eigenvalues:")
print(torch.linalg.eig(AB).eigenvalues)
print(AB_factor.eigenvalues)

print("\nSingular Values:")
print(torch.linalg.svd(AB).S)
print(AB_factor.S)

print("\nFull SVD:")
print(AB_factor.svd())

Eigenvalues:
tensor([-3.8594e-06+0.0000e+00j, -3.9839e-01+3.2362e-01j,
        -3.9839e-01-3.2362e-01j, -2.0023e-08+4.8973e-08j,
        -2.0023e-08-4.8973e-08j])
tensor([-0.3984+0.3236j, -0.3984-0.3236j])

Singular Values:
tensor([1.1045e+01, 2.4659e+00, 1.9530e-07, 7.8354e-08, 3.3726e-09])
tensor([11.0446,  2.4659])

Full SVD:
(tensor([[ 0.3214,  0.2192],
        [ 0.2939, -0.9133],
        [-0.6548, -0.1536],
        [-0.3436,  0.1314],
        [ 0.5133,  0.2776]]), tensor([11.0446,  2.4659]), tensor([[-0.6408, -0.4334],
        [-0.0991,  0.6347],
        [ 0.1738, -0.2532],
        [-0.6847,  0.4198],
        [ 0.2837,  0.4111]]))


In [7]:
C = torch.randn(5, 300)
ABC = AB @ C
ABC_factor = AB_factor @ C

print(f"Unfactored: shape={ABC.shape}, norm={ABC.norm()}")
print(f"Factored: shape={ABC_factor.shape}, norm={ABC_factor.norm()}")
print(f"\nRight dimension: {ABC_factor.rdim}, Left dimension: {ABC_factor.ldim}, Hidden dimension: {ABC_factor.mdim}")

Unfactored: shape=torch.Size([5, 300]), norm=194.84153747558594
Factored: shape=torch.Size([5, 300]), norm=194.84152221679688

Right dimension: 300, Left dimension: 5, Hidden dimension: 2


In [8]:
AB_unfactored = AB_factor.AB
torch.testing.assert_close(AB_unfactored, AB)

# 2. Reverse-engineering circuits

# Sources

1. [Ground truth - Reverse-engineering induction circuits, by ARENA](https://arena-chapter1-transformer-interp.streamlit.app/[1.2]_Intro_to_Mech_Interp)