In [1]:
from transformer_algebra import load_pythia_model, PromptedTransformer, expand
import torch

model, tokenizer = load_pythia_model("EleutherAI/pythia-14m")
T = PromptedTransformer(model, tokenizer, "Hello")
x = T(" world")

# Level 1 - should work
level1 = expand(x)
level1


The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


embed(' world') + ΔB^0 + ΔB^1 + ΔB^2 + ΔB^3 + ΔB^4 + ΔB^5

In [2]:
level1.terms

[embed(' world'), ΔB^0, ΔB^1, ΔB^2, ΔB^3, ΔB^4, ΔB^5]

In [3]:
(type(x),x.tensor)

(transformer_algebra.core.ResidualVector,
 tensor([ 1.2352,  1.1555, -1.3898, -1.2035, -0.9764,  1.2874,  1.1204,  0.8023,
          0.9848, -1.2769, -1.3887, -0.9518,  1.1225,  1.0903, -0.9513, -1.1463,
         -1.4768,  1.0472,  1.1267,  0.8227, -1.1667,  1.1461,  1.2543, -1.2231,
          1.0704,  1.0590,  1.4470,  1.3633,  1.0963,  1.0429, -1.1218,  0.4890,
          1.1295,  0.9831,  1.2395, -1.1499, -1.0796,  1.0946, -0.4139,  1.3139,
          1.1675,  1.0662,  0.7090,  1.2749,  1.1652,  0.7929, -1.3682,  0.7762,
         -0.9647, -1.0131,  1.2733, -1.5118,  1.1571,  1.0614, -1.2359,  1.3117,
         -0.9022, -1.1848,  1.4249, -0.9446, -1.1270, -1.3343,  1.3980, -1.1564,
          1.2109, -1.2323,  1.4454,  0.9781,  1.0751, -0.8479, -1.2567,  0.9115,
          1.1432, -1.3322, -1.0355, -1.4295, -0.9601,  1.2278,  1.3536, -1.4367,
          1.2473,  1.1206,  1.0495,  0.7398, -1.2135,  1.1060,  1.3140, -0.9931,
         -1.3243,  1.1408, -1.5610, -0.9008, -1.4434,  1.1511, -1.2

In [4]:
for i, term in enumerate(level1.terms):
    if hasattr(term, 'expand'):
        block_exp = term.expand()
        matches = torch.allclose(block_exp.tensor, term.tensor, atol=1e-5)
        print(f"Block {i-1} expand matches: {matches}")
        if not matches:
            diff = (block_exp.tensor - term.tensor).abs().max()
            print(f"  Max diff: {diff.item():.6f}")

Block 0 expand matches: True
Block 1 expand matches: True
Block 2 expand matches: True
Block 3 expand matches: True
Block 4 expand matches: True
Block 5 expand matches: True


In [5]:

level2 = level1.expand()
print(f"\nLevel 2 sum matches: {torch.allclose(level2.tensor, x.tensor, atol=1e-5)}")
print(f"Level 2 sum norm: {level2.tensor.norm():.2f}")
print(f"Original norm: {x.tensor.norm():.2f}")



Level 2 sum matches: True
Level 2 sum norm: 13.31
Original norm: 13.31


In [6]:
level2

embed(' world') + ΔB^0_A + ΔB^0_M + ΔB^1_A + ΔB^1_M + ΔB^2_A + ΔB^2_M + ΔB^3_A + ΔB^3_M + ΔB^4_A + ΔB^4_M + ΔB^5_A + ΔB^5_M

In [7]:
level1

embed(' world') + ΔB^0 + ΔB^1 + ΔB^2 + ΔB^3 + ΔB^4 + ΔB^5

In [8]:
level1[0]

embed(' world')

In [9]:
block1 = level1[1].expand()
block1

ΔB^0_A + ΔB^0_M

In [12]:
block1[0]

ΔB^0_A

In [13]:
block1[0].tensor

tensor([-0.0489, -0.1850, -0.0565, -0.1207,  0.0055,  0.0788,  0.0243, -0.0319,
         0.0097,  0.0055, -0.0314,  0.0166, -0.0975,  0.0915, -0.0050, -0.0381,
        -0.0455,  0.0695,  0.0456,  0.1058,  0.0088, -0.0012, -0.0691,  0.0243,
        -0.0332,  0.0195, -0.0540,  0.0102,  0.1035,  0.0240, -0.0762, -0.0752,
         0.0327, -0.0476,  0.1197, -0.0551,  0.0858,  0.0542, -0.0450,  0.2256,
         0.0652, -0.0014, -0.0099, -0.0185, -0.0984,  0.0164,  0.0586,  0.0208,
         0.0695,  0.0676,  0.1302,  0.0494, -0.0075, -0.1021, -0.0482,  0.0105,
         0.0238, -0.0330, -0.1027,  0.1051,  0.0234, -0.0476,  0.0073, -0.0388,
        -0.0091, -0.0282, -0.1024,  0.1642,  0.2296, -0.0280,  0.1279, -0.0229,
         0.0140,  0.0808, -0.0271,  0.0248,  0.0083,  0.0240,  0.0377, -0.0225,
         0.0614,  0.0284,  0.0035, -0.0402, -0.0105, -0.0082,  0.0242, -0.0388,
        -0.0883, -0.0225, -0.0121, -0.0088, -0.0344,  0.0575,  0.0066,  0.0483,
         0.0002,  0.0353, -0.0236,  0.06

In [15]:
block1[0].tensor.shape

torch.Size([128])

In [15]:
(x._hidden_states[0].shape,x._hidden_states[1].shape,x._hidden_states[2].shape,x._hidden_states[3].shape)

(torch.Size([1, 2, 128]),
 torch.Size([1, 2, 128]),
 torch.Size([1, 2, 128]),
 torch.Size([1, 2, 128]))

In [16]:
x.model

AttributeError: 'ResidualVector' object has no attribute 'model'