# Train multiple linear probes at once
In this notebook we will find out at which layer a transformer has the most linearly seperable information required to do causal language modelling on wikitext data. The nice thing about *transformer_heads* is that this will all be possible with just one training run.

In [1]:
from transformer_heads import load_headed
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import evaluate_head_wise, get_top_n_preds
import torch

In [3]:
# Parameters
model_path = "meta-llama/Llama-2-7b-hf"
train_epochs = 1
eval_epochs = 1
logging_steps = 100

In [4]:
model_params = get_model_params(model_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

{'vocab_size': 32000, 'max_position_embeddings': 4096, 'hidden_size': 4096, 'intermediate_size': 11008, 'num_hidden_layers': 32, 'num_attention_heads': 32, 'num_key_value_heads': 32, 'hidden_act': 'silu', 'initializer_range': 0.02, 'rms_norm_eps': 1e-05, 'pretraining_tp': 1, 'use_cache': True, 'rope_theta': 10000.0, 'rope_scaling': None, 'attention_bias': False, 'attention_dropout': 0.0, 'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float16', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': False, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.

Let's define a lot of heads in a loop. The heads will be hooked at layer -1,-3,-5,-7,-9,-11,-13. This is using python indexing: Layer -1 means after the last transformer block for example. We'll keep the original pretrained lm_head of the transformer model for comparison.

In [5]:
head_configs = [
    HeadConfig(
        name=f"wikitext_head_{(1+(i-1)*2)}",
        layer_hook=-(1 + (i - 1) * 2),
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
    )
    for i in range(1, 8)
]
head_configs.append(
    HeadConfig(
        name=f"lm_head",
        layer_hook=-1,
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
        trainable=False,
    )
)

In [6]:
dd = load_dataset("wikitext", "wikitext-2-v1")

In the *tokenize_function*, we define labels for each head. For causal_lm, this is just the copied input_ids.

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding=False, truncation=True)
    for hc in head_configs:
        out[hc.name] = out["input_ids"].copy()
    return out


for split in dd.keys():
    dd[split] = dd[split].filter(function=lambda example: len(example["text"]) > 10)
    dd[split] = dd[split].map(tokenize_function, batched=True)
dd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask"] + [x.name for x in head_configs],
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns("text")

Map:   0%|          | 0/2870 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/23627 [00:00<?, ? examples/s]

Map:   0%|          | 0/2460 [00:00<?, ? examples/s]

In [8]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = load_headed(
    model_class,
    model_path,
    head_configs=head_configs,
    quantization_config=quantization_config,
    device_map={"": torch.cuda.current_device()},
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of TransformerWithHeads were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['heads.wikitext_head_1.lins.0.weight', 'heads.wikitext_head_11.lins.0.weight', 'heads.wikitext_head_13.lins.0.weight', 'heads.wikitext_head_3.lins.0.weight', 'heads.wikitext_head_5.lins.0.weight', 'heads.wikitext_head_7.lins.0.weight', 'heads.wikitext_head_9.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Huggingface tells us that our newly added heads are newly initialized. Great.

In [9]:
print_trainable_parameters(model)

all params: 4417916928 || trainable params: 917504000 || trainable%: 20.76779656459851
params by dtype: defaultdict(<class 'int'>, {torch.float32: 1179914240, torch.uint8: 3238002688})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 917504000})


A lot of heads with large vocab size -> High amount of trainable parameters

In [10]:
dd["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'wikitext_head_1', 'wikitext_head_3', 'wikitext_head_5', 'wikitext_head_7', 'wikitext_head_9', 'wikitext_head_11', 'wikitext_head_13', 'lm_head'],
    num_rows: 23627
})

In [11]:
print(get_top_n_preds(5, model, "The historical significance of", tokenizer))

{'wikitext_head_1': ['에', 'ény', 'junto', 'Unity', 'zelf'], 'wikitext_head_3': ['Point', 'Everything', 'encode', 'nob', 'fal'], 'wikitext_head_5': ['чных', 'mlung', 'ismus', 'stress', 'подацима'], 'wikitext_head_7': ['мене', 'unsafe', 'lear', 'North', 'het'], 'wikitext_head_9': ['ște', 'rayed', 'credit', 'particul', 'marriage'], 'wikitext_head_11': ['CA', 'U', 'Ale', 'Æ', 'DOCTYPE'], 'wikitext_head_13': ['фе', 'продол', 'pag', 'attach', '改'], 'lm_head': ['the', 'this', 'a', '', 'The']}


The untrained heads are predicting somewhat randomly

In the collator, we need to make sure that the labels for each head are padded correctly. Here, we are padding with -100, the ignore_index token for cross_entropy.

In [12]:
args = TrainingArguments(
    output_dir="linear_probe_test",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,
)
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
        **{key.name: -100 for key in head_configs},
    }
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mykeller[0m ([33mchm-hci[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: wandb version 0.16.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Tracking run with wandb version 0.16.3


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/raven/u/ykeller/transformer_heads/wandb/run-20240324_102422-d2nnltle[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mfirm-smoke-217[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/chm-hci/huggingface[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/chm-hci/huggingface/runs/d2nnltle[0m


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...




Step,Training Loss
100,57.6001
200,38.5629
300,33.7243
400,30.8404
500,29.6798
600,28.3661
700,27.4192
800,27.1187
900,26.3536
1000,25.88


Checkpoint destination directory linear_probe_test/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.




Checkpoint destination directory linear_probe_test/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.




TrainOutput(global_step=2954, training_loss=26.48923161325216, metrics={'train_runtime': 20574.7915, 'train_samples_per_second': 1.148, 'train_steps_per_second': 0.144, 'total_flos': 3.1610522837397504e+17, 'train_loss': 26.48923161325216, 'epoch': 1.0})

In [13]:
print(evaluate_head_wise(model, dd["validation"], collator, epochs=eval_epochs))


Evaluating:   0%|          | 0/308 [00:00<?, ?it/s]




Evaluating:   0%|          | 1/308 [00:02<10:50,  2.12s/it]


Evaluating:   1%|          | 2/308 [00:04<10:25,  2.04s/it]


Evaluating:   1%|          | 3/308 [00:05<09:20,  1.84s/it]


Evaluating:   1%|▏         | 4/308 [00:07<09:07,  1.80s/it]


Evaluating:   2%|▏         | 5/308 [00:10<11:22,  2.25s/it]


Evaluating:   2%|▏         | 6/308 [00:14<14:06,  2.80s/it]


Evaluating:   2%|▏         | 7/308 [00:18<15:49,  3.16s/it]


Evaluating:   3%|▎         | 8/308 [00:21<16:03,  3.21s/it]


Evaluating:   3%|▎         | 9/308 [00:24<15:29,  3.11s/it]


Evaluating:   3%|▎         | 10/308 [00:28<16:56,  3.41s/it]


Evaluating:   4%|▎         | 11/308 [00:31<16:06,  3.25s/it]


Evaluating:   4%|▍         | 12/308 [00:33<14:36,  2.96s/it]


Evaluating:   4%|▍         | 13/308 [00:36<13:33,  2.76s/it]


Evaluating:   5%|▍         | 14/308 [00:37<11:15,  2.30s/it]


Evaluating:   5%|▍         | 15/308 [00:39<11:34,  2.37s/it]


Evaluating:   5%|▌         | 16/308 [00:42<11:53,  2.44s/it]


Evaluating:   6%|▌         | 17/308 [00:45<12:06,  2.50s/it]


Evaluating:   6%|▌         | 18/308 [00:47<12:08,  2.51s/it]


Evaluating:   6%|▌         | 19/308 [00:51<14:46,  3.07s/it]


Evaluating:   6%|▋         | 20/308 [00:55<15:55,  3.32s/it]


Evaluating:   7%|▋         | 21/308 [00:58<14:28,  3.03s/it]


Evaluating:   7%|▋         | 22/308 [00:58<11:05,  2.33s/it]


Evaluating:   7%|▋         | 23/308 [00:59<08:31,  1.79s/it]


Evaluating:   8%|▊         | 24/308 [01:03<11:05,  2.34s/it]


Evaluating:   8%|▊         | 25/308 [01:05<11:21,  2.41s/it]


Evaluating:   8%|▊         | 26/308 [01:06<09:08,  1.94s/it]


Evaluating:   9%|▉         | 27/308 [01:07<08:01,  1.71s/it]


Evaluating:   9%|▉         | 28/308 [01:09<07:39,  1.64s/it]


Evaluating:   9%|▉         | 29/308 [01:10<06:41,  1.44s/it]


Evaluating:  10%|▉         | 30/308 [01:11<07:10,  1.55s/it]


Evaluating:  10%|█         | 31/308 [01:14<07:57,  1.72s/it]


Evaluating:  10%|█         | 32/308 [01:16<08:48,  1.91s/it]


Evaluating:  11%|█         | 33/308 [01:18<08:52,  1.94s/it]


Evaluating:  11%|█         | 34/308 [01:20<09:41,  2.12s/it]


Evaluating:  11%|█▏        | 35/308 [01:23<10:14,  2.25s/it]


Evaluating:  12%|█▏        | 36/308 [01:24<08:14,  1.82s/it]


Evaluating:  12%|█▏        | 37/308 [01:26<08:38,  1.91s/it]


Evaluating:  12%|█▏        | 38/308 [01:28<09:05,  2.02s/it]


Evaluating:  13%|█▎        | 39/308 [01:30<08:26,  1.88s/it]


Evaluating:  13%|█▎        | 40/308 [01:32<08:41,  1.94s/it]


Evaluating:  13%|█▎        | 41/308 [01:34<09:11,  2.07s/it]


Evaluating:  14%|█▎        | 42/308 [01:36<08:32,  1.93s/it]


Evaluating:  14%|█▍        | 43/308 [01:38<08:59,  2.03s/it]


Evaluating:  14%|█▍        | 44/308 [01:41<10:05,  2.29s/it]


Evaluating:  15%|█▍        | 45/308 [01:43<09:24,  2.15s/it]


Evaluating:  15%|█▍        | 46/308 [01:46<10:57,  2.51s/it]


Evaluating:  15%|█▌        | 47/308 [01:49<11:44,  2.70s/it]


Evaluating:  16%|█▌        | 48/308 [01:52<11:09,  2.57s/it]


Evaluating:  16%|█▌        | 49/308 [01:54<10:44,  2.49s/it]


Evaluating:  16%|█▌        | 50/308 [01:57<12:07,  2.82s/it]


Evaluating:  17%|█▋        | 51/308 [02:01<13:00,  3.04s/it]


Evaluating:  17%|█▋        | 52/308 [02:03<11:39,  2.73s/it]


Evaluating:  17%|█▋        | 53/308 [02:03<08:30,  2.00s/it]


Evaluating:  18%|█▊        | 54/308 [02:04<06:17,  1.49s/it]


Evaluating:  18%|█▊        | 55/308 [02:06<07:04,  1.68s/it]


Evaluating:  18%|█▊        | 56/308 [02:07<06:29,  1.55s/it]


Evaluating:  19%|█▊        | 57/308 [02:10<08:09,  1.95s/it]


Evaluating:  19%|█▉        | 58/308 [02:12<08:53,  2.13s/it]


Evaluating:  19%|█▉        | 59/308 [02:16<10:21,  2.50s/it]


Evaluating:  19%|█▉        | 60/308 [02:20<12:23,  3.00s/it]


Evaluating:  20%|█▉        | 61/308 [02:22<11:44,  2.85s/it]


Evaluating:  20%|██        | 62/308 [02:23<08:57,  2.19s/it]


Evaluating:  20%|██        | 63/308 [02:25<09:06,  2.23s/it]


Evaluating:  21%|██        | 64/308 [02:28<09:44,  2.40s/it]


Evaluating:  21%|██        | 65/308 [02:31<10:37,  2.62s/it]


Evaluating:  21%|██▏       | 66/308 [02:35<11:41,  2.90s/it]


Evaluating:  22%|██▏       | 67/308 [02:38<11:38,  2.90s/it]


Evaluating:  22%|██▏       | 68/308 [02:42<13:06,  3.28s/it]


Evaluating:  22%|██▏       | 69/308 [02:44<11:21,  2.85s/it]


Evaluating:  23%|██▎       | 70/308 [02:47<11:35,  2.92s/it]


Evaluating:  23%|██▎       | 71/308 [02:49<10:36,  2.69s/it]


Evaluating:  23%|██▎       | 72/308 [02:51<10:04,  2.56s/it]


Evaluating:  24%|██▎       | 73/308 [02:55<11:35,  2.96s/it]


Evaluating:  24%|██▍       | 74/308 [02:57<09:55,  2.54s/it]


Evaluating:  24%|██▍       | 75/308 [02:57<07:17,  1.88s/it]


Evaluating:  25%|██▍       | 76/308 [02:57<05:24,  1.40s/it]


Evaluating:  25%|██▌       | 77/308 [03:00<06:39,  1.73s/it]


Evaluating:  25%|██▌       | 78/308 [03:02<06:55,  1.81s/it]


Evaluating:  26%|██▌       | 79/308 [03:03<06:34,  1.72s/it]


Evaluating:  26%|██▌       | 80/308 [03:05<06:52,  1.81s/it]


Evaluating:  26%|██▋       | 81/308 [03:08<08:17,  2.19s/it]


Evaluating:  27%|██▋       | 82/308 [03:12<09:35,  2.55s/it]


Evaluating:  27%|██▋       | 83/308 [03:14<09:20,  2.49s/it]


Evaluating:  27%|██▋       | 84/308 [03:17<09:35,  2.57s/it]


Evaluating:  28%|██▊       | 85/308 [03:19<08:56,  2.41s/it]


Evaluating:  28%|██▊       | 86/308 [03:22<10:08,  2.74s/it]


Evaluating:  28%|██▊       | 87/308 [03:26<10:59,  2.98s/it]


Evaluating:  29%|██▊       | 88/308 [03:29<11:04,  3.02s/it]


Evaluating:  29%|██▉       | 89/308 [03:32<10:47,  2.96s/it]


Evaluating:  29%|██▉       | 90/308 [03:34<09:29,  2.61s/it]


Evaluating:  30%|██▉       | 91/308 [03:36<08:55,  2.47s/it]


Evaluating:  30%|██▉       | 92/308 [03:38<08:10,  2.27s/it]


Evaluating:  30%|███       | 93/308 [03:40<08:31,  2.38s/it]


Evaluating:  31%|███       | 94/308 [03:43<09:13,  2.58s/it]


Evaluating:  31%|███       | 95/308 [03:46<09:23,  2.65s/it]


Evaluating:  31%|███       | 96/308 [03:49<09:33,  2.71s/it]


Evaluating:  31%|███▏      | 97/308 [03:52<09:55,  2.82s/it]


Evaluating:  32%|███▏      | 98/308 [03:54<08:36,  2.46s/it]


Evaluating:  32%|███▏      | 99/308 [03:57<09:31,  2.73s/it]


Evaluating:  32%|███▏      | 100/308 [04:00<09:51,  2.84s/it]


Evaluating:  33%|███▎      | 101/308 [04:04<10:32,  3.06s/it]


Evaluating:  33%|███▎      | 102/308 [04:08<11:19,  3.30s/it]


Evaluating:  33%|███▎      | 103/308 [04:10<10:48,  3.16s/it]


Evaluating:  34%|███▍      | 104/308 [04:13<10:23,  3.06s/it]


Evaluating:  34%|███▍      | 105/308 [04:16<10:06,  2.99s/it]


Evaluating:  34%|███▍      | 106/308 [04:20<10:37,  3.16s/it]


Evaluating:  35%|███▍      | 107/308 [04:22<09:58,  2.98s/it]


Evaluating:  35%|███▌      | 108/308 [04:24<08:45,  2.63s/it]


Evaluating:  35%|███▌      | 109/308 [04:27<08:37,  2.60s/it]


Evaluating:  36%|███▌      | 110/308 [04:29<08:32,  2.59s/it]


Evaluating:  36%|███▌      | 111/308 [04:32<09:00,  2.75s/it]


Evaluating:  36%|███▋      | 112/308 [04:35<08:35,  2.63s/it]


Evaluating:  37%|███▋      | 113/308 [04:37<08:43,  2.69s/it]


Evaluating:  37%|███▋      | 114/308 [04:40<08:48,  2.72s/it]


Evaluating:  37%|███▋      | 115/308 [04:43<08:55,  2.77s/it]


Evaluating:  38%|███▊      | 116/308 [04:44<07:28,  2.34s/it]


Evaluating:  38%|███▊      | 117/308 [04:50<10:06,  3.18s/it]


Evaluating:  38%|███▊      | 118/308 [04:53<10:41,  3.38s/it]


Evaluating:  39%|███▊      | 119/308 [04:58<11:33,  3.67s/it]


Evaluating:  39%|███▉      | 120/308 [05:01<10:58,  3.50s/it]


Evaluating:  39%|███▉      | 121/308 [05:04<10:20,  3.32s/it]


Evaluating:  40%|███▉      | 122/308 [05:07<09:49,  3.17s/it]


Evaluating:  40%|███▉      | 123/308 [05:10<09:50,  3.19s/it]


Evaluating:  40%|████      | 124/308 [05:13<09:49,  3.21s/it]


Evaluating:  41%|████      | 125/308 [05:16<09:21,  3.07s/it]


Evaluating:  41%|████      | 126/308 [05:17<07:21,  2.42s/it]


Evaluating:  41%|████      | 127/308 [05:18<06:32,  2.17s/it]


Evaluating:  42%|████▏     | 128/308 [05:19<05:36,  1.87s/it]


Evaluating:  42%|████▏     | 129/308 [05:22<05:48,  1.95s/it]


Evaluating:  42%|████▏     | 130/308 [05:24<06:08,  2.07s/it]


Evaluating:  43%|████▎     | 131/308 [05:27<06:49,  2.32s/it]


Evaluating:  43%|████▎     | 132/308 [05:30<07:10,  2.44s/it]


Evaluating:  43%|████▎     | 133/308 [05:32<07:17,  2.50s/it]


Evaluating:  44%|████▎     | 134/308 [05:35<07:35,  2.62s/it]


Evaluating:  44%|████▍     | 135/308 [05:37<07:18,  2.53s/it]


Evaluating:  44%|████▍     | 136/308 [05:41<07:47,  2.72s/it]


Evaluating:  44%|████▍     | 137/308 [05:44<08:43,  3.06s/it]


Evaluating:  45%|████▍     | 138/308 [05:47<08:04,  2.85s/it]


Evaluating:  45%|████▌     | 139/308 [05:50<07:59,  2.84s/it]


Evaluating:  45%|████▌     | 140/308 [05:53<08:23,  3.00s/it]


Evaluating:  46%|████▌     | 141/308 [05:57<08:53,  3.20s/it]


Evaluating:  46%|████▌     | 142/308 [06:00<08:35,  3.10s/it]


Evaluating:  46%|████▋     | 143/308 [06:02<07:52,  2.86s/it]


Evaluating:  47%|████▋     | 144/308 [06:07<09:17,  3.40s/it]


Evaluating:  47%|████▋     | 145/308 [06:11<09:50,  3.62s/it]


Evaluating:  47%|████▋     | 146/308 [06:14<09:43,  3.60s/it]


Evaluating:  48%|████▊     | 147/308 [06:17<09:13,  3.44s/it]


Evaluating:  48%|████▊     | 148/308 [06:21<09:31,  3.57s/it]


Evaluating:  48%|████▊     | 149/308 [06:23<08:06,  3.06s/it]


Evaluating:  49%|████▊     | 150/308 [06:25<07:19,  2.78s/it]


Evaluating:  49%|████▉     | 151/308 [06:27<06:55,  2.64s/it]


Evaluating:  49%|████▉     | 152/308 [06:30<06:39,  2.56s/it]


Evaluating:  50%|████▉     | 153/308 [06:33<06:49,  2.64s/it]


Evaluating:  50%|█████     | 154/308 [06:35<06:33,  2.56s/it]


Evaluating:  50%|█████     | 155/308 [06:39<07:15,  2.85s/it]


Evaluating:  51%|█████     | 156/308 [06:42<07:25,  2.93s/it]


Evaluating:  51%|█████     | 157/308 [06:45<07:21,  2.92s/it]


Evaluating:  51%|█████▏    | 158/308 [06:47<07:13,  2.89s/it]


Evaluating:  52%|█████▏    | 159/308 [06:50<06:52,  2.77s/it]


Evaluating:  52%|█████▏    | 160/308 [06:51<05:54,  2.40s/it]


Evaluating:  52%|█████▏    | 161/308 [06:53<05:36,  2.29s/it]


Evaluating:  53%|█████▎    | 162/308 [06:56<05:43,  2.35s/it]


Evaluating:  53%|█████▎    | 163/308 [06:59<06:04,  2.51s/it]


Evaluating:  53%|█████▎    | 164/308 [07:01<05:32,  2.31s/it]


Evaluating:  54%|█████▎    | 165/308 [07:04<05:53,  2.47s/it]


Evaluating:  54%|█████▍    | 166/308 [07:07<06:16,  2.65s/it]


Evaluating:  54%|█████▍    | 167/308 [07:11<07:16,  3.10s/it]


Evaluating:  55%|█████▍    | 168/308 [07:12<06:10,  2.65s/it]


Evaluating:  55%|█████▍    | 169/308 [07:14<05:44,  2.48s/it]


Evaluating:  55%|█████▌    | 170/308 [07:17<05:56,  2.58s/it]


Evaluating:  56%|█████▌    | 171/308 [07:21<06:46,  2.97s/it]


Evaluating:  56%|█████▌    | 172/308 [07:24<06:26,  2.84s/it]


Evaluating:  56%|█████▌    | 173/308 [07:27<06:26,  2.86s/it]


Evaluating:  56%|█████▋    | 174/308 [07:30<06:38,  2.98s/it]


Evaluating:  57%|█████▋    | 175/308 [07:34<07:20,  3.31s/it]


Evaluating:  57%|█████▋    | 176/308 [07:36<06:19,  2.87s/it]


Evaluating:  57%|█████▋    | 177/308 [07:37<05:23,  2.47s/it]


Evaluating:  58%|█████▊    | 178/308 [07:39<04:57,  2.29s/it]


Evaluating:  58%|█████▊    | 179/308 [07:42<05:15,  2.44s/it]


Evaluating:  58%|█████▊    | 180/308 [07:47<07:07,  3.34s/it]


Evaluating:  59%|█████▉    | 181/308 [07:50<06:44,  3.19s/it]


Evaluating:  59%|█████▉    | 182/308 [07:53<06:25,  3.06s/it]


Evaluating:  59%|█████▉    | 183/308 [07:55<05:54,  2.84s/it]


Evaluating:  60%|█████▉    | 184/308 [07:59<06:06,  2.96s/it]


Evaluating:  60%|██████    | 185/308 [08:02<06:09,  3.01s/it]


Evaluating:  60%|██████    | 186/308 [08:05<06:25,  3.16s/it]


Evaluating:  61%|██████    | 187/308 [08:08<06:07,  3.04s/it]


Evaluating:  61%|██████    | 188/308 [08:11<05:53,  2.95s/it]


Evaluating:  61%|██████▏   | 189/308 [08:14<06:13,  3.14s/it]


Evaluating:  62%|██████▏   | 190/308 [08:17<05:49,  2.97s/it]


Evaluating:  62%|██████▏   | 191/308 [08:20<05:42,  2.93s/it]


Evaluating:  62%|██████▏   | 192/308 [08:22<05:28,  2.83s/it]


Evaluating:  63%|██████▎   | 193/308 [08:25<05:14,  2.73s/it]


Evaluating:  63%|██████▎   | 194/308 [08:28<05:14,  2.76s/it]


Evaluating:  63%|██████▎   | 195/308 [08:30<04:57,  2.63s/it]


Evaluating:  64%|██████▎   | 196/308 [08:32<04:42,  2.52s/it]


Evaluating:  64%|██████▍   | 197/308 [08:34<04:15,  2.30s/it]


Evaluating:  64%|██████▍   | 198/308 [08:36<03:49,  2.09s/it]


Evaluating:  65%|██████▍   | 199/308 [08:39<04:46,  2.63s/it]


Evaluating:  65%|██████▍   | 200/308 [08:44<05:32,  3.08s/it]


Evaluating:  65%|██████▌   | 201/308 [08:47<05:29,  3.08s/it]


Evaluating:  66%|██████▌   | 202/308 [08:50<05:33,  3.15s/it]


Evaluating:  66%|██████▌   | 203/308 [08:52<05:11,  2.97s/it]


Evaluating:  66%|██████▌   | 204/308 [08:56<05:12,  3.01s/it]


Evaluating:  67%|██████▋   | 205/308 [09:00<06:06,  3.56s/it]


Evaluating:  67%|██████▋   | 206/308 [09:04<06:13,  3.66s/it]


Evaluating:  67%|██████▋   | 207/308 [09:05<04:27,  2.65s/it]


Evaluating:  68%|██████▊   | 208/308 [09:05<03:15,  1.96s/it]


Evaluating:  68%|██████▊   | 209/308 [09:05<02:25,  1.47s/it]


Evaluating:  68%|██████▊   | 210/308 [09:06<01:47,  1.10s/it]


Evaluating:  69%|██████▊   | 211/308 [09:07<01:46,  1.09s/it]


Evaluating:  69%|██████▉   | 212/308 [09:09<02:28,  1.55s/it]


Evaluating:  69%|██████▉   | 213/308 [09:12<02:55,  1.85s/it]


Evaluating:  69%|██████▉   | 214/308 [09:14<02:58,  1.90s/it]


Evaluating:  70%|██████▉   | 215/308 [09:17<03:30,  2.26s/it]


Evaluating:  70%|███████   | 216/308 [09:20<03:51,  2.51s/it]


Evaluating:  70%|███████   | 217/308 [09:21<03:02,  2.00s/it]


Evaluating:  71%|███████   | 218/308 [09:22<02:49,  1.88s/it]


Evaluating:  71%|███████   | 219/308 [09:23<02:05,  1.41s/it]


Evaluating:  71%|███████▏  | 220/308 [09:25<02:32,  1.73s/it]


Evaluating:  72%|███████▏  | 221/308 [09:27<02:38,  1.82s/it]


Evaluating:  72%|███████▏  | 222/308 [09:30<03:10,  2.22s/it]


Evaluating:  72%|███████▏  | 223/308 [09:34<03:42,  2.61s/it]


Evaluating:  73%|███████▎  | 224/308 [09:38<04:16,  3.06s/it]


Evaluating:  73%|███████▎  | 225/308 [09:42<04:34,  3.31s/it]


Evaluating:  73%|███████▎  | 226/308 [09:45<04:20,  3.18s/it]


Evaluating:  74%|███████▎  | 227/308 [09:48<04:22,  3.23s/it]


Evaluating:  74%|███████▍  | 228/308 [09:52<04:27,  3.34s/it]


Evaluating:  74%|███████▍  | 229/308 [09:55<04:11,  3.19s/it]


Evaluating:  75%|███████▍  | 230/308 [09:58<04:07,  3.17s/it]


Evaluating:  75%|███████▌  | 231/308 [10:00<03:36,  2.81s/it]


Evaluating:  75%|███████▌  | 232/308 [10:02<03:22,  2.66s/it]


Evaluating:  76%|███████▌  | 233/308 [10:05<03:29,  2.80s/it]


Evaluating:  76%|███████▌  | 234/308 [10:10<04:06,  3.33s/it]


Evaluating:  76%|███████▋  | 235/308 [10:14<04:15,  3.50s/it]


Evaluating:  77%|███████▋  | 236/308 [10:17<04:13,  3.53s/it]


Evaluating:  77%|███████▋  | 237/308 [10:21<04:23,  3.71s/it]


Evaluating:  77%|███████▋  | 238/308 [10:25<04:23,  3.76s/it]


Evaluating:  78%|███████▊  | 239/308 [10:28<03:59,  3.48s/it]


Evaluating:  78%|███████▊  | 240/308 [10:33<04:19,  3.81s/it]


Evaluating:  78%|███████▊  | 241/308 [10:36<04:11,  3.75s/it]


Evaluating:  79%|███████▊  | 242/308 [10:39<03:50,  3.49s/it]


Evaluating:  79%|███████▉  | 243/308 [10:41<03:13,  2.97s/it]


Evaluating:  79%|███████▉  | 244/308 [10:44<03:12,  3.01s/it]


Evaluating:  80%|███████▉  | 245/308 [10:46<02:50,  2.70s/it]


Evaluating:  80%|███████▉  | 246/308 [10:47<02:17,  2.22s/it]


Evaluating:  80%|████████  | 247/308 [10:50<02:27,  2.42s/it]


Evaluating:  81%|████████  | 248/308 [10:52<02:23,  2.40s/it]


Evaluating:  81%|████████  | 249/308 [10:55<02:20,  2.38s/it]


Evaluating:  81%|████████  | 250/308 [10:58<02:30,  2.59s/it]


Evaluating:  81%|████████▏ | 251/308 [10:59<02:13,  2.33s/it]


Evaluating:  82%|████████▏ | 252/308 [11:01<01:58,  2.11s/it]


Evaluating:  82%|████████▏ | 253/308 [11:02<01:42,  1.87s/it]


Evaluating:  82%|████████▏ | 254/308 [11:03<01:26,  1.60s/it]


Evaluating:  83%|████████▎ | 255/308 [11:04<01:12,  1.36s/it]


Evaluating:  83%|████████▎ | 256/308 [11:06<01:18,  1.51s/it]


Evaluating:  83%|████████▎ | 257/308 [11:09<01:41,  1.98s/it]


Evaluating:  84%|████████▍ | 258/308 [11:12<01:56,  2.32s/it]


Evaluating:  84%|████████▍ | 259/308 [11:14<01:50,  2.25s/it]


Evaluating:  84%|████████▍ | 260/308 [11:16<01:36,  2.01s/it]


Evaluating:  85%|████████▍ | 261/308 [11:18<01:39,  2.12s/it]


Evaluating:  85%|████████▌ | 262/308 [11:20<01:33,  2.04s/it]


Evaluating:  85%|████████▌ | 263/308 [11:24<02:05,  2.80s/it]


Evaluating:  86%|████████▌ | 264/308 [11:28<02:09,  2.95s/it]


Evaluating:  86%|████████▌ | 265/308 [11:31<02:11,  3.06s/it]


Evaluating:  86%|████████▋ | 266/308 [11:35<02:24,  3.43s/it]


Evaluating:  87%|████████▋ | 267/308 [11:38<02:10,  3.19s/it]


Evaluating:  87%|████████▋ | 268/308 [11:42<02:15,  3.38s/it]


Evaluating:  87%|████████▋ | 269/308 [11:46<02:23,  3.68s/it]


Evaluating:  88%|████████▊ | 270/308 [11:48<02:00,  3.18s/it]


Evaluating:  88%|████████▊ | 271/308 [11:51<01:53,  3.08s/it]


Evaluating:  88%|████████▊ | 272/308 [11:54<01:45,  2.94s/it]


Evaluating:  89%|████████▊ | 273/308 [11:56<01:36,  2.77s/it]


Evaluating:  89%|████████▉ | 274/308 [11:58<01:23,  2.47s/it]


Evaluating:  89%|████████▉ | 275/308 [12:00<01:21,  2.48s/it]


Evaluating:  90%|████████▉ | 276/308 [12:03<01:21,  2.56s/it]


Evaluating:  90%|████████▉ | 277/308 [12:07<01:33,  3.03s/it]


Evaluating:  90%|█████████ | 278/308 [12:11<01:35,  3.18s/it]


Evaluating:  91%|█████████ | 279/308 [12:15<01:44,  3.60s/it]


Evaluating:  91%|█████████ | 280/308 [12:17<01:25,  3.06s/it]


Evaluating:  91%|█████████ | 281/308 [12:19<01:09,  2.59s/it]


Evaluating:  92%|█████████▏| 282/308 [12:20<00:57,  2.21s/it]


Evaluating:  92%|█████████▏| 283/308 [12:22<00:50,  2.02s/it]


Evaluating:  92%|█████████▏| 284/308 [12:23<00:46,  1.94s/it]


Evaluating:  93%|█████████▎| 285/308 [12:25<00:41,  1.80s/it]


Evaluating:  93%|█████████▎| 286/308 [12:27<00:44,  2.03s/it]


Evaluating:  93%|█████████▎| 287/308 [12:30<00:48,  2.29s/it]


Evaluating:  94%|█████████▎| 288/308 [12:32<00:44,  2.24s/it]


Evaluating:  94%|█████████▍| 289/308 [12:35<00:44,  2.36s/it]


Evaluating:  94%|█████████▍| 290/308 [12:38<00:47,  2.66s/it]


Evaluating:  94%|█████████▍| 291/308 [12:43<00:54,  3.19s/it]


Evaluating:  95%|█████████▍| 292/308 [12:48<00:59,  3.69s/it]


Evaluating:  95%|█████████▌| 293/308 [12:52<00:57,  3.83s/it]


Evaluating:  95%|█████████▌| 294/308 [12:55<00:50,  3.62s/it]


Evaluating:  96%|█████████▌| 295/308 [12:58<00:44,  3.40s/it]


Evaluating:  96%|█████████▌| 296/308 [13:03<00:47,  3.99s/it]


Evaluating:  96%|█████████▋| 297/308 [13:07<00:44,  4.04s/it]


Evaluating:  97%|█████████▋| 298/308 [13:09<00:33,  3.31s/it]


Evaluating:  97%|█████████▋| 299/308 [13:11<00:26,  2.92s/it]


Evaluating:  97%|█████████▋| 300/308 [13:11<00:17,  2.17s/it]


Evaluating:  98%|█████████▊| 301/308 [13:12<00:12,  1.74s/it]


Evaluating:  98%|█████████▊| 302/308 [13:14<00:10,  1.82s/it]


Evaluating:  98%|█████████▊| 303/308 [13:17<00:10,  2.13s/it]


Evaluating:  99%|█████████▊| 304/308 [13:20<00:09,  2.34s/it]


Evaluating:  99%|█████████▉| 305/308 [13:24<00:08,  2.77s/it]


Evaluating:  99%|█████████▉| 306/308 [13:27<00:06,  3.10s/it]


Evaluating: 100%|█████████▉| 307/308 [13:32<00:03,  3.56s/it]


Evaluating: 100%|██████████| 308/308 [13:32<00:00,  2.57s/it]


Evaluating: 100%|██████████| 308/308 [13:32<00:00,  2.64s/it]

(23.908249817885363, {'wikitext_head_1': 2.7227502832939097, 'wikitext_head_3': 2.876460525896642, 'wikitext_head_5': 2.860558650323323, 'wikitext_head_7': 2.877980626248694, 'wikitext_head_9': 2.8937691941663815, 'wikitext_head_11': 2.950445142659274, 'wikitext_head_13': 3.052337732407954, 'lm_head': 3.673947743007115})





Nothing super surprising here. For each transformer block of the LLM that we are passing, the hidden state contains more (linearly seperable) information about causal language modelling with wikitext data. That makes a lot of sense as the model pretraining was also for causal language modelling.

In [14]:
print(get_top_n_preds(5, model, "The historical significance of", tokenizer))

{'wikitext_head_1': ['the', '', 'D', 'this', 'Old'], 'wikitext_head_3': ['the', '', 'D', 'this', 'F'], 'wikitext_head_5': ['the', '', 'D', 'Var', 'these'], 'wikitext_head_7': ['the', '', 'D', 'Old', 'this'], 'wikitext_head_9': ['the', '', 'D', 'Old', 'Ha'], 'wikitext_head_11': ['the', '', 'O', 'D', 'this'], 'wikitext_head_13': ['the', '', 'this', 'a', 'A'], 'lm_head': ['the', 'this', 'a', '', 'The']}


The heads are now predicting more likely tokens.