# Train multiple linear probes at once
In this notebook we will find out at which layer a transformer has the most linearly seperable information required to do causal language modelling on wikitext data. The nice thing about *transformer_heads* is that this will all be possible with just one training run.

**GPU Requirements:** For running with GPT-2 you may be fine with just 8GB of GPU RAM. With about 24GB you should be able to run any 7B model (E.g. Llama or Mistral). With 80GB (A100) GPU you may be able to run a 70B model.

In [1]:
from transformer_heads import load_headed
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import evaluate_head_wise, get_top_n_preds
import torch

In [2]:
# GPT2 is the fastest and requires fewest memory. However, this works just the same with any Llama or Mistral model. Just change model_path to its huggingface path.
model_path = "gpt2"
train_epochs = 1
eval_epochs = 1
logging_steps = 100

In [4]:
model_params = get_model_params(model_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

{'vocab_size': 50257, 'n_positions': 1024, 'n_embd': 768, 'n_layer': 12, 'n_head': 12, 'n_inner': None, 'activation_function': 'gelu_new', 'resid_pdrop': 0.1, 'embd_pdrop': 0.1, 'attn_pdrop': 0.1, 'layer_norm_epsilon': 1e-05, 'initializer_range': 0.02, 'summary_type': 'cls_index', 'summary_use_proj': True, 'summary_activation': None, 'summary_first_dropout': 0.1, 'summary_proj_to_labels': True, 'scale_attn_weights': True, 'use_cache': True, 'scale_attn_by_inverse_layer_idx': False, 'reorder_and_upcast_attn': False, 'bos_token_id': 50256, 'eos_token_id': 50256, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': None, 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, '

Let's define a lot of heads in a loop. The heads will be hooked at layer -1,-3,-5,-7,-9,-11,-13. This is using python indexing: Layer -1 means after the last transformer block for example. We'll keep the original pretrained lm_head of the transformer model for comparison.

In [5]:
head_configs = [
    HeadConfig(
        name=f"wikitext_head_{(1+(i-1)*2)}",
        layer_hook=-(1 + (i - 1) * 2),
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
    )
    for i in range(1, 8)
]
head_configs.append(
    HeadConfig(
        name=f"lm_head",
        layer_hook=-1,
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
        trainable=False,
    )
)

In [6]:
dd = load_dataset("wikitext", "wikitext-2-v1")

In the *tokenize_function*, we define labels for each head. For causal_lm, this is just the copied input_ids.

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding=False, truncation=True)
    for hc in head_configs:
        out[hc.name] = out["input_ids"].copy()
    return out


for split in dd.keys():
    dd[split] = dd[split].filter(function=lambda example: len(example["text"]) > 10)
    dd[split] = dd[split].map(tokenize_function, batched=True)
dd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask"] + [x.name for x in head_configs],
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns("text")

Map:   0%|          | 0/2870 [00:00<?, ? examples/s]

Map:   0%|          | 0/23627 [00:00<?, ? examples/s]

Map:   0%|          | 0/2460 [00:00<?, ? examples/s]

In [8]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = load_headed(
    model_class,
    model_path,
    head_configs=head_configs,
    quantization_config=quantization_config,
    device_map={"": torch.cuda.current_device()},
)

Some weights of TransformerWithHeads were not initialized from the model checkpoint at gpt2 and are newly initialized: ['heads.wikitext_head_1.lins.0.weight', 'heads.wikitext_head_11.lins.0.weight', 'heads.wikitext_head_13.lins.0.weight', 'heads.wikitext_head_3.lins.0.weight', 'heads.wikitext_head_5.lins.0.weight', 'heads.wikitext_head_7.lins.0.weight', 'heads.wikitext_head_9.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Huggingface tells us that our newly added heads are newly initialized. Great.

In [9]:
print_trainable_parameters(model)

all params: 352154112 || trainable params: 270181632 || trainable%: 76.72255492504372
params by dtype: defaultdict(<class 'int'>, {torch.float32: 309686784, torch.uint8: 42467328})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 270181632})


A lot of heads with large vocab size -> High amount of trainable parameters

In [10]:
dd["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'wikitext_head_1', 'wikitext_head_3', 'wikitext_head_5', 'wikitext_head_7', 'wikitext_head_9', 'wikitext_head_11', 'wikitext_head_13', 'lm_head'],
    num_rows: 23627
})

In [11]:
print(get_top_n_preds(5, model, "The historical significance of", tokenizer))

{'wikitext_head_1': ['llan', 'genre', 'pal', ' tennis', ' shack'], 'wikitext_head_3': [' injure', 'ersion', '525', 'cluded', ' differed'], 'wikitext_head_5': [' pun', ' Request', ' Close', ' evaluation', 'chanted'], 'wikitext_head_7': [' z', ' throwing', ' wrongful', 'write', '覚醒'], 'wikitext_head_9': [' walking', ' Republic', 'Rex', ' Tur', ' gubernatorial'], 'wikitext_head_11': [' margins', ' step', 'inia', 'Prosecut', ' Finance'], 'wikitext_head_13': ['ony', ' Sher', ' Trout', ' Pas', ' unit'], 'lm_head': [' the', ' this', ' these', ' that', ' his']}


The untrained heads are predicting somewhat randomly

In the collator, we need to make sure that the labels for each head are padded correctly. Here, we are padding with -100, the ignore_index token for cross_entropy.

In [12]:
args = TrainingArguments(
    output_dir="linear_probe_test",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,
)
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
        **{key.name: -100 for key in head_configs},
    }
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mykeller[0m ([33mchm-hci[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: wandb version 0.16.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Tracking run with wandb version 0.16.3


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/raven/u/ykeller/transformer_heads/wandb/run-20240323_235900-4fskzy02[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mvibrant-oath-211[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/chm-hci/huggingface[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/chm-hci/huggingface/runs/4fskzy02[0m


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...




Step,Training Loss
100,81.2958
200,57.8715
300,51.4641
400,48.5801
500,47.1243
600,45.6415
700,44.4631
800,44.0778
900,43.2774
1000,42.7283


Checkpoint destination directory linear_probe_test/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.




TrainOutput(global_step=2954, training_loss=43.90577940530848, metrics={'train_runtime': 1177.9828, 'train_samples_per_second': 20.057, 'train_steps_per_second': 2.508, 'total_flos': 1.2795692896977408e+16, 'train_loss': 43.90577940530848, 'epoch': 1.0})

In [13]:
print(evaluate_head_wise(model, dd["validation"], collator, epochs=eval_epochs))


Evaluating:   0%|          | 0/308 [00:00<?, ?it/s]




Evaluating:   0%|          | 1/308 [00:00<03:43,  1.38it/s]


Evaluating:   1%|          | 2/308 [00:01<03:28,  1.47it/s]


Evaluating:   1%|          | 3/308 [00:01<03:15,  1.56it/s]


Evaluating:   1%|▏         | 4/308 [00:02<03:08,  1.61it/s]


Evaluating:   2%|▏         | 5/308 [00:03<03:48,  1.33it/s]


Evaluating:   2%|▏         | 6/308 [00:04<04:51,  1.03it/s]


Evaluating:   2%|▏         | 7/308 [00:06<05:21,  1.07s/it]


Evaluating:   3%|▎         | 8/308 [00:07<05:22,  1.08s/it]


Evaluating:   3%|▎         | 9/308 [00:08<05:28,  1.10s/it]


Evaluating:   3%|▎         | 10/308 [00:10<06:17,  1.27s/it]


Evaluating:   4%|▎         | 11/308 [00:11<05:44,  1.16s/it]


Evaluating:   4%|▍         | 12/308 [00:11<05:08,  1.04s/it]


Evaluating:   4%|▍         | 13/308 [00:12<04:39,  1.06it/s]


Evaluating:   5%|▍         | 14/308 [00:12<03:49,  1.28it/s]


Evaluating:   5%|▍         | 15/308 [00:13<03:47,  1.29it/s]


Evaluating:   5%|▌         | 16/308 [00:14<04:00,  1.21it/s]


Evaluating:   6%|▌         | 17/308 [00:15<04:02,  1.20it/s]


Evaluating:   6%|▌         | 18/308 [00:16<04:06,  1.18it/s]


Evaluating:   6%|▌         | 19/308 [00:17<05:03,  1.05s/it]


Evaluating:   6%|▋         | 20/308 [00:19<05:23,  1.12s/it]


Evaluating:   7%|▋         | 21/308 [00:19<04:48,  1.01s/it]


Evaluating:   7%|▋         | 22/308 [00:20<03:42,  1.28it/s]


Evaluating:   7%|▋         | 23/308 [00:20<02:52,  1.65it/s]


Evaluating:   8%|▊         | 24/308 [00:21<03:43,  1.27it/s]


Evaluating:   8%|▊         | 25/308 [00:22<03:50,  1.23it/s]


Evaluating:   8%|▊         | 26/308 [00:22<02:59,  1.57it/s]


Evaluating:   9%|▉         | 27/308 [00:23<02:40,  1.75it/s]


Evaluating:   9%|▉         | 28/308 [00:23<02:31,  1.84it/s]


Evaluating:   9%|▉         | 29/308 [00:23<02:14,  2.08it/s]


Evaluating:  10%|▉         | 30/308 [00:24<02:26,  1.89it/s]


Evaluating:  10%|█         | 31/308 [00:25<02:58,  1.55it/s]


Evaluating:  10%|█         | 32/308 [00:26<03:17,  1.40it/s]


Evaluating:  11%|█         | 33/308 [00:27<03:21,  1.37it/s]


Evaluating:  11%|█         | 34/308 [00:27<03:31,  1.30it/s]


Evaluating:  11%|█▏        | 35/308 [00:28<03:43,  1.22it/s]


Evaluating:  12%|█▏        | 36/308 [00:29<03:00,  1.50it/s]


Evaluating:  12%|█▏        | 37/308 [00:29<03:12,  1.41it/s]


Evaluating:  12%|█▏        | 38/308 [00:30<03:17,  1.37it/s]


Evaluating:  13%|█▎        | 39/308 [00:31<03:01,  1.48it/s]


Evaluating:  13%|█▎        | 40/308 [00:32<03:09,  1.41it/s]


Evaluating:  13%|█▎        | 41/308 [00:32<03:15,  1.37it/s]


Evaluating:  14%|█▎        | 42/308 [00:33<02:55,  1.52it/s]


Evaluating:  14%|█▍        | 43/308 [00:34<03:04,  1.44it/s]


Evaluating:  14%|█▍        | 44/308 [00:35<03:30,  1.25it/s]


Evaluating:  15%|█▍        | 45/308 [00:35<03:17,  1.33it/s]


Evaluating:  15%|█▍        | 46/308 [00:36<03:46,  1.15it/s]


Evaluating:  15%|█▌        | 47/308 [00:38<04:00,  1.08it/s]


Evaluating:  16%|█▌        | 48/308 [00:38<03:54,  1.11it/s]


Evaluating:  16%|█▌        | 49/308 [00:39<03:49,  1.13it/s]


Evaluating:  16%|█▌        | 50/308 [00:41<04:27,  1.04s/it]


Evaluating:  17%|█▋        | 51/308 [00:42<04:49,  1.13s/it]


Evaluating:  17%|█▋        | 52/308 [00:43<04:16,  1.00s/it]


Evaluating:  18%|█▊        | 54/308 [00:43<02:27,  1.72it/s]


Evaluating:  18%|█▊        | 55/308 [00:44<02:35,  1.62it/s]


Evaluating:  18%|█▊        | 56/308 [00:44<02:23,  1.76it/s]


Evaluating:  19%|█▊        | 57/308 [00:45<02:56,  1.42it/s]


Evaluating:  19%|█▉        | 58/308 [00:46<03:11,  1.31it/s]


Evaluating:  19%|█▉        | 59/308 [00:47<03:46,  1.10it/s]


Evaluating:  19%|█▉        | 60/308 [00:49<04:33,  1.10s/it]


Evaluating:  20%|█▉        | 61/308 [00:50<04:17,  1.04s/it]


Evaluating:  20%|██        | 62/308 [00:50<03:14,  1.26it/s]


Evaluating:  20%|██        | 63/308 [00:51<03:15,  1.25it/s]


Evaluating:  21%|██        | 64/308 [00:52<03:32,  1.15it/s]


Evaluating:  21%|██        | 65/308 [00:53<03:47,  1.07it/s]


Evaluating:  21%|██▏       | 66/308 [00:54<04:08,  1.03s/it]


Evaluating:  22%|██▏       | 67/308 [00:55<04:12,  1.05s/it]


Evaluating:  22%|██▏       | 68/308 [00:57<04:49,  1.21s/it]


Evaluating:  22%|██▏       | 69/308 [00:57<04:11,  1.05s/it]


Evaluating:  23%|██▎       | 70/308 [00:59<04:15,  1.07s/it]


Evaluating:  23%|██▎       | 71/308 [00:59<03:48,  1.04it/s]


Evaluating:  23%|██▎       | 72/308 [01:00<03:32,  1.11it/s]


Evaluating:  24%|██▎       | 73/308 [01:01<03:58,  1.02s/it]


Evaluating:  24%|██▍       | 74/308 [01:02<03:21,  1.16it/s]


Evaluating:  24%|██▍       | 75/308 [01:02<02:27,  1.58it/s]


Evaluating:  25%|██▌       | 77/308 [01:03<02:07,  1.80it/s]


Evaluating:  25%|██▌       | 78/308 [01:04<02:17,  1.67it/s]


Evaluating:  26%|██▌       | 79/308 [01:04<02:14,  1.70it/s]


Evaluating:  26%|██▌       | 80/308 [01:05<02:22,  1.60it/s]


Evaluating:  26%|██▋       | 81/308 [01:06<02:53,  1.31it/s]


Evaluating:  27%|██▋       | 82/308 [01:07<03:22,  1.12it/s]


Evaluating:  27%|██▋       | 83/308 [01:08<03:17,  1.14it/s]


Evaluating:  27%|██▋       | 84/308 [01:09<03:25,  1.09it/s]


Evaluating:  28%|██▊       | 85/308 [01:10<03:07,  1.19it/s]


Evaluating:  28%|██▊       | 86/308 [01:11<03:30,  1.05it/s]


Evaluating:  28%|██▊       | 87/308 [01:12<03:46,  1.02s/it]


Evaluating:  29%|██▊       | 88/308 [01:13<03:47,  1.04s/it]


Evaluating:  29%|██▉       | 89/308 [01:14<03:43,  1.02s/it]


Evaluating:  29%|██▉       | 90/308 [01:15<03:17,  1.10it/s]


Evaluating:  30%|██▉       | 91/308 [01:16<03:08,  1.15it/s]


Evaluating:  30%|██▉       | 92/308 [01:16<02:54,  1.24it/s]


Evaluating:  30%|███       | 93/308 [01:17<02:56,  1.22it/s]


Evaluating:  31%|███       | 94/308 [01:18<03:12,  1.11it/s]


Evaluating:  31%|███       | 95/308 [01:19<03:18,  1.07it/s]


Evaluating:  31%|███       | 96/308 [01:20<03:25,  1.03it/s]


Evaluating:  31%|███▏      | 97/308 [01:21<03:34,  1.02s/it]


Evaluating:  32%|███▏      | 98/308 [01:22<03:05,  1.13it/s]


Evaluating:  32%|███▏      | 99/308 [01:23<03:28,  1.00it/s]


Evaluating:  32%|███▏      | 100/308 [01:24<03:37,  1.04s/it]


Evaluating:  33%|███▎      | 101/308 [01:26<03:53,  1.13s/it]


Evaluating:  33%|███▎      | 102/308 [01:27<04:07,  1.20s/it]


Evaluating:  33%|███▎      | 103/308 [01:28<03:56,  1.15s/it]


Evaluating:  34%|███▍      | 104/308 [01:29<03:50,  1.13s/it]


Evaluating:  34%|███▍      | 105/308 [01:30<03:46,  1.11s/it]


Evaluating:  34%|███▍      | 106/308 [01:31<03:46,  1.12s/it]


Evaluating:  35%|███▍      | 107/308 [01:32<03:31,  1.05s/it]


Evaluating:  35%|███▌      | 108/308 [01:33<03:05,  1.08it/s]


Evaluating:  35%|███▌      | 109/308 [01:34<03:01,  1.10it/s]


Evaluating:  36%|███▌      | 110/308 [01:35<03:04,  1.07it/s]


Evaluating:  36%|███▌      | 111/308 [01:36<03:16,  1.00it/s]


Evaluating:  36%|███▋      | 112/308 [01:37<03:09,  1.04it/s]


Evaluating:  37%|███▋      | 113/308 [01:38<03:14,  1.00it/s]


Evaluating:  37%|███▋      | 114/308 [01:39<03:15,  1.01s/it]


Evaluating:  37%|███▋      | 115/308 [01:40<03:15,  1.02s/it]


Evaluating:  38%|███▊      | 116/308 [01:40<02:45,  1.16it/s]


Evaluating:  38%|███▊      | 117/308 [01:43<03:53,  1.22s/it]


Evaluating:  38%|███▊      | 118/308 [01:44<04:06,  1.30s/it]


Evaluating:  39%|███▊      | 119/308 [01:46<04:23,  1.39s/it]


Evaluating:  39%|███▉      | 120/308 [01:47<04:06,  1.31s/it]


Evaluating:  39%|███▉      | 121/308 [01:48<03:48,  1.22s/it]


Evaluating:  40%|███▉      | 122/308 [01:49<03:31,  1.14s/it]


Evaluating:  40%|███▉      | 123/308 [01:50<03:27,  1.12s/it]


Evaluating:  40%|████      | 124/308 [01:51<03:31,  1.15s/it]


Evaluating:  41%|████      | 125/308 [01:52<03:22,  1.10s/it]


Evaluating:  41%|████      | 126/308 [01:52<02:34,  1.18it/s]


Evaluating:  41%|████      | 127/308 [01:53<02:18,  1.31it/s]


Evaluating:  42%|████▏     | 128/308 [01:53<01:56,  1.54it/s]


Evaluating:  42%|████▏     | 129/308 [01:54<02:04,  1.44it/s]


Evaluating:  42%|████▏     | 130/308 [01:55<02:14,  1.33it/s]


Evaluating:  43%|████▎     | 131/308 [01:56<02:28,  1.19it/s]


Evaluating:  43%|████▎     | 132/308 [01:57<02:34,  1.14it/s]


Evaluating:  43%|████▎     | 133/308 [01:58<02:37,  1.11it/s]


Evaluating:  44%|████▎     | 134/308 [01:59<02:42,  1.07it/s]


Evaluating:  44%|████▍     | 135/308 [02:00<02:33,  1.12it/s]


Evaluating:  44%|████▍     | 136/308 [02:01<02:42,  1.06it/s]


Evaluating:  44%|████▍     | 137/308 [02:02<03:07,  1.10s/it]


Evaluating:  45%|████▍     | 138/308 [02:03<02:52,  1.01s/it]


Evaluating:  45%|████▌     | 139/308 [02:04<02:49,  1.01s/it]


Evaluating:  45%|████▌     | 140/308 [02:05<03:00,  1.08s/it]


Evaluating:  46%|████▌     | 141/308 [02:07<03:13,  1.16s/it]


Evaluating:  46%|████▌     | 142/308 [02:08<03:07,  1.13s/it]


Evaluating:  46%|████▋     | 143/308 [02:08<02:47,  1.02s/it]


Evaluating:  47%|████▋     | 144/308 [02:10<03:14,  1.19s/it]


Evaluating:  47%|████▋     | 145/308 [02:11<03:28,  1.28s/it]


Evaluating:  47%|████▋     | 146/308 [02:13<03:19,  1.23s/it]


Evaluating:  48%|████▊     | 147/308 [02:14<03:14,  1.21s/it]


Evaluating:  48%|████▊     | 148/308 [02:15<03:27,  1.29s/it]


Evaluating:  48%|████▊     | 149/308 [02:16<02:57,  1.11s/it]


Evaluating:  49%|████▊     | 150/308 [02:17<02:36,  1.01it/s]


Evaluating:  49%|████▉     | 151/308 [02:17<02:24,  1.08it/s]


Evaluating:  49%|████▉     | 152/308 [02:18<02:19,  1.11it/s]


Evaluating:  50%|████▉     | 153/308 [02:19<02:23,  1.08it/s]


Evaluating:  50%|█████     | 154/308 [02:20<02:18,  1.11it/s]


Evaluating:  50%|█████     | 155/308 [02:21<02:33,  1.00s/it]


Evaluating:  51%|█████     | 156/308 [02:22<02:39,  1.05s/it]


Evaluating:  51%|█████     | 157/308 [02:24<02:41,  1.07s/it]


Evaluating:  51%|█████▏    | 158/308 [02:25<02:36,  1.05s/it]


Evaluating:  52%|█████▏    | 159/308 [02:26<02:31,  1.02s/it]


Evaluating:  52%|█████▏    | 160/308 [02:26<02:07,  1.16it/s]


Evaluating:  52%|█████▏    | 161/308 [02:27<01:57,  1.25it/s]


Evaluating:  53%|█████▎    | 162/308 [02:28<02:03,  1.19it/s]


Evaluating:  53%|█████▎    | 163/308 [02:29<02:14,  1.08it/s]


Evaluating:  53%|█████▎    | 164/308 [02:29<01:59,  1.21it/s]


Evaluating:  54%|█████▎    | 165/308 [02:30<02:07,  1.12it/s]


Evaluating:  54%|█████▍    | 166/308 [02:31<02:12,  1.07it/s]


Evaluating:  54%|█████▍    | 167/308 [02:33<02:36,  1.11s/it]


Evaluating:  55%|█████▍    | 168/308 [02:34<02:12,  1.06it/s]


Evaluating:  55%|█████▍    | 169/308 [02:34<02:03,  1.12it/s]


Evaluating:  55%|█████▌    | 170/308 [02:35<02:10,  1.06it/s]


Evaluating:  56%|█████▌    | 171/308 [02:37<02:30,  1.10s/it]


Evaluating:  56%|█████▌    | 172/308 [02:38<02:24,  1.06s/it]


Evaluating:  56%|█████▌    | 173/308 [02:39<02:22,  1.05s/it]


Evaluating:  56%|█████▋    | 174/308 [02:40<02:25,  1.09s/it]


Evaluating:  57%|█████▋    | 175/308 [02:41<02:40,  1.21s/it]


Evaluating:  57%|█████▋    | 176/308 [02:42<02:15,  1.03s/it]


Evaluating:  57%|█████▋    | 177/308 [02:43<01:54,  1.15it/s]


Evaluating:  58%|█████▊    | 178/308 [02:43<01:45,  1.23it/s]


Evaluating:  58%|█████▊    | 179/308 [02:44<01:50,  1.17it/s]


Evaluating:  58%|█████▊    | 180/308 [02:46<02:32,  1.19s/it]


Evaluating:  59%|█████▉    | 181/308 [02:47<02:25,  1.15s/it]


Evaluating:  59%|█████▉    | 182/308 [02:48<02:16,  1.09s/it]


Evaluating:  59%|█████▉    | 183/308 [02:49<02:05,  1.00s/it]


Evaluating:  60%|█████▉    | 184/308 [02:50<02:11,  1.06s/it]


Evaluating:  60%|██████    | 185/308 [02:51<02:14,  1.09s/it]


Evaluating:  60%|██████    | 186/308 [02:53<02:16,  1.12s/it]


Evaluating:  61%|██████    | 187/308 [02:53<02:08,  1.06s/it]


Evaluating:  61%|██████    | 188/308 [02:54<02:01,  1.01s/it]


Evaluating:  61%|██████▏   | 189/308 [02:56<02:12,  1.12s/it]


Evaluating:  62%|██████▏   | 190/308 [02:57<02:07,  1.08s/it]


Evaluating:  62%|██████▏   | 191/308 [02:58<02:05,  1.07s/it]


Evaluating:  62%|██████▏   | 192/308 [02:59<01:57,  1.01s/it]


Evaluating:  63%|██████▎   | 193/308 [03:00<01:52,  1.02it/s]


Evaluating:  63%|██████▎   | 194/308 [03:01<01:55,  1.01s/it]


Evaluating:  63%|██████▎   | 195/308 [03:02<01:50,  1.03it/s]


Evaluating:  64%|██████▎   | 196/308 [03:02<01:43,  1.08it/s]


Evaluating:  64%|██████▍   | 197/308 [03:03<01:33,  1.19it/s]


Evaluating:  64%|██████▍   | 198/308 [03:04<01:24,  1.29it/s]


Evaluating:  65%|██████▍   | 199/308 [03:05<01:46,  1.03it/s]


Evaluating:  65%|██████▍   | 200/308 [03:06<02:00,  1.11s/it]


Evaluating:  65%|██████▌   | 201/308 [03:07<01:55,  1.08s/it]


Evaluating:  66%|██████▌   | 202/308 [03:09<01:54,  1.08s/it]


Evaluating:  66%|██████▌   | 203/308 [03:09<01:47,  1.02s/it]


Evaluating:  66%|██████▌   | 204/308 [03:11<01:50,  1.06s/it]


Evaluating:  67%|██████▋   | 205/308 [03:12<02:09,  1.26s/it]


Evaluating:  67%|██████▋   | 206/308 [03:14<02:13,  1.31s/it]


Evaluating:  68%|██████▊   | 208/308 [03:14<01:14,  1.34it/s]


Evaluating:  68%|██████▊   | 209/308 [03:14<00:58,  1.70it/s]


Evaluating:  69%|██████▊   | 211/308 [03:14<00:42,  2.30it/s]


Evaluating:  69%|██████▉   | 212/308 [03:15<00:50,  1.89it/s]


Evaluating:  69%|██████▉   | 213/308 [03:16<01:00,  1.58it/s]


Evaluating:  69%|██████▉   | 214/308 [03:17<01:01,  1.53it/s]


Evaluating:  70%|██████▉   | 215/308 [03:18<01:12,  1.28it/s]


Evaluating:  70%|███████   | 216/308 [03:19<01:22,  1.12it/s]


Evaluating:  70%|███████   | 217/308 [03:20<01:05,  1.40it/s]


Evaluating:  71%|███████   | 218/308 [03:20<00:59,  1.51it/s]


Evaluating:  71%|███████▏  | 220/308 [03:21<00:52,  1.69it/s]


Evaluating:  72%|███████▏  | 221/308 [03:22<00:55,  1.56it/s]


Evaluating:  72%|███████▏  | 222/308 [03:23<01:07,  1.28it/s]


Evaluating:  72%|███████▏  | 223/308 [03:24<01:17,  1.10it/s]


Evaluating:  73%|███████▎  | 224/308 [03:26<01:23,  1.01it/s]


Evaluating:  73%|███████▎  | 225/308 [03:27<01:30,  1.09s/it]


Evaluating:  73%|███████▎  | 226/308 [03:28<01:28,  1.08s/it]


Evaluating:  74%|███████▎  | 227/308 [03:29<01:30,  1.12s/it]


Evaluating:  74%|███████▍  | 228/308 [03:30<01:32,  1.16s/it]


Evaluating:  74%|███████▍  | 229/308 [03:31<01:27,  1.11s/it]


Evaluating:  75%|███████▍  | 230/308 [03:32<01:25,  1.09s/it]


Evaluating:  75%|███████▌  | 231/308 [03:33<01:14,  1.04it/s]


Evaluating:  75%|███████▌  | 232/308 [03:34<01:09,  1.09it/s]


Evaluating:  76%|███████▌  | 233/308 [03:35<01:12,  1.03it/s]


Evaluating:  76%|███████▌  | 234/308 [03:37<01:24,  1.15s/it]


Evaluating:  76%|███████▋  | 235/308 [03:38<01:27,  1.20s/it]


Evaluating:  77%|███████▋  | 236/308 [03:39<01:25,  1.18s/it]


Evaluating:  77%|███████▋  | 237/308 [03:40<01:28,  1.25s/it]


Evaluating:  77%|███████▋  | 238/308 [03:42<01:26,  1.24s/it]


Evaluating:  78%|███████▊  | 239/308 [03:43<01:19,  1.15s/it]


Evaluating:  78%|███████▊  | 240/308 [03:44<01:24,  1.24s/it]


Evaluating:  78%|███████▊  | 241/308 [03:45<01:23,  1.25s/it]


Evaluating:  79%|███████▊  | 242/308 [03:46<01:15,  1.14s/it]


Evaluating:  79%|███████▉  | 243/308 [03:47<01:03,  1.02it/s]


Evaluating:  79%|███████▉  | 244/308 [03:48<01:06,  1.04s/it]


Evaluating:  80%|███████▉  | 245/308 [03:49<00:58,  1.08it/s]


Evaluating:  80%|███████▉  | 246/308 [03:49<00:45,  1.35it/s]


Evaluating:  80%|████████  | 247/308 [03:50<00:50,  1.20it/s]


Evaluating:  81%|████████  | 248/308 [03:51<00:50,  1.19it/s]


Evaluating:  81%|████████  | 249/308 [03:52<00:49,  1.20it/s]


Evaluating:  81%|████████  | 250/308 [03:53<00:52,  1.10it/s]


Evaluating:  81%|████████▏ | 251/308 [03:53<00:46,  1.23it/s]


Evaluating:  82%|████████▏ | 252/308 [03:54<00:40,  1.39it/s]


Evaluating:  82%|████████▏ | 253/308 [03:54<00:35,  1.56it/s]


Evaluating:  82%|████████▏ | 254/308 [03:55<00:28,  1.88it/s]


Evaluating:  83%|████████▎ | 255/308 [03:55<00:23,  2.27it/s]


Evaluating:  83%|████████▎ | 256/308 [03:55<00:26,  1.97it/s]


Evaluating:  83%|████████▎ | 257/308 [03:57<00:35,  1.44it/s]


Evaluating:  84%|████████▍ | 258/308 [03:58<00:40,  1.24it/s]


Evaluating:  84%|████████▍ | 259/308 [03:58<00:38,  1.27it/s]


Evaluating:  84%|████████▍ | 260/308 [03:59<00:33,  1.41it/s]


Evaluating:  85%|████████▍ | 261/308 [04:00<00:35,  1.32it/s]


Evaluating:  85%|████████▌ | 262/308 [04:00<00:33,  1.37it/s]


Evaluating:  85%|████████▌ | 263/308 [04:02<00:45,  1.01s/it]


Evaluating:  86%|████████▌ | 264/308 [04:03<00:46,  1.07s/it]


Evaluating:  86%|████████▌ | 265/308 [04:05<00:48,  1.12s/it]


Evaluating:  86%|████████▋ | 266/308 [04:06<00:53,  1.27s/it]


Evaluating:  87%|████████▋ | 267/308 [04:07<00:48,  1.18s/it]


Evaluating:  87%|████████▋ | 268/308 [04:09<00:50,  1.26s/it]


Evaluating:  87%|████████▋ | 269/308 [04:10<00:54,  1.41s/it]


Evaluating:  88%|████████▊ | 270/308 [04:11<00:45,  1.18s/it]


Evaluating:  88%|████████▊ | 271/308 [04:12<00:41,  1.13s/it]


Evaluating:  88%|████████▊ | 272/308 [04:13<00:38,  1.07s/it]


Evaluating:  89%|████████▊ | 273/308 [04:14<00:34,  1.01it/s]


Evaluating:  89%|████████▉ | 274/308 [04:14<00:29,  1.14it/s]


Evaluating:  89%|████████▉ | 275/308 [04:15<00:29,  1.11it/s]


Evaluating:  90%|████████▉ | 276/308 [04:16<00:29,  1.07it/s]


Evaluating:  90%|████████▉ | 277/308 [04:18<00:34,  1.10s/it]


Evaluating:  90%|█████████ | 278/308 [04:19<00:34,  1.15s/it]


Evaluating:  91%|█████████ | 279/308 [04:21<00:37,  1.30s/it]


Evaluating:  91%|█████████ | 280/308 [04:21<00:31,  1.12s/it]


Evaluating:  91%|█████████ | 281/308 [04:22<00:25,  1.05it/s]


Evaluating:  92%|█████████▏| 282/308 [04:22<00:20,  1.24it/s]


Evaluating:  92%|█████████▏| 283/308 [04:23<00:18,  1.37it/s]


Evaluating:  92%|█████████▏| 284/308 [04:24<00:16,  1.41it/s]


Evaluating:  93%|█████████▎| 285/308 [04:24<00:15,  1.50it/s]


Evaluating:  93%|█████████▎| 286/308 [04:25<00:16,  1.33it/s]


Evaluating:  93%|█████████▎| 287/308 [04:26<00:18,  1.16it/s]


Evaluating:  94%|█████████▎| 288/308 [04:27<00:17,  1.17it/s]


Evaluating:  94%|█████████▍| 289/308 [04:28<00:16,  1.14it/s]


Evaluating:  94%|█████████▍| 290/308 [04:29<00:18,  1.01s/it]


Evaluating:  94%|█████████▍| 291/308 [04:31<00:20,  1.19s/it]


Evaluating:  95%|█████████▍| 292/308 [04:33<00:22,  1.41s/it]


Evaluating:  95%|█████████▌| 293/308 [04:35<00:21,  1.45s/it]


Evaluating:  95%|█████████▌| 294/308 [04:36<00:19,  1.39s/it]


Evaluating:  96%|█████████▌| 295/308 [04:37<00:16,  1.31s/it]


Evaluating:  96%|█████████▌| 296/308 [04:39<00:17,  1.48s/it]


Evaluating:  96%|█████████▋| 297/308 [04:40<00:16,  1.50s/it]


Evaluating:  97%|█████████▋| 298/308 [04:41<00:12,  1.21s/it]


Evaluating:  97%|█████████▋| 299/308 [04:42<00:09,  1.06s/it]


Evaluating:  97%|█████████▋| 300/308 [04:42<00:06,  1.29it/s]


Evaluating:  98%|█████████▊| 301/308 [04:42<00:04,  1.62it/s]


Evaluating:  98%|█████████▊| 302/308 [04:43<00:03,  1.56it/s]


Evaluating:  98%|█████████▊| 303/308 [04:44<00:03,  1.28it/s]


Evaluating:  99%|█████████▊| 304/308 [04:45<00:03,  1.15it/s]


Evaluating:  99%|█████████▉| 305/308 [04:46<00:03,  1.06s/it]


Evaluating:  99%|█████████▉| 306/308 [04:48<00:02,  1.16s/it]


Evaluating: 100%|█████████▉| 307/308 [04:49<00:01,  1.36s/it]


Evaluating: 100%|██████████| 308/308 [04:50<00:00,  1.06it/s]

(37.68969941448856, {'wikitext_head_1': 4.537133027206767, 'wikitext_head_3': 4.251083857827372, 'wikitext_head_5': 4.451241332989235, 'wikitext_head_7': 4.719932771348334, 'wikitext_head_9': 4.847185150369421, 'wikitext_head_11': 4.841892783517961, 'wikitext_head_13': 5.563870134291711, 'lm_head': 4.4773603600341})





That one is some interesting result. Despite the fact that gpt2 was pretrained to do causal_lm after the last transformer block, the linear probe hooked after the third to last block performs the best for causal_lm on wikitext data.

In [14]:
print(get_top_n_preds(5, model, "The historical significance of", tokenizer))

{'wikitext_head_1': [' the', ' his', ' this', ' its', ' a'], 'wikitext_head_3': [' the', ' this', ' his', ' these', ' a'], 'wikitext_head_5': [' the', ' this', ' a', ' his', ' <'], 'wikitext_head_7': [' the', ' this', ' his', ' its', ' a'], 'wikitext_head_9': [' the', ' this', ' <', ' a', ' his'], 'wikitext_head_11': [' the', ' <', ' a', ' his', ' this'], 'wikitext_head_13': [' the', ' a', ' <', ' his', '@'], 'lm_head': [' the', ' this', ' these', ' a', ' his']}


The heads are now predicting pretty likely tokens. Note that predicting *<* is an artefact of the wikitext data.