<a href="https://colab.research.google.com/github/stvngo/Algoverse-AI-Model-Probing/blob/main/Linear_Probing_Qwen_3_0_6B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

When a model makes a correct prediction on a task it has been trained on, Probing classifeier can be used to identify if the model actually contains the relevant informatioin or knowledge required to make that prediction, or it is just making a lucky guess
- can be used to identify crucial insights for developing better models over time


### How it works

A nn takes it's input as a series of vectors, or representations, and transform them through a series of layers to produce an output
- develop representations that useful so that the final few layers of the network can be a good prediction

### Probes
- a features or representations from the model are easily seperable by a simple classifier ==> a probe
The only way the probe can perform well on this task is if the representation it is given are already good enough to make the prediction



## Using Qwen 3 0.6B to extract residual steam activations


In [11]:
# Install and load the model
!pip install transformers accelerate




In [12]:
# loading the transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn

model_name = "Qwen/Qwen1.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
model.eval()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1024,), eps=1e-06)
    (rotary_emb): 

In [13]:
print(model)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((1024,), eps=1e-06)
    (rotary_emb): 

## PTSProbeDataset
1. Loads PTS samples(```text```, ```pivotal_tokens```)
2. Tokenizes using Qwen tokenizer
3. Captures residual activations at a chosen layer
4. Aligns pivotal tokens to labels
5. Returns(activation, is_pivotal_label) pairs

In [14]:
from torch.utils.data import Dataset
import torch

class PTSProbeDataset(Dataset):
    def __init__(self, samples, tokenizer, model, layer_index=16):
        """
        samples: list of dicts with keys "text" and "pivotal_tokens"
        tokenizer: Qwen tokenizer
        model: Qwen2ForCausalLM model
        layer_index: transformer block to hook
        """
        self.samples = samples
        self.tokenizer = tokenizer
        self.model = model
        self.layer_index = layer_index
        self.residuals = []
        self.labels = []
        self.hook_handle = None  # Store hook handle for cleanup

        # Validate model structure
        if not hasattr(self.model, 'model') or not hasattr(self.model.model, 'layers'):
            raise AttributeError("Model doesn't have expected structure: model.model.layers")

        if len(self.model.model.layers) <= layer_index:
            raise IndexError(f"Layer index {layer_index} is out of range. Model has {len(self.model.model.layers)} layers")

        print(f"Initializing PTSProbeDataset with {len(samples)} samples, hooking layer {layer_index}")

        # Preprocess everything once
        self._prepare_data()

    def _get_activations_for_sample(self, encoded_input):
        """Get activations for a single sample using a temporary hook"""
        activations = {}

        def hook_fn(module, input, output):
            try:
                if isinstance(output, tuple):
                    output = output[0]
                if isinstance(output, torch.Tensor):
                    activations["residual"] = output.detach().clone()
                else:
                    print(f"Warning: Hook output is not a tensor, got {type(output)}")
            except Exception as e:
                print(f"Error in hook function: {e}")
                raise

        # Register hook temporarily
        target_layer = self.model.model.layers[self.layer_index]
        hook_handle = target_layer.register_forward_hook(hook_fn)

        try:
            with torch.no_grad():
                outputs = self.model(**encoded_input)

            if "residual" not in activations:
                print(f"Available keys in activations: {list(activations.keys())}")
                print(f"Model output type: {type(outputs)}")
                raise RuntimeError("Hook failed to capture activations")

            return activations["residual"]
        finally:
            # Always remove the hook
            hook_handle.remove()

    def _prepare_data(self):
        for sample in self.samples:
            text = sample["text"]
            pivotal_tokens = sample["pivotal_tokens"]

            # Tokenize with character offsets
            encoded = self.tokenizer(
                text,
                return_offsets_mapping=True,
                return_tensors="pt",
                truncation=True,
                padding=False  # Don't pad during tokenization
            )
            offsets = encoded["offset_mapping"][0].tolist()

            # Get activations for this sample
            resid = self._get_activations_for_sample(encoded)
            resid = resid.squeeze(0)  # [seq_len, hidden_dim]

            # Build binary token labels aligned to offsets
            token_labels = []
            for start, end in offsets:
                # Handle special tokens that might have (0,0) offsets
                if start == 0 and end == 0:
                    token_labels.append(0)  # Special tokens are not pivotal
                else:
                    token_str = text[start:end]
                    is_pivotal = any(piv.lower() in token_str.lower() for piv in pivotal_tokens)
                    token_labels.append(1 if is_pivotal else 0)

            # Convert labels to tensor
            token_labels = torch.tensor(token_labels, dtype=torch.float)

            # Ensure alignment between residuals and labels
            seq_len = resid.shape[0]
            if len(token_labels) != seq_len:
                # Truncate or pad labels to match sequence length
                if len(token_labels) > seq_len:
                    token_labels = token_labels[:seq_len]
                else:
                    # Pad with zeros (same dtype)
                    pad_len = seq_len - len(token_labels)
                    padding = torch.zeros(pad_len, dtype=torch.float)
                    token_labels = torch.cat([token_labels, padding])

            self.residuals.append(resid)
            self.labels.append(token_labels)

    def __len__(self):
        return len(self.residuals)

    def __getitem__(self, idx):
        return self.residuals[idx], self.labels[idx]

    def __del__(self):
        """Cleanup method (though __del__ isn't guaranteed to be called)"""
        self.cleanup()

    def cleanup(self):
        """Explicitly remove any remaining hooks"""
        if hasattr(self, 'hook_handle') and self.hook_handle is not None:
            self.hook_handle.remove()
            self.hook_handle = None

PTSProbeDataset gives us:
- ```resid```: residual activations from a layer --> shape ```[seq_len, hidden_dim```]
- ```labels```: binary labels for each token --> shape ```[seq_len]```

In [15]:
print("Model type:", type(model))
print("Has model.model?", hasattr(model, 'model'))
print("Model.model type:", type(getattr(model, 'model', None)))

if hasattr(model, 'model'):
    print("Model.model has layers?", hasattr(model.model, 'layers'))
    print("Model.model.layers type:", type(getattr(model.model, 'layers', None)))
    if hasattr(model.model, 'layers'):
        print("Model.model.layers length:", len(model.model.layers))
samples = [
    {"text": "The quick brown fox jumps over the lazy dog.", "pivotal_tokens": ["quick", "jumps", "dog"]},
    {"text": "The model interprets language better with more data.", "pivotal_tokens": ["interprets", "data"]}
]

try:
  dataset = PTSProbeDataset(samples, tokenizer, model, layer_index=16)
  print(f"Dataset created successfully with {len(dataset)} samples")

  # Test getting on item
  resid, labels = dataset[0]
  print(f"Residual shape: {resid.shape}, Labels shape: {labels.shape}")
except Exception as e:
  print(f"Error creating dataset: {e}")
  import traceback
  traceback.print_exc()


Model type: <class 'transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM'>
Has model.model? True
Model.model type: <class 'transformers.models.qwen2.modeling_qwen2.Qwen2Model'>
Model.model has layers? True
Model.model.layers type: <class 'torch.nn.modules.container.ModuleList'>
Model.model.layers length: 24
Initializing PTSProbeDataset with 2 samples, hooking layer 16
Dataset created successfully with 2 samples
Residual shape: torch.Size([10, 1024]), Labels shape: torch.Size([10])


## DataLoaders


In [16]:
from torch.utils.data import DataLoader

def flatten_collate(batch):
  x_list, y_list = zip(*batch)
  x = torch.cat(x_list, dim=0)
  y = torch.cat(y_list, dim=0)
  return x, y

  dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=flatten_collate)

## The probe


In [17]:
## Create the Linear
# Define the probe ==> a linear layer + sigmoid

import torch.nn as nn

class LinearProbe(nn.Module):
  def __init__(self, hidden_dim=1024):
    super().__init__()
    self.linear = nn.Linear(hidden_dim, 1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.linear(x)
    x = self.sigmoid(x).squeeze(-1)
    return x

## Train the probe


In [18]:
import torch.nn.functional as F

def train_probe(probe, dataloader, num_epochs=5, lr=1e-3, verbose=True):
  """
  Trains a probe on a residual activations with binary labels

  Args:
      probe (nn.Module): The probe model (e.g., LinearProbe)
      dataloader (torch.utils.data.DataLoader): Dataloader with residuals and labels
      epoch (int): Number of epochs to train for
  """

  probe.train()
  optimizer = torch.ptim.Adam(probe.paramters(), lr=lr)
  loss_fn = nn.BGELoss()

  # Training loop

  for epoch in range(num_epochs):
    total_loss = 0.0
    correct, total = 0, 0
    for x,y in dataloader:
      optimizer.zero_grad()
      preds = probe(x) #[total_tokens]
      loss = loss_fn(preds, y)
      loss.backward()
      optimizer.step()
      total_loss += loss.item()

      # Accurarcy
      predicted = (preds >= 0.5).long()
      correct += (predicted == y).sum().item()
      total += y.size(0)

    acc = correct / total
    if verbose:
      print(f"Epoch {epoch+1}: Loss = {total_loss:.4f} | Accuracy = {acc:.4f}")


In [22]:
# Initialize the probe using the hidden dimesion of Qwen-0.6B(1024)
probe = LinearProbe(hidden_dim=1024)


## Prediction: Use Probe to make predictions on new sentences


In [25]:
def predict_with_probe(text, tokenizer, model, probe, layer_index=16):
    residual_activations = {}

    def hook_fn(module, input, output):
        if isinstance(output, tuple):
            output = output[0]
        residual_activations["resid"] = output.detach()

    # Register hook
    hook_handle = model.model.layers[layer_index].register_forward_hook(hook_fn)

    # Tokenize input
    encoded = tokenizer(text, return_tensors="pt")
    input_ids = encoded["input_ids"]

    with torch.no_grad():
        _ = model(**encoded)
    hook_handle.remove()

    # Get residuals
    resid = residual_activations.get("resid", None)
    if resid is None:
        print("Hook failed to capture activations")
        return []

    # Remove batch dim
    resid = resid.squeeze(0)  # [seq_len, hidden_dim]

    # Run through probe
    scores = probe(resid)  # shape: [seq_len], values in [0,1]

    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))
    return list(zip(tokens, scores.tolist()))


In [26]:
# Example usage

text = "The quick brown fox jumps over the lazy dog."
predictions = predict_with_probe(text, tokenizer, model, probe)

for token, score in predictions:
  print(f"{token}: {score:.4f}")

The: 1.0000
Ġquick: 0.6418
Ġbrown: 0.5964
Ġfox: 0.6003
Ġjumps: 0.6182
Ġover: 0.6308
Ġthe: 0.5279
Ġlazy: 0.5356
Ġdog: 0.5321
.: 0.4272


### Analysis

1. The having a probability of 1 probably means overfitting on this token
2. With the threshold set at 0.5, the pivotal tokens are "quick", "brown", "fox", "jumps", "over", "the" "lazy" "dogs"

## Generate aligned token-level ```is_pivotal``` labels from the PTS dataset


Based on the data in the [PTS repo](https://github.com/codelion/pts)

We want to:
1. Tokenize the text(using Qwen tokenizer)
2. Align the pivotal words to tokens
3. Mark each token with a binary label
- `1` if it maps to a pivotal word
- `0` otherwise

In [7]:
def get_token_labels(text, tokenizer, pivotal_words):
  # Tokenize text with character offsets

  encoded = tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
  offsets = encoded.offset_mapping[0].tolist()
  tokens = tokenizer.convert_ids_to_tokens(encoded.input_ids[0])

  # Find character spans of each pivotal word
  token_labels = []
  for start, end in offsets:
    token_str = text[start:end]
    is_pivotal = any(token_str in word for word in pivotal_words)
    token_labels.append(1 if is_pivotal else 0)

  return encoded, token_labels

In [None]:
'''
# Tokenize and run input
text = "The quick brown fox jumps over the lazy dog"
pivotal_words = ["quick", "jumps", "dog"]
inputs = tokenizer(text, return_tensors="pt")
# we don't want the model to update the parameters so we don't use gradient descent
with torch.no_grad():
    _ = model(**inputs)

'''


```pivotal_tokens``` should be a list of strings be a list of strings, like ```["quick", "jumps", "dogs"]```

In [9]:
encoded, token_labels = get_token_labels(samples[0].text, tokenizer, samples[0].pivotal_words)

AttributeError: 'dict' object has no attribute 'text'

The Qwen2 uses a hidden size of 1024, that's the hidden_dim


In [50]:
# Align activations with pivotal labels
resid = resid.squeeze(0) # [seq_len, 1024]
labels = torch.tensor(token_labels).float() # [seq_len]

In [None]:
# Custom Collate Function
def flatten_collate(batch):
  x_list, y_list = zip(*batch)
  x = torch.cat(x_list, dim=0)
  y = torch.cat(y_list, dim=0)
  return x, y


## Evaluate Accuracy


In [6]:
with torch.no_grad():
  correct, total = 0, 0
  for x, y in dataloader:
    preds = probe(x)
    preds = (preds >= 0.5).float()
    correct += (preds == y).sum().item()
    total += y.size(0)

  acc = correct/total
  print(f"Accuracy: {acc: .4f}")

NameError: name 'dataloader' is not defined