## Setup

In [1]:
DEVICE = 'cuda:0' #change as needed

In [2]:
from collections.abc import Callable
from typing import Union

In [3]:
import gc

In [4]:
import pandas as pd
import sys

In [5]:
import torch

In [6]:
import torch.nn.functional as F

In [7]:
from transformer_lens import HookedTransformer

In [8]:
from neuron_choice import neuron_choice
from entropy.entropy_intervention import make_hooks
from argparse import Namespace
from utils import NAME_TO_COMBO

## Constants that define the experiment

In [9]:
INTERVENTION_TYPE = "zero_ablation"

## Input text and model

In [10]:
prompt = "Yesterday (21 December) the Government announced a package of support for hospitality and leisure businesses that are losing trade because of the O"

In [11]:
model = HookedTransformer.from_pretrained('allenai/OLMo-7B-0424-hf', refactor_glu=True, device=DEVICE)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loaded pretrained model allenai/OLMo-7B-0424-hf into HookedTransformer


In [12]:
my_input_ids = model.to_tokens(prompt)
str_tokens = model.to_str_tokens(my_input_ids)
print(str_tokens)

['<|endoftext|>', 'Yesterday', ' (', '21', ' December', ')', ' the', ' Government', ' announced', ' a', ' package', ' of', ' support', ' for', ' hospitality', ' and', ' leisure', ' businesses', ' that', ' are', ' losing', ' trade', ' because', ' of', ' the', ' O']


## List of weakening neurons

In [13]:
args = Namespace(
    work_dir='.', wcos_dir='.',
    model='allenai/OLMo-7B-0424-hf',
    neuron_subset_name="weakening",
    gate='-', post='+',
    activation_location='mlp.hook_post',
    intervention_type=INTERVENTION_TYPE,
)

In [14]:
neuron_list = neuron_choice(
            args,
            category_key=NAME_TO_COMBO[args.neuron_subset_name],
            subset=243,
            baseline=False
        )
#print(neuron_list)

In [15]:
short_neuron_list = [(lix, nix) for lix,nix in neuron_list if lix==model.cfg.n_layers-1]

## Running the model on the example

In [25]:
logits = model(my_input_ids)

In [17]:
logits.shape

torch.Size([1, 26, 50304])

## Functions

In [16]:
def make_all_hooks(layers: list[int]|int|None=None):
    if layers is None:
        layers = list(range(model.cfg.n_layers))
    elif isinstance(layers, int):
        layers = [layers]
    hook_list = []
    conditioning_values = {}
    for lix, nix in neuron_list:
        if layers is not None and lix not in layers:
            continue
        conditioning_values[(lix,nix)]={}#this is possible because one list is a subset of the other
        new_hooks = make_hooks(
            args,
            lix, nix,
            conditioning_value=conditioning_values[(lix,nix)],
            sign=(model.W_gate[lix,:,nix]@model.W_in[lix,:,nix]).item(),
            mean_value=0.0,
        )
        hook_list.extend(new_hooks)
    return hook_list, conditioning_values

In [17]:
def run_and_clear(model=model, input_ids=my_input_ids, fwd_hooks=[], conditioning_values={}):
    logits = model.run_with_hooks(input_ids, fwd_hooks=fwd_hooks)
    conditioning_values = {key:{} for key in conditioning_values.keys()}
    return logits, conditioning_values

In [18]:
def logits_to_entropy(logit_tensor: torch.Tensor) -> torch.Tensor:
    """Compute entropy from logits. Ignore any attention mask problems (not relevant in this notebook).

    Args:
        logit_tensor (torch.Tensor): tensor of logits, shape (batch, pos, vocab)

    Returns:
        torch.Tensor: tensor of entropies, shape (batch, pos)
    """
    probs = F.softmax(logit_tensor, dim=-1)
    entropies = - torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
    return entropies

In [19]:
def get_entropy_value(layers=None, input_ids=my_input_ids):
    hook_list, conditioning_values = make_all_hooks(layers=layers)
    logits, conditioning_values = run_and_clear(input_ids=input_ids, fwd_hooks=hook_list, conditioning_values=conditioning_values)
    entropy = logits_to_entropy(logits)[:,-1].item()
    del logits, hook_list
    gc.collect()
    torch.cuda.empty_cache()
    return entropy

## The crucial thing: how much of the entropy reduction is explained by the last layer?

In [26]:
entropy_clean = logits_to_entropy(logits)[:,-1].item()
del logits
gc.collect()
torch.cuda.empty_cache()
print("clean entropy:", entropy_clean)

entropy_all_ablated = get_entropy_value()
print("entropy when ablating all weakening neurons if gate-_post+ is fulfilled:", entropy_all_ablated)

entropy_final_ablated = get_entropy_value(31)
print("entropy when ablating final layer weakening neurons if gate-_post+ is fulfilled:", entropy_final_ablated)

clean entropy: 0.011461791582405567
entropy when ablating all weakening neurons if gate-_post+ is fulfilled: 9.571281433105469
entropy when ablating final layer weakening neurons if gate-_post+ is fulfilled: 0.014897964894771576


So it's not the final layer!!

How many weakening neurons are there inside and outside of the final layer?

In [23]:
print("total number of neurons considered:", len(neuron_list))
print("number of final layer neurons considered:", len(short_neuron_list))

total number of neurons considered: 243
number of final layer neurons considered: 58


In [24]:
243-58

185

## Doing the same for one layer after another

In [34]:
for layer in range(30, -1, -1):
    print(layer, get_entropy_value(layer))

30 0.023790016770362854
29 4.536865234375
28 0.012854362837970257
27 0.011443565599620342
26 0.01133407186716795
25 0.012656877748668194
24 0.01135966181755066
23 0.011542798951268196
22 0.011182578280568123
21 0.011500231921672821
20 0.011464115232229233
19 0.010262476280331612
18 0.013297613710165024
17 0.011551076546311378
16 0.011159299872815609
15 0.015131891705095768
14 0.011391527950763702
13 0.011461791582405567
12 0.011030402965843678
11 0.010766370221972466
10 0.013470765203237534
9 0.011461791582405567
8 0.011463585309684277
7 0.011395260691642761
6 0.011448143981397152
5 0.011465327814221382
4 0.011225365102291107
3 0.011691529303789139
2 0.011507169343531132
1 0.011632694862782955
0 0.011258020997047424


So layer 29 contributes a lot to that entropy increase *on its own*, while other layers *on their own* do not. So there is some cross-layer collaboration: either cross-layer superposition or some layers preparing stuff for later layers. Let's first see if we can recover (almost) the full entropy increase with layers 29 to 31:

In [41]:
print(get_entropy_value(29))

4.536865234375


In [40]:
print(get_entropy_value(list(range(29,32))))

0.11155463010072708


In [42]:
print(get_entropy_value(list(range(0,30))))#including 29

2.8095455169677734


In [43]:
print(get_entropy_value(list(range(29,31))))

0.08107975125312805


In [44]:
print(get_entropy_value(list(range(32))))

9.571281433105469


Ablating from layer n to the end, for various n:

In [45]:
for layer in range(31, -1, -1):
    print(f'{layer} to 31:', get_entropy_value(list(range(layer, 32))))

31 to 31: 0.014897964894771576
30 to 31: 0.031053991988301277
29 to 31: 0.11155463010072708
28 to 31: 0.12374565750360489
27 to 31: 0.12347183376550674
26 to 31: 0.12399730086326599
25 to 31: 0.14999280869960785
24 to 31: 0.15769124031066895
23 to 31: 0.16057711839675903
22 to 31: 0.15990473330020905
21 to 31: 0.16104717552661896
20 to 31: 0.16111084818840027
19 to 31: 0.2255602777004242
18 to 31: 0.6157550811767578
17 to 31: 0.7932028770446777
16 to 31: 0.8191139698028564
15 to 31: 0.9114950895309448
14 to 31: 0.9104797840118408
13 to 31: 0.9104797840118408
12 to 31: 2.4281532764434814
11 to 31: 6.424055099487305
10 to 31: 9.450294494628906
9 to 31: 9.450294494628906
8 to 31: 9.450304985046387
7 to 31: 9.508875846862793
6 to 31: 9.509500503540039
5 to 31: 9.512125015258789
4 to 31: 9.50985050201416
3 to 31: 9.516609191894531
2 to 31: 9.493454933166504
1 to 31: 9.594376564025879
0 to 31: 9.571281433105469


Layers 10 and maybe 11 seem crucial, but only in collaboration with (some or all) later layers.

How about layer 10 and 29 together?

In [20]:
print(get_entropy_value(layers=[10,29]))

6.3530426025390625


## Finding relevant tokens

In [63]:
logit_diff = logits - logits_ablated

In [64]:
logit_diff.shape

torch.Size([1, 96, 50304])

In [73]:
relevant_logit_diff = logit_diff[0,25,:]
relevant_logit_diff.shape

torch.Size([50304])

In [74]:
top = torch.topk(relevant_logit_diff, k=16)
bottom = torch.topk(relevant_logit_diff, k=16, largest=False)
print(top)
print(bottom)

torch.return_types.topk(
values=tensor([12.5329,  7.1575,  7.1229,  6.8121,  6.7228,  6.5633,  6.5550,  6.5279,
         6.2066,  6.1819,  6.1671,  5.9645,  5.8675,  5.6778,  5.6114,  5.5637],
       device='cuda:6', grad_fn=<TopkBackward0>),
indices=tensor([ 6185,  6402,  7373, 10929, 34611, 28184, 18227,  8555, 43441,  2494,
         4883, 12761, 25810, 36132,  1814, 16192], device='cuda:6'))
torch.return_types.topk(
values=tensor([-3.3217, -3.2983, -3.2578, -3.1851, -3.1702, -3.1052, -3.1043, -3.0887,
        -3.0541, -3.0457, -2.9698, -2.9301, -2.9189, -2.9149, -2.9032, -2.8691],
       device='cuda:6', grad_fn=<TopkBackward0>),
indices=tensor([12444, 16828,   214,  9676,  5020, 10713,  1972, 22165,   800,  3504,
        24761, 13793, 33158,  3153,   731,  1718], device='cuda:6'))


In [75]:
model.to_str_tokens(top.indices)

['mic',
 'MI',
 'mi',
 ' Mic',
 'MIC',
 'Mic',
 'micro',
 ' mic',
 'yster',
 ' micro',
 'NS',
 ' Micro',
 'Micro',
 'microm',
 'mb',
 'mn']

In [76]:
model.to_str_tokens(bottom.indices)

['об',
 'ру',
 '�',
 ' illegal',
 ' majority',
 ' bulk',
 'ances',
 ' Wayne',
 'ative',
 'я',
 'houses',
 'ع',
 '\n\t\t\n\t',
 ' live',
 ' them',
 ' invest']

In [77]:
model.to_str_tokens(' Omicron')

['<|endoftext|>', ' O', 'mic', 'ron']

## Ablating weakening neurons generally

In [None]:
args.gate=None
args.post=None

hooks_general = []
conditioning_values = {}
for layer_list, nix in neuron_list:
    conditioning_values[(layer_list,nix)]={}
    hooks_general += make_hooks(
        args,
        layer_list, nix,
        conditioning_value=conditioning_values[(layer_list,nix)],
        sign=(model.W_gate[layer_list,:,nix]@model.W_in[layer_list,:,nix]).item(),
        mean_value=0.0,
    )

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.38 GiB. GPU 6 has a total capacity of 47.40 GiB of which 2.41 GiB is free. Including non-PyTorch memory, this process has 44.98 GiB memory in use. Of the allocated memory 41.08 GiB is allocated by PyTorch, and 3.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Is there a single neuron responsible for this?

We're looking for a weakening neuron whose $w_\text{out}$ corresponds as closely as possible to the token 'mic'.

In [22]:
mic = model.W_U[:,model.to_single_token('mic')]

In [25]:
scores = []
for l,n in neuron_list:
    wout = model.W_out[l,n]
    wout /= torch.norm(wout)
    scores.append(wout@mic)
scores = torch.tensor(scores)
print(scores)

tensor([-5.5654e-03, -1.9865e-02,  2.3740e-03,  2.3131e-03, -1.0920e-02,
        -1.1843e-02, -5.0423e-03,  6.3631e-05,  6.2427e-03, -2.7632e-03,
        -1.7804e-02, -3.3106e-03, -2.4462e-02,  1.4472e-02,  3.5042e-03,
         1.0408e-03, -6.3770e-03,  1.5915e-02, -7.3378e-03,  4.2016e-03,
        -2.5157e-02, -1.2029e-02, -4.9860e-04,  2.5065e-03,  1.2379e-02,
        -1.4918e-02, -3.8517e-03,  6.9253e-04,  1.1168e-02, -1.2686e-02,
        -2.5985e-02, -1.1120e-02,  6.2533e-02, -4.0599e-04,  4.4302e-03,
        -7.2232e-05,  2.3728e-02, -7.4714e-03, -2.7623e-03, -9.6072e-03,
        -7.5862e-03, -3.2695e-02, -1.2145e-03, -1.2980e-02, -6.6941e-04,
         3.1989e-02, -3.7805e-03,  1.3651e-02, -6.6925e-03, -8.9840e-04,
        -2.2595e-02,  7.9196e-03,  1.5796e-02, -8.6340e-03, -5.0332e-04,
         3.4942e-04, -1.7101e-02,  8.4782e-03, -2.1459e-03, -1.4602e-02,
         4.4703e-03,  4.5846e-03,  5.6527e-03,  1.0334e-02, -2.8266e-03,
         1.0407e-02, -1.4816e-02,  1.0181e-03, -2.0

In [27]:
torch.topk(scores, k=16)

torch.return_types.topk(
values=tensor([0.1670, 0.1110, 0.0625, 0.0482, 0.0419, 0.0352, 0.0351, 0.0328, 0.0327,
        0.0327, 0.0321, 0.0320, 0.0260, 0.0249, 0.0246, 0.0237]),
indices=tensor([ 80, 168,  32, 204, 148,  72,  75, 173, 141, 155, 138,  45, 110,  95,
        206,  36]))

In [37]:
torch.topk(scores, k=16, largest=False)

torch.return_types.topk(
values=tensor([-0.0452, -0.0336, -0.0327, -0.0296, -0.0275, -0.0271, -0.0260, -0.0255,
        -0.0252, -0.0246, -0.0245, -0.0226, -0.0211, -0.0209, -0.0206, -0.0199]),
indices=tensor([105, 111,  41,  96, 215, 187,  30, 182,  20, 135,  12,  50, 183, 159,
        242,   1]))

In [28]:
neuron_list[80]

(tensor(31, device='cuda:0'), tensor(3498, device='cuda:0'))

In [34]:
values, indices = torch.topk(model.blocks[31].mlp.W_out[3498,:]@model.W_U, k=16)
for j in range(16):
    print((model.to_single_str_token(indices[j].item()), values[j]))

('re', tensor(0.1918, device='cuda:6'))
('b', tensor(0.1895, device='cuda:6'))
('c', tensor(0.1863, device='cuda:6'))
('w', tensor(0.1831, device='cuda:6'))
('p', tensor(0.1822, device='cuda:6'))
('g', tensor(0.1808, device='cuda:6'))
('to', tensor(0.1789, device='cuda:6'))
('f', tensor(0.1776, device='cuda:6'))
('m', tensor(0.1770, device='cuda:6'))
('st', tensor(0.1760, device='cuda:6'))
('sh', tensor(0.1730, device='cuda:6'))
('de', tensor(0.1711, device='cuda:6'))
('d', tensor(0.1711, device='cuda:6'))
('t', tensor(0.1700, device='cuda:6'))
('h', tensor(0.1699, device='cuda:6'))
('se', tensor(0.1661, device='cuda:6'))


In [36]:
ranks = []
for i,(l,n) in enumerate(neuron_list):
    wout = model.blocks[l].mlp.W_out[n,:]
    wout /= torch.norm(wout)
    all_scores = wout@model.W_U
    ranks.append(torch.count_nonzero(all_scores>scores[i])-1)
ranks = torch.tensor(ranks)
print(ranks)
print(torch.topk(ranks, k=16))

tensor([33027, 44794, 11722, 17507, 46631, 25442, 34436, 23220, 11846, 34613,
        46425, 25834, 46371,  5025, 15681, 22796, 31459,  8902, 36383, 11759,
        49177, 37462, 25061, 18214,  8414, 38222, 34092, 20984,  8966, 43250,
        48547, 35267,  1817, 25002, 11094, 23719,   204, 22263, 34918, 49288,
        43525, 32577, 27633, 26429, 24108,  8951, 36769,  5448, 34544, 22052,
        49540,  4933,  5594, 40541, 23993, 21516, 40685,  9317, 31069, 45143,
        16945, 13908,  9458,  8681, 33606,  7666, 45361, 18988, 28537, 36553,
         3530, 27726,  5330,  7637,  2332,  3959, 42244, 21798, 44665, 28791,
          767, 43145, 14870, 10781, 29001, 45069, 48040, 27130, 45341,  3937,
         4669, 12602, 35360,  3953, 34059,  3986, 48969,  2957, 25602, 13362,
        47640,  1403, 12754, 20264, 25564, 31879, 29734, 16825, 36359, 44763,
          360, 49651, 14619, 26784, 24379, 10822, 17850, 35104, 13581, 34743,
        43165,  1346, 27405,  4292,  8312, 15292, 32876, 12042, 

In [38]:
torch.topk(ranks, k=16, largest=False)

torch.return_types.topk(
values=tensor([ 204,  360,  739,  767, 1142, 1346, 1403, 1817, 2200, 2332, 2957, 3006,
        3262, 3340, 3530, 3931]),
indices=tensor([ 36, 110, 130,  80, 144, 121, 101,  32, 168,  74,  97, 195, 164, 189,
         70, 173]))

In [39]:
import gc

In [40]:
gc.collect()
torch.cuda.empty_cache()

## Analysing model hidden states on the example

In [46]:
for layer in range(32):
    for sublayer in ['mid', 'post']:
        print('===========================')
        print(f'Layer {layer}, {sublayer}:')
        top_token_vi = torch.topk(cache[f'blocks.{layer}.hook_resid_{sublayer}'][0,25]@model.W_U, k=4)
        for j in range(4):
            print((model.to_str_tokens(top_token_vi.indices[j]), top_token_vi.values[j]))

Layer 0, mid:
([' in'], tensor(0.2816, device='cuda:6'))
([' of'], tensor(0.2600, device='cuda:6'))
(['\n'], tensor(0.2556, device='cuda:6'))
([' then'], tensor(0.2409, device='cuda:6'))
Layer 0, post:
(["'"], tensor(0.3902, device='cuda:6'))
(['-'], tensor(0.3567, device='cuda:6'))
(['ops'], tensor(0.3376, device='cuda:6'))
(['xygen'], tensor(0.3136, device='cuda:6'))
Layer 1, mid:
(['xygen'], tensor(0.3256, device='cuda:6'))
(["'"], tensor(0.3223, device='cuda:6'))
(['ops'], tensor(0.3137, device='cuda:6'))
(['ats'], tensor(0.3126, device='cuda:6'))
Layer 1, post:
(['xygen'], tensor(0.3145, device='cuda:6'))
(['ats'], tensor(0.2860, device='cuda:6'))
(['asis'], tensor(0.2826, device='cuda:6'))
(['ops'], tensor(0.2819, device='cuda:6'))
Layer 2, mid:
(['xygen'], tensor(0.3330, device='cuda:6'))
(['cean'], tensor(0.3086, device='cuda:6'))
(['asis'], tensor(0.2923, device='cuda:6'))
(['ceans'], tensor(0.2916, device='cuda:6'))
Layer 2, post:
(['xygen'], tensor(0.3385, device='cuda:6'))


In [47]:
for layer in range(32):
    for sublayer in ['mid', 'post']:
        print(f'Layer {layer}, {sublayer}: {cache[f'blocks.{layer}.hook_resid_{sublayer}'][0,25]@mic}')

Layer 0, mid: -0.009818505495786667
Layer 0, post: 0.1756778061389923
Layer 1, mid: 0.1873852014541626
Layer 1, post: 0.16365423798561096
Layer 2, mid: 0.18060725927352905
Layer 2, post: 0.17816656827926636
Layer 3, mid: 0.16314618289470673
Layer 3, post: 0.19512557983398438
Layer 4, mid: 0.18644800782203674
Layer 4, post: 0.20833870768547058
Layer 5, mid: 0.16482438147068024
Layer 5, post: 0.15313085913658142
Layer 6, mid: 0.14475101232528687
Layer 6, post: 0.22487929463386536
Layer 7, mid: 0.22021272778511047
Layer 7, post: 0.2116979956626892
Layer 8, mid: 0.19317162036895752
Layer 8, post: 0.2249262034893036
Layer 9, mid: 0.2260773777961731
Layer 9, post: 0.187677800655365
Layer 10, mid: 0.1936216801404953
Layer 10, post: 0.2052566558122635
Layer 11, mid: 0.22408485412597656
Layer 11, post: 0.19123020768165588
Layer 12, mid: 0.21115386486053467
Layer 12, post: 0.23132723569869995
Layer 13, mid: 0.2442898154258728
Layer 13, post: 0.25902628898620605
Layer 14, mid: 0.27178335189819336