In [1]:
%run Definition.ipynb

In [2]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from dawnet.model import ModelRunner

model_id = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
)
model = model.eval()
runner = ModelRunner(model)
print(runner._model)



GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [3]:
handler1 = runner.cache_outputs("transformer.h.10")

In [5]:
# sae = SAEWithL1.load_from_checkpoint("/home/john/dawnet/experiments/lightning_logs/version_26/checkpoints/last.ckpt")
sae = L2DecoderWeightFull.load_from_checkpoint("/home/john/dawnet/experiments/lightning_logs/version_49/checkpoints/last.ckpt")
sae.cuda()

L2DecoderWeightFull(
  (loss): MSELoss()
)

In [6]:
text1 = "The sky is blue"
text2 = "The sun is red"
text3 = "I feel very blue"
text4 = "The sun is blue"
text5 = "The sky is dog"
text6 = "Sometimes, there is a dog going"
text7 = " blue"
text8 = "fuck"

In [7]:
def get(text):
    input_ids = tokenizer.encode("<|endoftext|>"+text, return_tensors="pt").to(model.device)
    print("Input ids:", input_ids)

    with torch.no_grad():
        output = runner(input_ids)

        acts = runner._output["transformer.h.10"][0].squeeze()[-1]
        feat = sae.encode(acts.unsqueeze(dim=0))
        feat = feat.squeeze().cpu().numpy()

    idxs = (feat != 0).nonzero()
    print(idxs)
    print(feat[idxs])
    return set(idxs[0]), feat[idxs], feat

In [8]:
t1, tfi1, tf1 = get(text1)
t2, tfi2, tf2 = get(text2)
t3, tfi3, tf3 = get(text3)
t4, tfi4, tf4 = get(text4)
t5, tfi5, tf5 = get(text5)
t6, tfi6, tf6 = get(text6)
t7, tfi7, tf7 = get(text7)
t8, tfi8, tf8 = get(text8)

Input ids: tensor([[50256,   464,  6766,   318,  4171]], device='cuda:0')
(array([    2,    10,    39, ..., 24556, 24571, 24574]),)
[1.6439335  2.3737485  0.37328902 ... 2.7317746  4.2512474  0.714765  ]
Input ids: tensor([[50256,   464,  4252,   318,  2266]], device='cuda:0')
(array([   10,    58,    62, ..., 24547, 24565, 24571]),)
[0.87956375 3.51472    1.5762739  ... 2.0717318  1.0222964  6.710807  ]
Input ids: tensor([[50256,    40,  1254,   845,  4171]], device='cuda:0')
(array([   15,    27,    28, ..., 24565, 24571, 24574]),)
[2.1132116  0.46094003 0.7298076  ... 0.6349358  7.0152063  3.0297208 ]
Input ids: tensor([[50256,   464,  4252,   318,  4171]], device='cuda:0')
(array([    2,    10,    27, ..., 24565, 24569, 24571]),)
[1.5378329  1.871917   0.977425   ... 0.22279675 0.297614   6.059757  ]
Input ids: tensor([[50256,   464,  6766,   318,  3290]], device='cuda:0')
(array([   10,    28,    31, ..., 24556, 24565, 24567]),)
[0.77650315 1.32107    0.5427977  ... 0.03357542 0.9

In [9]:
len(t1), len(t2), len(t3), len(t4), len(t5), len(t6), len(t7), len(t8)

(1669, 1732, 1854, 1651, 1899, 1915, 1237, 1641)

In [10]:
def feat_diff(feat1, feat2):
    total_item = feat1.union(feat2)
    intersection = feat1.intersection(feat2)
    feat1_has = feat1.difference(feat2)
    feat2_has = feat2.difference(feat1)
    print("Intersection:", len(intersection) / len(total_item), list(sorted(intersection)))
    print("Feat1 has:", len(feat1_has) / len(total_item), list(sorted(feat1_has)))
    print("Feat2 has:", len(feat2_has) / len(total_item), list(sorted(feat2_has)))
    return intersection

In [11]:
int2 = feat_diff(t1, t2)

Intersection: 0.30007645259938837 [10, 58, 62, 143, 176, 200, 299, 374, 375, 377, 383, 447, 479, 535, 601, 607, 608, 620, 668, 682, 794, 806, 850, 885, 907, 909, 927, 982, 1016, 1062, 1095, 1117, 1123, 1125, 1231, 1279, 1294, 1305, 1319, 1343, 1416, 1417, 1451, 1555, 1608, 1623, 1701, 1723, 1737, 1800, 1804, 1855, 1857, 1886, 1943, 1985, 1988, 2019, 2024, 2146, 2189, 2200, 2224, 2232, 2242, 2255, 2316, 2418, 2509, 2527, 2554, 2598, 2614, 2616, 2625, 2637, 2639, 2648, 2650, 2662, 2684, 2736, 2752, 2864, 2880, 2882, 2884, 2886, 2887, 2938, 2996, 3007, 3048, 3106, 3109, 3166, 3189, 3220, 3231, 3238, 3261, 3301, 3302, 3314, 3352, 3388, 3396, 3421, 3442, 3450, 3457, 3468, 3478, 3510, 3519, 3525, 3547, 3669, 3674, 3713, 3738, 3818, 3841, 3870, 3880, 3905, 3913, 3956, 3964, 3976, 3991, 4004, 4016, 4043, 4103, 4174, 4253, 4255, 4268, 4284, 4312, 4329, 4399, 4405, 4428, 4471, 4513, 4531, 4636, 4680, 4744, 4808, 4810, 4822, 4852, 4868, 4927, 4931, 4940, 4950, 5003, 5089, 5091, 5094, 5139, 5170, 

In [12]:
int3 = feat_diff(t1, t3)

Intersection: 0.1910074374577417 [39, 62, 143, 161, 185, 299, 375, 377, 487, 490, 514, 603, 607, 608, 683, 695, 769, 806, 850, 885, 907, 909, 1016, 1017, 1062, 1095, 1177, 1271, 1279, 1319, 1341, 1359, 1417, 1465, 1525, 1608, 1646, 1701, 1718, 1737, 1943, 1985, 2019, 2031, 2189, 2232, 2255, 2308, 2316, 2372, 2434, 2487, 2491, 2520, 2527, 2625, 2639, 2648, 2650, 2662, 2729, 2752, 2831, 2880, 2886, 2938, 2975, 2979, 2991, 3005, 3007, 3048, 3106, 3162, 3189, 3238, 3261, 3302, 3352, 3357, 3373, 3396, 3399, 3433, 3442, 3457, 3478, 3506, 3519, 3525, 3527, 3605, 3606, 3655, 3662, 3669, 3713, 3841, 3870, 3899, 3913, 3964, 3976, 3991, 4004, 4016, 4039, 4043, 4140, 4253, 4284, 4312, 4335, 4425, 4445, 4542, 4667, 4680, 4808, 4837, 4868, 4899, 4927, 4940, 4947, 4950, 4970, 5084, 5089, 5094, 5099, 5139, 5170, 5314, 5452, 5466, 5553, 5567, 5574, 5650, 5713, 5719, 5727, 5774, 5824, 5866, 5874, 6013, 6177, 6203, 6311, 6323, 6542, 6717, 6718, 6733, 6775, 6776, 6801, 6822, 6843, 6922, 7027, 7057, 7083, 

In [13]:
int4 = feat_diff(t1, t4)

Intersection: 0.5257352941176471 [2, 10, 39, 58, 62, 98, 143, 161, 176, 200, 284, 289, 299, 339, 375, 377, 444, 479, 490, 509, 534, 601, 607, 608, 621, 668, 680, 682, 683, 689, 713, 738, 769, 794, 806, 814, 829, 835, 850, 867, 885, 907, 909, 927, 982, 1016, 1062, 1095, 1125, 1231, 1271, 1279, 1294, 1305, 1306, 1319, 1343, 1351, 1416, 1417, 1502, 1525, 1542, 1550, 1608, 1623, 1630, 1646, 1701, 1723, 1737, 1741, 1800, 1834, 1855, 1882, 1886, 1985, 1988, 2019, 2024, 2064, 2146, 2170, 2189, 2200, 2224, 2232, 2242, 2255, 2296, 2316, 2319, 2418, 2445, 2487, 2491, 2494, 2499, 2509, 2527, 2530, 2536, 2554, 2598, 2614, 2616, 2625, 2637, 2639, 2648, 2662, 2666, 2675, 2684, 2736, 2752, 2787, 2838, 2864, 2873, 2880, 2882, 2884, 2886, 2887, 2938, 2975, 2979, 2981, 2996, 3005, 3007, 3048, 3106, 3109, 3127, 3156, 3162, 3166, 3189, 3220, 3238, 3261, 3301, 3302, 3314, 3339, 3373, 3388, 3390, 3396, 3399, 3416, 3421, 3433, 3435, 3450, 3457, 3478, 3506, 3510, 3519, 3525, 3563, 3566, 3606, 3662, 3669, 3674

In [14]:
int5 = feat_diff(t1, t5)

Intersection: 0.09548664415105926 [10, 58, 143, 176, 299, 375, 514, 607, 683, 769, 789, 814, 859, 927, 1016, 1095, 1117, 1271, 1294, 1341, 1416, 1741, 1800, 1985, 2019, 2232, 2242, 2296, 2316, 2396, 2509, 2614, 2616, 2625, 2752, 2831, 2864, 2884, 2887, 3162, 3220, 3301, 3302, 3352, 3357, 3388, 3399, 3416, 3481, 3547, 3655, 3662, 3738, 3964, 4425, 4434, 4471, 4495, 4531, 4744, 4852, 4868, 4927, 5084, 5094, 5099, 5130, 5177, 5223, 5279, 5422, 5466, 5574, 5594, 5669, 5762, 5907, 5937, 5984, 6106, 6121, 6137, 6300, 6313, 6390, 6457, 6606, 6716, 6717, 6728, 6775, 6822, 6979, 7056, 7083, 7085, 7188, 7280, 7308, 7360, 7433, 7466, 7505, 7719, 7728, 7792, 8080, 8290, 8486, 8616, 8743, 8919, 8958, 9085, 9101, 9265, 9438, 9530, 9545, 9588, 9709, 9745, 9772, 9897, 9941, 10039, 10120, 10121, 10210, 10249, 10294, 10369, 10465, 10472, 10478, 10601, 10649, 10803, 10845, 10888, 11211, 11268, 11367, 11389, 11455, 11581, 11626, 11867, 11899, 11953, 11981, 12002, 12085, 12150, 12154, 12218, 12237, 12251, 

In [15]:
int6 = feat_diff(t1, t6)

Intersection: 0.057227138643067846 [2, 61, 98, 502, 682, 806, 1016, 1062, 1123, 1597, 1654, 1664, 1741, 1769, 1882, 1985, 2019, 2308, 2527, 2598, 2645, 2648, 2662, 2675, 2752, 3007, 3220, 3314, 3352, 3357, 3566, 3605, 3669, 3761, 3913, 3964, 3976, 4284, 4443, 4680, 4868, 4931, 4940, 4970, 5005, 5089, 5170, 5293, 5369, 5574, 5727, 5728, 5762, 5907, 5981, 6203, 6313, 6434, 6437, 6467, 6716, 6717, 6733, 6775, 6801, 6822, 6860, 7085, 7719, 7987, 8290, 8429, 8538, 8743, 8836, 9193, 9222, 9545, 9588, 9742, 10256, 10454, 10478, 10649, 10740, 10813, 11063, 11211, 11364, 11379, 11457, 11492, 11581, 11592, 11867, 11894, 12062, 12154, 12218, 12373, 12496, 12544, 12562, 12614, 13132, 13167, 13354, 13507, 13672, 13805, 14430, 14471, 14569, 14680, 14760, 14798, 15125, 15200, 15202, 15204, 15242, 15438, 15464, 15501, 15558, 16165, 16196, 16232, 16239, 16253, 16355, 16383, 16512, 16570, 16597, 16698, 16900, 16996, 17120, 17189, 17589, 17590, 17594, 17820, 17905, 18186, 18202, 18338, 18542, 18665, 1870

In [16]:
int7 = feat_diff(t1, t7)

Intersection: 0.10410334346504559 [2, 39, 185, 200, 299, 377, 509, 514, 534, 621, 695, 769, 850, 885, 1017, 1095, 1177, 1247, 1319, 1542, 1552, 1608, 1855, 1985, 2019, 2232, 2308, 2316, 2396, 2499, 2520, 2657, 2675, 2684, 2736, 2752, 2880, 2991, 3107, 3166, 3373, 3433, 3469, 3527, 3913, 3991, 4004, 4104, 4140, 4268, 4312, 4495, 4545, 4560, 4636, 4667, 4808, 4822, 4868, 4927, 4931, 4947, 4950, 5033, 5084, 5177, 5276, 5567, 5692, 5774, 5874, 6203, 6313, 6323, 6440, 6513, 6614, 6717, 6775, 6822, 6904, 6991, 7076, 7085, 7200, 7270, 7277, 7294, 7306, 7316, 7370, 7466, 7586, 7653, 7719, 7754, 7841, 7888, 7987, 8072, 8080, 8532, 8538, 8732, 8921, 8946, 8956, 9265, 9294, 9459, 9568, 9709, 9978, 10256, 10287, 10301, 10324, 10342, 10369, 10454, 10565, 10722, 10845, 11211, 11364, 11369, 11392, 11527, 11578, 11581, 11651, 11766, 11834, 11880, 11899, 11928, 11969, 12218, 12237, 12252, 12427, 12544, 12561, 12570, 12757, 12877, 12935, 13125, 13167, 13241, 13286, 13294, 13507, 13527, 13554, 13827, 138

In [17]:
int8 = feat_diff(t1, t8)

Intersection: 0.05346912794398472 [2, 39, 143, 176, 695, 829, 848, 1016, 1095, 1834, 1837, 1882, 1985, 2019, 2232, 2396, 2981, 2991, 3005, 3106, 3238, 3373, 3416, 3433, 3457, 3468, 3469, 3525, 3563, 3913, 3964, 4174, 4399, 4495, 4542, 4808, 4868, 4931, 4940, 5005, 5063, 5084, 5129, 5130, 5177, 5276, 5293, 5314, 5594, 5692, 5765, 5932, 5981, 6137, 6192, 6467, 6543, 6717, 6733, 6775, 6909, 7085, 7293, 7432, 7631, 7677, 8080, 8616, 8743, 8838, 9193, 9981, 9994, 10565, 10649, 10773, 11211, 11213, 11581, 11626, 11880, 11899, 11969, 12218, 12237, 12338, 12442, 12484, 12561, 12627, 12757, 12943, 12981, 13167, 13334, 13554, 14093, 14154, 14376, 14420, 14539, 14760, 15001, 15165, 15166, 15256, 15353, 16196, 16271, 16308, 16355, 16407, 16512, 16570, 16597, 16633, 16698, 16725, 17212, 17354, 17378, 17614, 17924, 17985, 18083, 18126, 18167, 18210, 18266, 18338, 18402, 18542, 18612, 18704, 18938, 19159, 19228, 19454, 19788, 19856, 20024, 20030, 20048, 20368, 20501, 20556, 20663, 20693, 20812, 20814

In [88]:
to_check = [int2, int3, int4, int5, int6, int7]
intersect = to_check[0]
for each in to_check[1:]:
    intersect = intersect.intersection(each)
print(list(sorted(intersect)))

[3149, 8483]


In [17]:
len(intersect)

4

In [52]:
print(t2)
print(t2.shape)

[  284  2071  2090  2687  2897  3370  3375  3715  3772  4387  4933  5074
  5341  5912  6251  6774  7424  8355  8372  9254  9791 10107 10163 10207
 10573 10924 11027 11133 11372 12177 13445 14572 15194 15269 15302 15419
 15631 16432 16642 16859 17694 18391 18499 19193 19263 19823 22265 22299
 22490 22947 23300 24332 24356]
(53,)


In [20]:
def investigate(text1):
    input_ids = tokenizer.encode("<|endoftext|>"+text1, return_tensors="pt").to(model.device)
    print("Input ids:", input_ids)

    with torch.no_grad():
        output = runner(input_ids)

        acts = runner._output["transformer.h.10"][0].squeeze()[-1]
        feat = sae.encode(acts.unsqueeze(dim=0))
        feat = feat.squeeze().cpu().numpy()

    idxs = (feat != 0).nonzero()
    print(idxs)
    print(feat[idxs])

    return acts, feat

In [21]:
acts1, feat1 = investigate(text1)

Input ids: tensor([[50256,   464,  6766,   318,  4171]], device='cuda:0')
(array([   54,   305,  3149,  4782,  4812,  5002,  5199,  5501,  6785,
        7682,  8483,  8981, 11586, 11846, 12383, 12848, 13304, 13541,
       14103, 15386, 15620, 15867, 15883, 16526, 17853, 19566, 19996,
       20140, 20173, 21003, 22325, 22747, 22883, 23091, 23868, 24135]),)
[ 6.3777604   1.3098043  16.978262    2.0895884   2.2118785   3.028096
  7.2511697   5.1198187   1.7559111  18.820425    0.5514142   2.875419
 26.801739    1.8400564   0.35215202  5.2779455   0.971651   27.787973
  1.0997177   0.94836706  3.2699153   1.1426216  23.224455    1.181668
  0.7635544   3.5243092   1.3847507   0.5888878   1.1456937   1.6065216
  0.49929056  1.5055027  10.849372    1.7046075   4.2688594   3.3106172 ]


In [32]:
acts2, feat2 = investigate(text4)

Input ids: tensor([[50256,   464,  4252,   318,  4171]], device='cuda:0')
(array([   54,  2556,  3149,  4772,  4812,  5199,  5501,  6454,  6764,
        6785,  7682,  8483,  8981,  9143, 11586, 12572, 12848, 13304,
       13541, 13585, 14103, 15620, 15867, 15883, 16526, 18470, 19996,
       20027, 20140, 20173, 21003, 21253, 22883, 23091, 24076, 24135]),)
[ 0.5982488   2.2331626  12.675084    0.38048583  1.507112    0.0599636
  5.093863    2.4953861   1.0038148   2.4169827  18.393347    1.8376398
  5.2430606   0.5954744  25.534533    2.036723    3.6063006   3.6423447
 28.925066    3.2233129   0.37607336  4.9857507   0.6715972  25.493914
  0.75199765  1.1644429   0.37116593  0.5231185   1.7070831   2.4725382
  2.8669686   2.3166056  14.050442    2.2471902   1.3746876   0.34518868]


In [92]:
acts7, feat7 = investigate(text7)

Input ids: tensor([[50256,  4171]], device='cuda:0')
(array([  155,   305,   760,  2659,  2834,  3149,  3849,  4009,  4883,
        5094,  5261,  5502,  6052,  8319,  8483,  8570,  8694, 11586,
       11810, 12266, 12383, 12634, 13541, 13689, 13803, 14143, 15293,
       18123, 18246, 18470, 18802, 20356, 20671, 21162, 22137, 22325,
       23319, 23328, 23900, 24135, 24310]),)
[3.11422050e-02 9.52416515e+00 3.77603507e+00 1.85582995e+00
 1.42247400e+01 2.97373867e+01 6.31513745e-02 1.48426294e-02
 1.78389058e-01 1.14564157e+00 2.50294745e-01 1.04398990e+00
 1.62452266e-01 7.96857059e-01 1.09593725e+00 2.65095663e+00
 1.23478889e+01 4.76539224e-01 1.11205244e+00 7.30951965e-01
 1.09978586e-01 2.41723442e+01 3.77602615e+01 9.03683662e-01
 7.82823384e-01 7.02804089e+00 3.63073945e-02 8.08034539e-01
 3.16053838e-01 1.38868380e+01 1.15289807e-01 3.03316236e+00
 5.20504415e-01 4.18909401e-01 2.18765450e+00 2.74508405e+00
 1.40655985e+01 1.58566749e+00 1.09261203e+00 4.45262003e+00
 8.62535059

In [27]:
acts1.shape, acts4.shape

(torch.Size([768]), torch.Size([768]))

In [75]:
print("Act cosine:", nn.functional.cosine_similarity(acts1.unsqueeze(dim=0), acts2.unsqueeze(dim=0)))
print("Act cosine:", np.dot(acts1.cpu().numpy(), acts2.cpu().numpy()) / (np.linalg.norm(acts1.cpu().numpy()) * np.linalg.norm(acts2.cpu().numpy())))

Act cosine: tensor([0.9368], device='cuda:0')
Act cosine: 0.9368108


In [76]:
def similarity(act1, feat1, act2, feat2):
    print("Act MSE:", ((act1 - act2) ** 2).mean())
    print("Act cosine:", nn.functional.cosine_similarity(act1.unsqueeze(dim=0), act2.unsqueeze(dim=0)))
    print("Feat MSE:", ((feat1 - feat2) ** 2).mean())
    print("Feat cosine:", np.dot(feat1, feat2) / (np.linalg.norm(feat1) * np.linalg.norm(feat2)))

In [77]:
similarity(acts1, feat1, acts2, feat2)

Act MSE: tensor(9.3173, device='cuda:0')
Act cosine: tensor([0.9368], device='cuda:0')
Feat MSE: 0.10331639
Feat cosine: 0.58516407


In [78]:
similarity(acts1, feat1, acts3, feat3)

Act MSE: tensor(14.8855, device='cuda:0')
Act cosine: tensor([0.8938], device='cuda:0')
Feat MSE: 0.08940496
Feat cosine: 0.6052401


In [79]:
similarity(acts1, feat1, acts4, feat4)

Act MSE: tensor(2.5950, device='cuda:0')
Act cosine: tensor([0.9818], device='cuda:0')
Feat MSE: 0.010106247
Feat cosine: 0.95898825


In [80]:
similarity(acts1, feat1, acts5, feat5)

Act MSE: tensor(32.3265, device='cuda:0')
Act cosine: tensor([0.7686], device='cuda:0')
Feat MSE: 0.27190927
Feat cosine: 0.0949504


In [81]:
similarity(acts1, feat1, acts6, feat6)

Act MSE: tensor(39.4036, device='cuda:0')
Act cosine: tensor([0.7334], device='cuda:0')
Feat MSE: 0.23914807
Feat cosine: 0.10611745


In [93]:
similarity(acts1, feat1, acts7, feat7)

Act MSE: tensor(22.0498, device='cuda:0')
Act cosine: tensor([0.8605], device='cuda:0')
Feat MSE: 0.15096694
Feat cosine: 0.46576208


In [56]:
nn.functional.cosine_similarity(acts1.unsqueeze(dim=0), acts7.unsqueeze(dim=0))

tensor([0.7334], device='cuda:0')

In [38]:
idxs = (feat != 0).nonzero()
print(idxs)
print(feat[idxs])

(array([  284,  2071,  2090,  2687,  2897,  3370,  3375,  3715,  3772,
        4128,  4387,  4745,  4933,  5074,  5336,  5341,  5634,  5912,
        6251,  6774,  7424,  8355,  8372,  9254,  9353,  9791, 10107,
       10163, 10207, 10573, 10924, 11027, 11133, 11372, 11879, 12177,
       12189, 13445, 13780, 14572, 14655, 15194, 15269, 15302, 15419,
       15631, 16432, 16642, 16859, 17694, 18391, 18499, 18658, 19263,
       22265, 22299, 22490, 22947, 23300, 24332, 24356]),)
[2.423076   0.8249892  0.19186723 0.27324545 1.6404442  0.9554368
 0.22211158 0.47308585 0.10743587 0.35749412 0.2538712  0.08809737
 1.6620847  0.7667606  0.41524723 0.90271395 0.01954544 0.03700136
 0.38561237 1.0759873  0.01548168 0.4221773  0.44862297 0.6608852
 0.2803075  0.16518445 0.50297403 0.1111737  0.36621413 0.66537726
 0.8776589  1.0969954  0.9856529  0.6027818  0.16524598 1.5779285
 0.3279565  0.41394776 0.03233503 0.4117829  0.62311333 0.42536527
 0.91049016 0.46033752 0.4836884  0.33684498 0.0663612

In [39]:
input_ids = tokenizer.encode("<|endoftext|>"+text2, return_tensors="pt").to(model.device)
print("Input ids:", input_ids)

with torch.no_grad():
    output = runner(input_ids)

    acts = runner._output["transformer.h.10"][0].squeeze()[-1]
    feat = sae.encode(acts.unsqueeze(dim=0))
    feat = feat.squeeze().cpu().numpy()

idxs = (feat != 0).nonzero()
print(idxs)
print(feat[idxs])

Input ids: tensor([[50256,   464,  4252,   318,  2266]], device='cuda:0')
(array([  284,  2071,  2090,  2687,  2897,  3370,  3375,  3715,  3772,
        4387,  4933,  5074,  5341,  5912,  6251,  6774,  7424,  8355,
        8372,  9254,  9791, 10107, 10163, 10207, 10573, 10924, 11027,
       11133, 11372, 12177, 13445, 14572, 15194, 15269, 15302, 15419,
       15631, 16432, 16642, 16859, 17694, 18391, 18499, 19193, 19263,
       19823, 22265, 22299, 22490, 22947, 23300, 24332, 24356]),)
[2.0284038  1.2658823  0.11693037 0.22030735 1.4555813  0.5612621
 0.21379471 0.6794169  0.18634202 0.36882305 1.6883729  0.71657604
 1.3570757  0.16324691 0.36243808 1.0410396  0.03622173 0.34265715
 0.73360455 0.28753704 0.16668846 0.4607747  0.16775732 0.24128908
 0.4485487  0.68137664 0.7856412  1.1274025  0.62342066 1.4122851
 0.11364459 0.35753712 0.27992207 0.69837713 0.50452495 0.26336506
 0.21755058 0.0123381  0.30366915 1.5134838  0.48188084 0.39419085
 0.01330312 0.18634865 0.35071725 0.887250

In [41]:
input_ids = tokenizer.encode("<|endoftext|>"+text3, return_tensors="pt").to(model.device)
print("Input ids:", input_ids)

with torch.no_grad():
    output = runner(input_ids)

    acts = runner._output["transformer.h.10"][0].squeeze()[-1]
    feat = sae.encode(acts.unsqueeze(dim=0))
    feat = feat.squeeze().cpu().numpy()

idxs = (feat != 0).nonzero()
print(idxs)
print(feat[idxs])

Input ids: tensor([[50256,    40,  1254,   845,  4171]], device='cuda:0')
(array([  284,  1461,  2071,  2090,  2687,  2897,  3051,  3370,  3375,
        3567,  3715,  3772,  4933,  5074,  5336,  5341,  5912,  6251,
        6774,  8355,  8372,  9254,  9791, 10107, 10207, 10573, 10924,
       11027, 11133, 11372, 11879, 12177, 12514, 13445, 14572, 14655,
       15194, 15269, 15302, 15419, 15631, 16642, 16859, 17339, 18499,
       19193, 19263, 19823, 22265, 22299, 22411, 22490, 22947, 23300,
       24332, 24356]),)
[1.74757695e+00 2.49852613e-02 1.01016474e+00 3.43876868e-01
 1.82686150e-01 7.74324179e-01 4.10946935e-01 6.77367806e-01
 3.46389174e-01 1.02291144e-01 7.32578576e-01 6.53356910e-02
 1.65316486e+00 6.85901582e-01 2.74194628e-01 5.02825797e-01
 9.48872119e-02 2.44514868e-01 6.90216720e-01 1.58180833e-01
 2.55117089e-01 3.68986309e-01 1.06268093e-01 3.52675498e-01
 2.26888806e-03 2.96502680e-01 4.89511490e-01 7.12561250e-01
 1.18790329e+00 8.16677094e-01 7.86374733e-02 9.942096

In [42]:
type(tokenizer)

transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast

In [97]:
(torch.zeros(10) != 0).nonzero().tolist()

[]