In [None]:
#!export CUDA_VISIBLE_DEVICES=7

## Setup

In [1]:
#import pandas as pd
import torch

In [2]:
import datasets

In [25]:
from transformers import AutoTokenizer

In [3]:
from transformer_lens import HookedTransformer

In [4]:
from entropy.compare import unflattened_data

In [5]:
from neuron_choice import neuron_choice
from entropy.entropy_intervention import make_hooks
from argparse import Namespace
from utils import NAME_TO_COMBO

In [29]:
from utils import neuron_analysis

## Constants that define the experiment

In [6]:
INTERVENTION_TYPE = "zero_ablation"
METRIC = "entropy"
NEURON_SUBSET = "weakening_gate-_post+"

In [7]:
DATA_PATH = "intervention_results/allenai/OLMo-7B-0424-hf/dolma-small"

In [8]:
EXTREMUM = "min"

## Preparing the data

In [9]:
diff_data = unflattened_data(DATA_PATH, METRIC, NEURON_SUBSET, INTERVENTION_TYPE)

In [10]:
diff_data.shape

torch.Size([45734, 1024])

## Choosing example

In [None]:
#torch.argmin(diff_data)

tensor(31864857)

In [14]:
vi = torch.min(diff_data, dim=0)

In [15]:
vi.indices.shape

torch.Size([1024])

In [16]:
torch.argmin(vi.values)

tensor(25)

In [17]:
vi.values[25]

tensor(-10.7500, dtype=torch.float16)

In [18]:
vi.indices[25]

tensor(31118)

In [19]:
diff_data[31118, 25]

tensor(-10.7500, dtype=torch.float16)

Okay, so weakening neurons provoke the biggest decrease of entropy at sequence 31118, position 25. Let us see what the text says:

In [11]:
text_dataset = datasets.load_from_disk('neuroscope/datasets/dolma-small')

In [22]:
text_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 45734
})

In [12]:
text_dataset[31118]

{'input_ids': tensor([50279, 43688,   313,  1797,  4565,    10,   253,  7295,  6138,   247,
          5522,   273,  1329,   323, 34598,   285, 27335,  9341,   326,   403,
         10305,  5454,   984,   273,   253,   473,  6185,  1406, 12955,    13,
           347,   973,   347,   247,   747, 34220, 14117, 11726,  1467,   621,
         13629,   313, 24584,    10,  6974,   281,  1329,   643,  9341,    13,
          1754,   327,  1980,  5054,   878,    15,   187,  6872, 15849,   588,
           320, 11966,   407,  1980,  9061,    15,   187,    38,  3354,  4412,
          6456,   310,  4390, 31288, 12925,   285,  1491,   670, 26281,   432,
           253,  7295,   285,   588,   452,   271,  2898,  1232,   873,   598,
           347,  3517,   347,  1896,    15,   187]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [13]:
my_input_ids = text_dataset[31118]['input_ids']

In [14]:
model = HookedTransformer.from_pretrained('allenai/OLMo-7B-0424-hf', refactor_glu=True, device='cuda:6')

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loaded pretrained model allenai/OLMo-7B-0424-hf into HookedTransformer


In [15]:
model.to_str_tokens(text_dataset[31118]['input_ids'])

['<|endoftext|>',
 'Yesterday',
 ' (',
 '21',
 ' December',
 ')',
 ' the',
 ' Government',
 ' announced',
 ' a',
 ' package',
 ' of',
 ' support',
 ' for',
 ' hospitality',
 ' and',
 ' leisure',
 ' businesses',
 ' that',
 ' are',
 ' losing',
 ' trade',
 ' because',
 ' of',
 ' the',
 ' O',
 'mic',
 'ron',
 ' variant',
 ',',
 ' as',
 ' well',
 ' as',
 ' a',
 ' new',
 ' discretionary',
 ' Additional',
 ' Rest',
 'rict',
 'ions',
 ' Grant',
 ' (',
 'ARG',
 ')',
 ' scheme',
 ' to',
 ' support',
 ' other',
 ' businesses',
 ',',
 ' based',
 ' on',
 ' local',
 ' economic',
 ' need',
 '.',
 '\n',
 'These',
 ' schemes',
 ' will',
 ' be',
 ' administered',
 ' by',
 ' local',
 ' authorities',
 '.',
 '\n',
 'E',
 'den',
 ' District',
 ' Council',
 ' is',
 ' currently',
 ' awaiting',
 ' guidance',
 ' and',
 ' information',
 ' about',
 ' eligibility',
 ' from',
 ' the',
 ' Government',
 ' and',
 ' will',
 ' have',
 ' an',
 ' application',
 ' process',
 ' set',
 ' up',
 ' as',
 ' soon',
 ' as',
 ' pos

In [29]:
len(text_dataset[31118]['input_ids'])

96

In [17]:
model.to_single_str_token(my_input_ids[26].item())

'mic'

In [18]:
model.to_tokens(' Omicron')

tensor([[50279,   473,  6185,  1406]], device='cuda:6')

In [19]:
print(diff_data[31118][:97])

tensor([ -0.0625,  -0.2441,  -2.4375,  -1.4219,  -0.2852,  -0.4121,   0.0352,
         -1.8320,  -1.8574,  -0.7656,  -0.6880,   0.2051,  -0.2881,  -1.3828,
         -0.9590,  -6.9531,   0.2861,  -1.7871,  -2.4844,  -1.9062,  -3.3203,
         -4.2227,  -3.4648,  -4.3438,  -2.7266, -10.7500,  -9.3359,  -1.5205,
         -0.4658,  -0.5234,  -1.0488,  -0.0432,  -1.0781,  -0.7148,  -0.9453,
          0.4102,  -1.9238,  -1.4707,  -2.0195,  -0.6436,  -1.0977,  -1.2080,
         -2.2520,  -1.6289,  -0.3613,   0.0996,  -1.6465,   0.2539,   0.4902,
         -0.7598,  -0.3340,  -1.9570,  -1.7227,   1.8818,  -0.3433,   0.2520,
         -1.3320,  -0.0176,  -0.7168,  -1.0020,  -1.8008,  -0.9746,  -0.7549,
         -2.1680,  -0.3379,  -0.4785,  -1.1602,  -0.0322,   1.3242,  -6.1289,
         -2.2031,  -0.8203,  -1.6133,  -0.1504,   0.2500,  -1.9395,   0.0586,
         -1.6250,  -0.1592,  -2.3535,  -3.2656,  -1.1445,  -3.1777,  -1.7031,
         -1.8994,  -3.1523,  -2.2051,  -1.5625,  -0.7256,  -0.38

## List of weakening neurons

In [20]:
args = Namespace(
    work_dir='.', wcos_dir='.',
    model='allenai/OLMo-7B-0424-hf',
    neuron_subset_name="weakening",
    gate='-', post='+',
    activation_location='mlp.hook_post',
    intervention_type=INTERVENTION_TYPE,
)

In [21]:
neuron_list = neuron_choice(
            args,
            category_key=NAME_TO_COMBO[args.neuron_subset_name],
            subset=243,
            baseline=False
        )
#print(neuron_list)

## Running the model on the example

In [41]:
logits, cache = model.run_with_cache(my_input_ids)

In [45]:
logits.shape

torch.Size([1, 96, 50304])

In [46]:
cache.keys()

dict_keys(['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.m

In [61]:
hooks = []
conditioning_values = {}
for lix, nix in neuron_list:
    conditioning_values[(lix,nix)]={}
    hooks += make_hooks(
        args,
        lix, nix,
        conditioning_value=conditioning_values[(lix,nix)],
        sign=(model.W_gate[lix,:,nix]@model.W_in[lix,:,nix]).item(),
        mean_value=0.0,
    )
#print(hooks)

In [62]:
logits_ablated = model.run_with_hooks(my_input_ids, fwd_hooks=hooks)

## Finding relevant tokens

In [63]:
logit_diff = logits - logits_ablated

In [64]:
logit_diff.shape

torch.Size([1, 96, 50304])

In [73]:
relevant_logit_diff = logit_diff[0,25,:]
relevant_logit_diff.shape

torch.Size([50304])

In [74]:
top = torch.topk(relevant_logit_diff, k=16)
bottom = torch.topk(relevant_logit_diff, k=16, largest=False)
print(top)
print(bottom)

torch.return_types.topk(
values=tensor([12.5329,  7.1575,  7.1229,  6.8121,  6.7228,  6.5633,  6.5550,  6.5279,
         6.2066,  6.1819,  6.1671,  5.9645,  5.8675,  5.6778,  5.6114,  5.5637],
       device='cuda:6', grad_fn=<TopkBackward0>),
indices=tensor([ 6185,  6402,  7373, 10929, 34611, 28184, 18227,  8555, 43441,  2494,
         4883, 12761, 25810, 36132,  1814, 16192], device='cuda:6'))
torch.return_types.topk(
values=tensor([-3.3217, -3.2983, -3.2578, -3.1851, -3.1702, -3.1052, -3.1043, -3.0887,
        -3.0541, -3.0457, -2.9698, -2.9301, -2.9189, -2.9149, -2.9032, -2.8691],
       device='cuda:6', grad_fn=<TopkBackward0>),
indices=tensor([12444, 16828,   214,  9676,  5020, 10713,  1972, 22165,   800,  3504,
        24761, 13793, 33158,  3153,   731,  1718], device='cuda:6'))


In [75]:
model.to_str_tokens(top.indices)

['mic',
 'MI',
 'mi',
 ' Mic',
 'MIC',
 'Mic',
 'micro',
 ' mic',
 'yster',
 ' micro',
 'NS',
 ' Micro',
 'Micro',
 'microm',
 'mb',
 'mn']

In [76]:
model.to_str_tokens(bottom.indices)

['об',
 'ру',
 '�',
 ' illegal',
 ' majority',
 ' bulk',
 'ances',
 ' Wayne',
 'ative',
 'я',
 'houses',
 'ع',
 '\n\t\t\n\t',
 ' live',
 ' them',
 ' invest']

In [77]:
model.to_str_tokens(' Omicron')

['<|endoftext|>', ' O', 'mic', 'ron']

## Ablating weakening neurons generally

In [78]:
args.gate=None
args.post=None

hooks_general = []
conditioning_values = {}
for lix, nix in neuron_list:
    conditioning_values[(lix,nix)]={}
    hooks_general += make_hooks(
        args,
        lix, nix,
        conditioning_value=conditioning_values[(lix,nix)],
        sign=(model.W_gate[lix,:,nix]@model.W_in[lix,:,nix]).item(),
        mean_value=0.0,
    )

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.38 GiB. GPU 6 has a total capacity of 47.40 GiB of which 2.41 GiB is free. Including non-PyTorch memory, this process has 44.98 GiB memory in use. Of the allocated memory 41.08 GiB is allocated by PyTorch, and 3.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Is there a single neuron responsible for this?

We're looking for a weakening neuron whose $w_\text{out}$ corresponds as closely as possible to the token 'mic'.

In [22]:
mic = model.W_U[:,model.to_single_token('mic')]

In [25]:
scores = []
for l,n in neuron_list:
    wout = model.W_out[l,n]
    wout /= torch.norm(wout)
    scores.append(wout@mic)
scores = torch.tensor(scores)
print(scores)

tensor([-5.5654e-03, -1.9865e-02,  2.3740e-03,  2.3131e-03, -1.0920e-02,
        -1.1843e-02, -5.0423e-03,  6.3631e-05,  6.2427e-03, -2.7632e-03,
        -1.7804e-02, -3.3106e-03, -2.4462e-02,  1.4472e-02,  3.5042e-03,
         1.0408e-03, -6.3770e-03,  1.5915e-02, -7.3378e-03,  4.2016e-03,
        -2.5157e-02, -1.2029e-02, -4.9860e-04,  2.5065e-03,  1.2379e-02,
        -1.4918e-02, -3.8517e-03,  6.9253e-04,  1.1168e-02, -1.2686e-02,
        -2.5985e-02, -1.1120e-02,  6.2533e-02, -4.0599e-04,  4.4302e-03,
        -7.2232e-05,  2.3728e-02, -7.4714e-03, -2.7623e-03, -9.6072e-03,
        -7.5862e-03, -3.2695e-02, -1.2145e-03, -1.2980e-02, -6.6941e-04,
         3.1989e-02, -3.7805e-03,  1.3651e-02, -6.6925e-03, -8.9840e-04,
        -2.2595e-02,  7.9196e-03,  1.5796e-02, -8.6340e-03, -5.0332e-04,
         3.4942e-04, -1.7101e-02,  8.4782e-03, -2.1459e-03, -1.4602e-02,
         4.4703e-03,  4.5846e-03,  5.6527e-03,  1.0334e-02, -2.8266e-03,
         1.0407e-02, -1.4816e-02,  1.0181e-03, -2.0

In [27]:
torch.topk(scores, k=16)

torch.return_types.topk(
values=tensor([0.1670, 0.1110, 0.0625, 0.0482, 0.0419, 0.0352, 0.0351, 0.0328, 0.0327,
        0.0327, 0.0321, 0.0320, 0.0260, 0.0249, 0.0246, 0.0237]),
indices=tensor([ 80, 168,  32, 204, 148,  72,  75, 173, 141, 155, 138,  45, 110,  95,
        206,  36]))

In [37]:
torch.topk(scores, k=16, largest=False)

torch.return_types.topk(
values=tensor([-0.0452, -0.0336, -0.0327, -0.0296, -0.0275, -0.0271, -0.0260, -0.0255,
        -0.0252, -0.0246, -0.0245, -0.0226, -0.0211, -0.0209, -0.0206, -0.0199]),
indices=tensor([105, 111,  41,  96, 215, 187,  30, 182,  20, 135,  12,  50, 183, 159,
        242,   1]))

In [28]:
neuron_list[80]

(tensor(31, device='cuda:0'), tensor(3498, device='cuda:0'))

In [34]:
values, indices = torch.topk(model.blocks[31].mlp.W_out[3498,:]@model.W_U, k=16)
for j in range(16):
    print((model.to_single_str_token(indices[j].item()), values[j]))

('re', tensor(0.1918, device='cuda:6'))
('b', tensor(0.1895, device='cuda:6'))
('c', tensor(0.1863, device='cuda:6'))
('w', tensor(0.1831, device='cuda:6'))
('p', tensor(0.1822, device='cuda:6'))
('g', tensor(0.1808, device='cuda:6'))
('to', tensor(0.1789, device='cuda:6'))
('f', tensor(0.1776, device='cuda:6'))
('m', tensor(0.1770, device='cuda:6'))
('st', tensor(0.1760, device='cuda:6'))
('sh', tensor(0.1730, device='cuda:6'))
('de', tensor(0.1711, device='cuda:6'))
('d', tensor(0.1711, device='cuda:6'))
('t', tensor(0.1700, device='cuda:6'))
('h', tensor(0.1699, device='cuda:6'))
('se', tensor(0.1661, device='cuda:6'))


In [36]:
ranks = []
for i,(l,n) in enumerate(neuron_list):
    wout = model.blocks[l].mlp.W_out[n,:]
    wout /= torch.norm(wout)
    all_scores = wout@model.W_U
    ranks.append(torch.count_nonzero(all_scores>scores[i])-1)
ranks = torch.tensor(ranks)
print(ranks)
print(torch.topk(ranks, k=16))

tensor([33027, 44794, 11722, 17507, 46631, 25442, 34436, 23220, 11846, 34613,
        46425, 25834, 46371,  5025, 15681, 22796, 31459,  8902, 36383, 11759,
        49177, 37462, 25061, 18214,  8414, 38222, 34092, 20984,  8966, 43250,
        48547, 35267,  1817, 25002, 11094, 23719,   204, 22263, 34918, 49288,
        43525, 32577, 27633, 26429, 24108,  8951, 36769,  5448, 34544, 22052,
        49540,  4933,  5594, 40541, 23993, 21516, 40685,  9317, 31069, 45143,
        16945, 13908,  9458,  8681, 33606,  7666, 45361, 18988, 28537, 36553,
         3530, 27726,  5330,  7637,  2332,  3959, 42244, 21798, 44665, 28791,
          767, 43145, 14870, 10781, 29001, 45069, 48040, 27130, 45341,  3937,
         4669, 12602, 35360,  3953, 34059,  3986, 48969,  2957, 25602, 13362,
        47640,  1403, 12754, 20264, 25564, 31879, 29734, 16825, 36359, 44763,
          360, 49651, 14619, 26784, 24379, 10822, 17850, 35104, 13581, 34743,
        43165,  1346, 27405,  4292,  8312, 15292, 32876, 12042, 

In [38]:
torch.topk(ranks, k=16, largest=False)

torch.return_types.topk(
values=tensor([ 204,  360,  739,  767, 1142, 1346, 1403, 1817, 2200, 2332, 2957, 3006,
        3262, 3340, 3530, 3931]),
indices=tensor([ 36, 110, 130,  80, 144, 121, 101,  32, 168,  74,  97, 195, 164, 189,
         70, 173]))

In [39]:
import gc

In [40]:
gc.collect()
torch.cuda.empty_cache()

## Analysing model hidden states on the example

In [46]:
for layer in range(32):
    for sublayer in ['mid', 'post']:
        print('===========================')
        print(f'Layer {layer}, {sublayer}:')
        top_token_vi = torch.topk(cache[f'blocks.{layer}.hook_resid_{sublayer}'][0,25]@model.W_U, k=4)
        for j in range(4):
            print((model.to_str_tokens(top_token_vi.indices[j]), top_token_vi.values[j]))

Layer 0, mid:
([' in'], tensor(0.2816, device='cuda:6'))
([' of'], tensor(0.2600, device='cuda:6'))
(['\n'], tensor(0.2556, device='cuda:6'))
([' then'], tensor(0.2409, device='cuda:6'))
Layer 0, post:
(["'"], tensor(0.3902, device='cuda:6'))
(['-'], tensor(0.3567, device='cuda:6'))
(['ops'], tensor(0.3376, device='cuda:6'))
(['xygen'], tensor(0.3136, device='cuda:6'))
Layer 1, mid:
(['xygen'], tensor(0.3256, device='cuda:6'))
(["'"], tensor(0.3223, device='cuda:6'))
(['ops'], tensor(0.3137, device='cuda:6'))
(['ats'], tensor(0.3126, device='cuda:6'))
Layer 1, post:
(['xygen'], tensor(0.3145, device='cuda:6'))
(['ats'], tensor(0.2860, device='cuda:6'))
(['asis'], tensor(0.2826, device='cuda:6'))
(['ops'], tensor(0.2819, device='cuda:6'))
Layer 2, mid:
(['xygen'], tensor(0.3330, device='cuda:6'))
(['cean'], tensor(0.3086, device='cuda:6'))
(['asis'], tensor(0.2923, device='cuda:6'))
(['ceans'], tensor(0.2916, device='cuda:6'))
Layer 2, post:
(['xygen'], tensor(0.3385, device='cuda:6'))


In [47]:
for layer in range(32):
    for sublayer in ['mid', 'post']:
        print(f'Layer {layer}, {sublayer}: {cache[f'blocks.{layer}.hook_resid_{sublayer}'][0,25]@mic}')

Layer 0, mid: -0.009818505495786667
Layer 0, post: 0.1756778061389923
Layer 1, mid: 0.1873852014541626
Layer 1, post: 0.16365423798561096
Layer 2, mid: 0.18060725927352905
Layer 2, post: 0.17816656827926636
Layer 3, mid: 0.16314618289470673
Layer 3, post: 0.19512557983398438
Layer 4, mid: 0.18644800782203674
Layer 4, post: 0.20833870768547058
Layer 5, mid: 0.16482438147068024
Layer 5, post: 0.15313085913658142
Layer 6, mid: 0.14475101232528687
Layer 6, post: 0.22487929463386536
Layer 7, mid: 0.22021272778511047
Layer 7, post: 0.2116979956626892
Layer 8, mid: 0.19317162036895752
Layer 8, post: 0.2249262034893036
Layer 9, mid: 0.2260773777961731
Layer 9, post: 0.187677800655365
Layer 10, mid: 0.1936216801404953
Layer 10, post: 0.2052566558122635
Layer 11, mid: 0.22408485412597656
Layer 11, post: 0.19123020768165588
Layer 12, mid: 0.21115386486053467
Layer 12, post: 0.23132723569869995
Layer 13, mid: 0.2442898154258728
Layer 13, post: 0.25902628898620605
Layer 14, mid: 0.27178335189819336