In [1]:
!export CUDA_VISIBLE_DEVICES=7

## Setup

In [2]:
import pandas as pd
import torch

In [20]:
import datasets

In [25]:
from transformers import AutoTokenizer

In [39]:
from transformer_lens import HookedTransformer

In [7]:
from entropy.compare import unflattened_data

In [47]:
from neuron_choice import neuron_choice
from entropy.entropy_intervention import make_hooks
from argparse import Namespace
from utils import NAME_TO_COMBO

## Constants that define the experiment

In [3]:
INTERVENTION_TYPE = "zero_ablation"
METRIC = "entropy"
NEURON_SUBSET = "weakening_gate-_post+"

In [9]:
DATA_PATH = "intervention_results/allenai/OLMo-7B-0424-hf/dolma-small"

In [5]:
EXTREMUM = "min"

## Preparing the data

In [10]:
diff_data = unflattened_data(DATA_PATH, METRIC, NEURON_SUBSET, INTERVENTION_TYPE)

In [11]:
diff_data.shape

torch.Size([45734, 1024])

## Choosing example

In [12]:
torch.argmin(diff_data)

tensor(31864857)

In [14]:
vi = torch.min(diff_data, dim=0)

In [15]:
vi.indices.shape

torch.Size([1024])

In [16]:
torch.argmin(vi.values)

tensor(25)

In [17]:
vi.values[25]

tensor(-10.7500, dtype=torch.float16)

In [18]:
vi.indices[25]

tensor(31118)

In [19]:
diff_data[31118, 25]

tensor(-10.7500, dtype=torch.float16)

Okay, so weakening neurons provoke the biggest decrease of entropy at sequence 31118, position 25. Let us see what the text says:

In [21]:
text_dataset = datasets.load_from_disk('neuroscope/datasets/dolma-small')

In [22]:
text_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 45734
})

In [23]:
text_dataset[31118]

{'input_ids': tensor([50279, 43688,   313,  1797,  4565,    10,   253,  7295,  6138,   247,
          5522,   273,  1329,   323, 34598,   285, 27335,  9341,   326,   403,
         10305,  5454,   984,   273,   253,   473,  6185,  1406, 12955,    13,
           347,   973,   347,   247,   747, 34220, 14117, 11726,  1467,   621,
         13629,   313, 24584,    10,  6974,   281,  1329,   643,  9341,    13,
          1754,   327,  1980,  5054,   878,    15,   187,  6872, 15849,   588,
           320, 11966,   407,  1980,  9061,    15,   187,    38,  3354,  4412,
          6456,   310,  4390, 31288, 12925,   285,  1491,   670, 26281,   432,
           253,  7295,   285,   588,   452,   271,  2898,  1232,   873,   598,
           347,  3517,   347,  1896,    15,   187]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [30]:
my_input_ids = text_dataset[31118]['input_ids']

In [26]:
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMo-7B-0424-hf')

In [28]:
tokenizer.decode(text_dataset[31118]['input_ids'])

'<|endoftext|>Yesterday (21 December) the Government announced a package of support for hospitality and leisure businesses that are losing trade because of the Omicron variant, as well as a new discretionary Additional Restrictions Grant (ARG) scheme to support other businesses, based on local economic need.\nThese schemes will be administered by local authorities.\nEden District Council is currently awaiting guidance and information about eligibility from the Government and will have an application process set up as soon as possible.\n'

In [29]:
len(text_dataset[31118]['input_ids'])

96

In [31]:
tokenizer.decode(my_input_ids[:25])

'<|endoftext|>Yesterday (21 December) the Government announced a package of support for hospitality and leisure businesses that are losing trade because of the'

In [32]:
tokenizer.decode(my_input_ids[25])

' O'

In [33]:
tokenizer(' Omicron')

{'input_ids': [473, 6185, 1406], 'attention_mask': [1, 1, 1]}

In [37]:
print(diff_data[31118][:97])

tensor([ -0.0625,  -0.2441,  -2.4375,  -1.4219,  -0.2852,  -0.4121,   0.0352,
         -1.8320,  -1.8574,  -0.7656,  -0.6880,   0.2051,  -0.2881,  -1.3828,
         -0.9590,  -6.9531,   0.2861,  -1.7871,  -2.4844,  -1.9062,  -3.3203,
         -4.2227,  -3.4648,  -4.3438,  -2.7266, -10.7500,  -9.3359,  -1.5205,
         -0.4658,  -0.5234,  -1.0488,  -0.0432,  -1.0781,  -0.7148,  -0.9453,
          0.4102,  -1.9238,  -1.4707,  -2.0195,  -0.6436,  -1.0977,  -1.2080,
         -2.2520,  -1.6289,  -0.3613,   0.0996,  -1.6465,   0.2539,   0.4902,
         -0.7598,  -0.3340,  -1.9570,  -1.7227,   1.8818,  -0.3433,   0.2520,
         -1.3320,  -0.0176,  -0.7168,  -1.0020,  -1.8008,  -0.9746,  -0.7549,
         -2.1680,  -0.3379,  -0.4785,  -1.1602,  -0.0322,   1.3242,  -6.1289,
         -2.2031,  -0.8203,  -1.6133,  -0.1504,   0.2500,  -1.9395,   0.0586,
         -1.6250,  -0.1592,  -2.3535,  -3.2656,  -1.1445,  -3.1777,  -1.7031,
         -1.8994,  -3.1523,  -2.2051,  -1.5625,  -0.7256,  -0.38

## Running the model on the example

In [42]:
del model
torch.cuda.empty_cache()

In [43]:
model = HookedTransformer.from_pretrained('allenai/OLMo-7B-0424-hf', refactor_glu=True, device='cuda:6')

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loaded pretrained model allenai/OLMo-7B-0424-hf into HookedTransformer


In [44]:
logits, cache = model.run_with_cache(my_input_ids)

In [45]:
logits.shape

torch.Size([1, 96, 50304])

In [46]:
cache.keys()

dict_keys(['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.m

In [60]:
args = Namespace(
    work_dir='.', wcos_dir='.',
    model='allenai/OLMo-7B-0424-hf',
    neuron_subset_name="weakening",
    gate='-', post='+',
    activation_location='mlp.hook_post',
    intervention_type=INTERVENTION_TYPE,
)

In [58]:
neuron_list = neuron_choice(
            args,
            category_key=NAME_TO_COMBO[args.neuron_subset_name],
            subset=243,
            baseline=False
        )
print(neuron_list)

[(tensor(12, device='cuda:0'), tensor(6572, device='cuda:0')), (tensor(17, device='cuda:0'), tensor(8074, device='cuda:0')), (tensor(31, device='cuda:0'), tensor(10867, device='cuda:0')), (tensor(28, device='cuda:0'), tensor(610, device='cuda:0')), (tensor(31, device='cuda:0'), tensor(1043, device='cuda:0')), (tensor(25, device='cuda:0'), tensor(4595, device='cuda:0')), (tensor(30, device='cuda:0'), tensor(4674, device='cuda:0')), (tensor(29, device='cuda:0'), tensor(4984, device='cuda:0')), (tensor(22, device='cuda:0'), tensor(336, device='cuda:0')), (tensor(30, device='cuda:0'), tensor(2463, device='cuda:0')), (tensor(0, device='cuda:0'), tensor(3467, device='cuda:0')), (tensor(18, device='cuda:0'), tensor(1210, device='cuda:0')), (tensor(1, device='cuda:0'), tensor(7896, device='cuda:0')), (tensor(1, device='cuda:0'), tensor(6917, device='cuda:0')), (tensor(26, device='cuda:0'), tensor(5984, device='cuda:0')), (tensor(7, device='cuda:0'), tensor(318, device='cuda:0')), (tensor(12, d

In [61]:
hooks = []
conditioning_values = {}
for lix, nix in neuron_list:
    conditioning_values[(lix,nix)]={}
    hooks += make_hooks(
        args,
        lix, nix,
        conditioning_value=conditioning_values[(lix,nix)],
        sign=(model.W_gate[lix,:,nix]@model.W_in[lix,:,nix]).item(),
        mean_value=0.0,
    )
#print(hooks)

In [62]:
logits_ablated = model.run_with_hooks(my_input_ids, fwd_hooks=hooks)

## Finding relevant tokens

In [63]:
logit_diff = logits - logits_ablated

In [64]:
logit_diff.shape

torch.Size([1, 96, 50304])

In [73]:
relevant_logit_diff = logit_diff[0,25,:]
relevant_logit_diff.shape

torch.Size([50304])

In [74]:
top = torch.topk(relevant_logit_diff, k=16)
bottom = torch.topk(relevant_logit_diff, k=16, largest=False)
print(top)
print(bottom)

torch.return_types.topk(
values=tensor([12.5329,  7.1575,  7.1229,  6.8121,  6.7228,  6.5633,  6.5550,  6.5279,
         6.2066,  6.1819,  6.1671,  5.9645,  5.8675,  5.6778,  5.6114,  5.5637],
       device='cuda:6', grad_fn=<TopkBackward0>),
indices=tensor([ 6185,  6402,  7373, 10929, 34611, 28184, 18227,  8555, 43441,  2494,
         4883, 12761, 25810, 36132,  1814, 16192], device='cuda:6'))
torch.return_types.topk(
values=tensor([-3.3217, -3.2983, -3.2578, -3.1851, -3.1702, -3.1052, -3.1043, -3.0887,
        -3.0541, -3.0457, -2.9698, -2.9301, -2.9189, -2.9149, -2.9032, -2.8691],
       device='cuda:6', grad_fn=<TopkBackward0>),
indices=tensor([12444, 16828,   214,  9676,  5020, 10713,  1972, 22165,   800,  3504,
        24761, 13793, 33158,  3153,   731,  1718], device='cuda:6'))


In [75]:
model.to_str_tokens(top.indices)

['mic',
 'MI',
 'mi',
 ' Mic',
 'MIC',
 'Mic',
 'micro',
 ' mic',
 'yster',
 ' micro',
 'NS',
 ' Micro',
 'Micro',
 'microm',
 'mb',
 'mn']

In [76]:
model.to_str_tokens(bottom.indices)

['об',
 'ру',
 '�',
 ' illegal',
 ' majority',
 ' bulk',
 'ances',
 ' Wayne',
 'ative',
 'я',
 'houses',
 'ع',
 '\n\t\t\n\t',
 ' live',
 ' them',
 ' invest']

In [77]:
model.to_str_tokens(' Omicron')

['<|endoftext|>', ' O', 'mic', 'ron']

## Ablating weakening neurons generally

In [78]:
args.gate=None
args.post=None

hooks_general = []
conditioning_values = {}
for lix, nix in neuron_list:
    conditioning_values[(lix,nix)]={}
    hooks_general += make_hooks(
        args,
        lix, nix,
        conditioning_value=conditioning_values[(lix,nix)],
        sign=(model.W_gate[lix,:,nix]@model.W_in[lix,:,nix]).item(),
        mean_value=0.0,
    )

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.38 GiB. GPU 6 has a total capacity of 47.40 GiB of which 2.41 GiB is free. Including non-PyTorch memory, this process has 44.98 GiB memory in use. Of the allocated memory 41.08 GiB is allocated by PyTorch, and 3.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)