# Train multiple linear probes at once
In this notebook we will find out at which layer a transformer has the most linearly seperable information required to do causal language modelling on wikitext data. The nice thing about *transformer_heads* is that this will all be possible with just one training run.

In [1]:
from transformer_heads import load_headed
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    MistralForCausalLM,
    Trainer,
    BitsAndBytesConfig,
    TrainingArguments,
    GPT2Model,
    GPT2LMHeadModel,
)
from transformer_heads.util.helpers import DataCollatorWithPadding, get_model_params
from peft import LoraConfig
from transformer_heads.config import HeadConfig
from transformer_heads.util.model import print_trainable_parameters
from transformer_heads.util.evaluate import evaluate_head_wise, get_top_n_preds
import torch

In [2]:
model_path = "gpt2"
train_epochs = 1
eval_epochs = 1
logging_steps = 100

In [3]:
# Parameters
model_path = "meta-llama/Llama-2-7b-hf"
train_epochs = 0.1
eval_epochs = 0.1
logging_steps = 40


In [4]:
model_params = get_model_params(model_path)
model_class = model_params["model_class"]
hidden_size = model_params["hidden_size"]
vocab_size = model_params["vocab_size"]
print(model_params)

{'model_class': <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>, 'hidden_size': 4096, 'vocab_size': 32000}


Let's define a lot of heads in a loop. The heads will be hooked at layer -1,-3,-5,-7,-9,-11,-13. This is using python indexing: Layer -1 means after the last transformer block for example. We'll keep the original pretrained lm_head of the transformer model for comparison.

In [5]:
head_configs = [
    HeadConfig(
        name=f"wikitext_head_{(1+(i-1)*2)}",
        layer_hook=-(1 + (i - 1) * 2),
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
    )
    for i in range(1, 8)
]
head_configs.append(
    HeadConfig(
        name=f"lm_head",
        layer_hook=-1,
        in_size=hidden_size,
        hidden_size=0,
        num_layers=1,
        output_activation="linear",
        is_causal_lm=True,
        loss_fct="cross_entropy",
        num_outputs=vocab_size,
        is_regression=False,
        output_bias=False,
        trainable=False,
    )
)

In [6]:
dd = load_dataset("wikitext", "wikitext-2-v1")

In the *tokenize_function*, we define labels for each head. For causal_lm, this is just the copied input_ids.

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token


def tokenize_function(examples):
    out = tokenizer(examples["text"], padding=False)
    for hc in head_configs:
        out[hc.name] = out["input_ids"].copy()
    return out


for split in dd.keys():
    dd[split] = dd[split].filter(function=lambda example: len(example["text"]) > 10)
    dd[split] = dd[split].map(tokenize_function, batched=True)
dd.set_format(
    type="torch",
    columns=["input_ids", "attention_mask"] + [x.name for x in head_configs],
)
for split in dd.keys():
    dd[split] = dd[split].remove_columns("text")

Map:   0%|          | 0/2870 [00:00<?, ? examples/s]

Map:   0%|          | 0/23627 [00:00<?, ? examples/s]

Map:   0%|          | 0/2460 [00:00<?, ? examples/s]

In [8]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    load_in_8bit=False,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
    bnb_4bit_compute_dtype=torch.float32,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = load_headed(
    model_class,
    model_path,
    head_configs=head_configs,
    quantization_config=quantization_config,
    device_map={"": torch.cuda.current_device()},
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of TransformerWithHeads were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['heads.wikitext_head_1.lins.0.weight', 'heads.wikitext_head_11.lins.0.weight', 'heads.wikitext_head_13.lins.0.weight', 'heads.wikitext_head_3.lins.0.weight', 'heads.wikitext_head_5.lins.0.weight', 'heads.wikitext_head_7.lins.0.weight', 'heads.wikitext_head_9.lins.0.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Huggingface tells us that our newly added heads are newly initialized. Great.

In [9]:
print_trainable_parameters(model)

all params: 4417916928 || trainable params: 917504000 || trainable%: 20.76779656459851
params by dtype: defaultdict(<class 'int'>, {torch.float32: 1179914240, torch.uint8: 3238002688})
trainable params by dtype: defaultdict(<class 'int'>, {torch.float32: 917504000})


A lot of heads with large vocab size -> High amount of trainable parameters

In [10]:
dd["train"]

Dataset({
    features: ['input_ids', 'attention_mask', 'wikitext_head_1', 'wikitext_head_3', 'wikitext_head_5', 'wikitext_head_7', 'wikitext_head_9', 'wikitext_head_11', 'wikitext_head_13', 'lm_head'],
    num_rows: 23627
})

In [11]:
print(get_top_n_preds(5, model, "The historical significance of", tokenizer))

{'wikitext_head_1': ['力', 'Wilson', 'Ps', 'Louis', 'Char'], 'wikitext_head_3': ['convergence', 'Liber', 'Source', 'pc', 'pom'], 'wikitext_head_5': ['Nav', 'Fac', 'equations', 'Vert', '�'], 'wikitext_head_7': ['charged', 'constraint', 'decay', 'anten', "'):"], 'wikitext_head_9': ['ISO', 'selenium', 'Gas', 'eingesetzt', 'ves'], 'wikitext_head_11': ['vs', 'grade', 'varepsilon', 'コ', '."'], 'wikitext_head_13': ['ihm', 'epo', 'listen', 'épisode', 'terug'], 'lm_head': ['the', 'this', 'a', '', 'The']}


The untrained heads are predicting somewhat randomly

In the collator, we need to make sure that the labels for each head are padded correctly. Here, we are padding with -100, the ignore_index token for cross_entropy.

In [12]:
args = TrainingArguments(
    output_dir="linear_probe_test",
    learning_rate=0.0002,
    num_train_epochs=train_epochs,
    logging_steps=logging_steps,
    do_eval=False,
    remove_unused_columns=False,
)
collator = DataCollatorWithPadding(
    feature_name_to_padding_value={
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
        **{key.name: -100 for key in head_configs},
    }
)
trainer = Trainer(
    model,
    args=args,
    train_dataset=dd["train"],
    data_collator=collator,
)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mykeller[0m ([33mchm-hci[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: wandb version 0.16.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Tracking run with wandb version 0.16.3


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/raven/u/ykeller/transformer_heads/wandb/run-20240322_144416-1apyu4oh[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mamber-disco-190[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/chm-hci/huggingface[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/chm-hci/huggingface/runs/1apyu4oh[0m


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...




Step,Training Loss
40,70.8382
80,52.1676
120,44.7254
160,41.396
200,39.0564
240,38.3483
280,36.635


TrainOutput(global_step=296, training_loss=45.68603969264675, metrics={'train_runtime': 2073.3781, 'train_samples_per_second': 1.14, 'train_steps_per_second': 0.143, 'total_flos': 3.1915708082159616e+16, 'train_loss': 45.68603969264675, 'epoch': 0.1})

In [13]:
print(evaluate_head_wise(model, dd["validation"], collator, epochs=eval_epochs))

Evaluating:   0%|          | 0/30.8 [00:00<?, ?it/s]



Evaluating:   3%|▎         | 1/30.8 [00:02<00:59,  2.00s/it]

Evaluating:   6%|▋         | 2/30.8 [00:03<00:57,  1.99s/it]

Evaluating:  10%|▉         | 3/30.8 [00:05<00:50,  1.81s/it]

Evaluating:  13%|█▎        | 4/30.8 [00:07<00:47,  1.78s/it]

Evaluating:  16%|█▌        | 5/30.8 [00:10<00:57,  2.24s/it]

Evaluating:  19%|█▉        | 6/30.8 [00:14<01:09,  2.79s/it]

Evaluating:  23%|██▎       | 7/30.8 [00:18<01:14,  3.14s/it]

Evaluating:  26%|██▌       | 8/30.8 [00:21<01:12,  3.20s/it]

Evaluating:  29%|██▉       | 9/30.8 [00:24<01:07,  3.10s/it]

Evaluating:  32%|███▏      | 10/30.8 [00:28<01:10,  3.39s/it]

Evaluating:  36%|███▌      | 11/30.8 [00:31<01:03,  3.23s/it]

Evaluating:  39%|███▉      | 12/30.8 [00:33<00:55,  2.93s/it]

Evaluating:  42%|████▏     | 13/30.8 [00:35<00:48,  2.73s/it]

Evaluating:  45%|████▌     | 14/30.8 [00:36<00:38,  2.27s/it]

Evaluating:  49%|████▊     | 15/30.8 [00:39<00:37,  2.35s/it]

Evaluating:  52%|█████▏    | 16/30.8 [00:42<00:35,  2.42s/it]

Evaluating:  55%|█████▌    | 17/30.8 [00:44<00:34,  2.48s/it]

Evaluating:  58%|█████▊    | 18/30.8 [00:47<00:31,  2.49s/it]

Evaluating:  62%|██████▏   | 19/30.8 [00:51<00:35,  3.04s/it]

Evaluating:  65%|██████▍   | 20/30.8 [00:55<00:35,  3.29s/it]

Evaluating:  68%|██████▊   | 21/30.8 [00:57<00:29,  3.00s/it]

Evaluating:  71%|███████▏  | 22/30.8 [00:58<00:20,  2.31s/it]

Evaluating:  75%|███████▍  | 23/30.8 [00:58<00:13,  1.78s/it]

Evaluating:  78%|███████▊  | 24/30.8 [01:02<00:15,  2.32s/it]

Evaluating:  81%|████████  | 25/30.8 [01:05<00:13,  2.38s/it]

Evaluating:  84%|████████▍ | 26/30.8 [01:05<00:09,  1.92s/it]

Evaluating:  88%|████████▊ | 27/30.8 [01:07<00:06,  1.69s/it]

Evaluating:  91%|█████████ | 28/30.8 [01:08<00:04,  1.62s/it]

Evaluating:  94%|█████████▍| 29/30.8 [01:09<00:02,  1.43s/it]

Evaluating:  97%|█████████▋| 30/30.8 [01:11<00:01,  1.54s/it]

  full_bar = Bar(frac,
Evaluating: 101%|██████████| 31/30.8 [01:13<00:00,  1.71s/it]

Evaluating: 101%|██████████| 31/30.8 [01:15<00:00,  2.44s/it]

(35.15710258483887, {'wikitext_head_1': 4.096696861088276, 'wikitext_head_3': 4.405081398785114, 'wikitext_head_5': 4.553819946944714, 'wikitext_head_7': 4.596806988120079, 'wikitext_head_9': 4.691593900322914, 'wikitext_head_11': 4.619848169386387, 'wikitext_head_13': 4.73256553709507, 'lm_head': 3.4606898948550224})





That one is some interesting result. Despite the fact that gpt2 was pretrained to do causal_lm after the last transformer block, the linear probe hooked after the third to last block performs the best for causal_lm on wikitext data.

In [14]:
print(get_top_n_preds(5, model, "The historical significance of", tokenizer))

{'wikitext_head_1': ['the', 'a', '', 'their', 'both'], 'wikitext_head_3': ['the', '', 'this', 'E', 'V'], 'wikitext_head_5': ['the', '', 'society', 'Y', 'its'], 'wikitext_head_7': ['the', '', 'O', 'sculpt', 'N'], 'wikitext_head_9': ['the', '"', 'a', 'their', 'this'], 'wikitext_head_11': ['the', '', 'a', 'this', 'his'], 'wikitext_head_13': ['the', '', 'this', 'a', 'his'], 'lm_head': ['the', 'this', 'a', '', 'The']}


The heads are now predicting pretty likely tokens. Note that predicting *<* is an artefact of the wikitext data.