Skip to content

t-shoemaker/lm_probe

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

About

This code is for lm_head, a little tool for training linear probes on neural language models. Use it to isolate model behavior via classification tasks.

Features:

  • Flexible probe configuration for classification tasks using logistic regression
  • Simultaneous, online training for multiple probes with minimal overhead for activation caches
  • Pre-/post-processing for model activations (probes default to mean pooling)
  • Metrics tracking for probe performance

Probes are trained using scikit-learn's SGDClassifier. The tool relies on nnsight to trace model activations. All you need is your own text and labels.

Caveats:

  • This is a small tool for quickly standing up probe experiments; brittleness and sharp corners are likely
  • Probes are limited to classification tasks

Installation

While a full set of dependencies is in pyproject.toml, I've made no attempt to configure installation for different CPU/GPU setups; install this tool into an environment that already has whatever PyTorch configuration you use.

To install with pip, run the following:

pip install git+https://github.com/t-shoemaker/lm_probe

Or, with pixi:

pixi add --pypi "lm_probe @ git+https://github.com/t-shoemaker/lm_probe"

Usage

You can run your probes on any LanguageModel from nnsight (which wraps Hugging Face models and PyTorch modules).

import torch
import pandas as pd
from nnsight import LanguageModel
from lm_probe import ProbeConfig, ProbeRunner, ProbeDataset


gpt2 = LanguageModel("gpt2", device_map="auto")

Data preparation

Probes learn directly from activations, so you'll need to tokenize your texts separately. Below, we use a sample of IMDB reviews. The probes will attempt to classify positive and negative reviews using model activations. We wrap the tokenized reviews in a ProbeDataset, which handles batching for the model.

# Download the data, map string labels to integers, and sample
df = pd.read_csv("hf://datasets/scikit-learn/imdb/IMDB Dataset.csv")
df["label"] = df["sentiment"].map({"positive": 1, "negative": 0})
df = df.sample(1_000)

# Tokenize with the model
reviews, labels = df["review"].tolist(), df["label"].tolist()
inputs = gpt2.tokenizer(
    reviews, padding="max_length", truncation=True, return_tensors="pt"
)

# Wrap everything up in a ProbeDataset and do a train/test split
ds = ProbeDataset(
    inputs["input_ids"], inputs["attention_mask"], labels
)
train_set, test_set = ds.tt_split(test_size=0.1)

Probe configuration

With nnsight, you can extract activations from any part of a neural network. This functionality is exposed in lm_probe with stringified submodule names. If, for example, you wanted to extract output activations from layer 5's MLP, you'd write transformer.h.5.mlp.output.

gpt2
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): WrapperModule()
)

Assign that submodule name to the first argument of a ProbeConfig along with the unique set of classes you're looking for, and you're ready to perform a probe.

classes = set(labels)
submodule = "transformer.h.5.mlp.output"
config = ProbeConfig(submodule, classes)

Training

Use a ProbeRunner to train the probe. It will handle pre-/post-processing for activations as well as the control flow for batching, evaluation, early stopping, etc.

runner = ProbeRunner(gpt2, [config])
runner.fit_probes(train_set, test_set, batch_size=8)
2024-12-04 09:47:11 | INFO | Early stopping after 18 steps for transformer.h.5.mlp.output
2024-12-04 09:47:11 | INFO | Finished training 1 probe(s). Computing metrics

Inspect the results with the .metrics attribute. This holds a DataFrame of metrics, which record a probe's performance on the testing set.

print(runner.metrics.to_string())
                                loss  accuracy  precision    recall        f1  matthews
transformer.h.5.mlp.output  4.505457     0.875   0.881834  0.875941  0.874613  0.757752

Using trained probes

The runner's .get_probe_features() method extracts its probes' features for some input and returns them in a simple dictionary.

docs = test_set[0:5]
features = runner.get_probe_features(docs["input_ids"], docs["attention_mask"])
features[submodule].shape
(5, 768)

Index a ProbeRunner by a submodule to get the corresponding probe. With that, you can make a prediction on features.

probe = runner[submodule]
probe.predict(features[submodule])
array([0, 1, 0, 0, 0])

Parallelized Probing

The real point of lm_probe is that it parallelizes probe training. It does this with minimal activation caching, relying instead on nnsight to trace model layers during processing. While this front-loads memory costs during model processing, it enables quick experimentation with different probe configurations without having to keep all activations in memory. With that in mind, we initialize a few probes below.

configs = [
    ProbeConfig(
        "transformer.h.11.output",
        classes,
        test_size=0.5,
        patience=10,
        warmup_steps=25,
    ),
    ProbeConfig(
        "transformer.h.4.mlp.output",
        classes,
        test_size=0.5,
        patience=10,
        warmup_steps=25,
    ),
    ProbeConfig(
        "transformer.h.7.attn.output",
        classes,
        test_size=0.5,
        patience=10,
        warmup_steps=25,
    ),
]

Time to train.

runner = ProbeRunner(gpt2, configs)
runner.fit_probes(train_set, test_set, batch_size=8)
2024-12-04 09:50:59 | INFO | Early stopping after 38 steps for transformer.h.4.mlp.output
2024-12-04 09:51:28 | INFO | Early stopping after 42 steps for transformer.h.7.attn.output
2024-12-04 09:51:58 | INFO | Early stopping after 46 steps for transformer.h.11.output
2024-12-04 09:51:58 | INFO | Finished training 3 probe(s). Computing metrics

Which set of model activations leads to the most accurate classification?

print(
    runner.metrics
    .sort_values(["loss", "matthews"], ascending=[True, False])
    .to_string()
)
                                  loss  accuracy  precision    recall        f1  matthews
transformer.h.11.output       6.590683  0.815217   0.814490  0.815834  0.814774  0.630322
transformer.h.4.mlp.output    6.870814  0.807065   0.806108  0.807062  0.806435  0.613169
transformer.h.7.attn.output  11.722360  0.668478   0.673981  0.672990  0.668390  0.346969

About

Train linear probes on neural language models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages