<a href="https://colab.research.google.com/github/stvngo/Algoverse-AI-Model-Probing/blob/SN-updates/Steven_Qwen_PTS_Linear_Probing_Prototype.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dataset Preparation and Model Probing

Link to our GitHub repository: https://github.com/stvngo/Algoverse-AI-Model-Probing

Link to this colab: https://colab.research.google.com/drive/1lPYyJzPMA3MBKDzJQ-X3hVCp_kEFky1s#scrollTo=363e9e8d&uniqifier=2

**Model Paths**
*   [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B): Qwen/Qwen3-0.6B
*   [DeepSeek R1 Distill Qwen 1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B): deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B




In [23]:
# install necessary libraries
!pip install datasets --upgrade
!pip install transformers --upgrade
!pip install einops --upgrade



# Testing and Prompting on Qwen 3 0.6B with Qwen3 PTS dataset

# Load Dataset

- Load dataset through huggingface path
- Imported sklearn train and test split function
- First, we split by query, then create many negative examples while extracting token positions and labels.
- Lastly, balance the dataset with twice the original shape
- TODO: For some reason, the dataset splits vary each time, despite using random_states=42, random.seed(42), and torch.manual_seed(42). Perhaps I may be mistaken?

In [24]:
from datasets import load_dataset, Dataset
from typing import List, Dict, Tuple, Optional
from sklearn.model_selection import train_test_split
import pandas as pd

def split_pts_by_query(dataset_path: str, test_size: float = 0.2) -> Tuple[Dataset, Dataset]:
    """
    Load PTS dataset and split by query ID to avoid data leakage.

    :param dataset_path: Path/name of your PTS dataset on HuggingFace
    :param test_size: Fraction for test split
    :return: train_dataset, test_dataset split by query
    """
    # Load the PTS dataset with explicit configuration
    print(f"Loading dataset: {dataset_path}")

    try:
        # Try loading without any wildcards or special patterns
        dataset = load_dataset(dataset_path, split='train')
        print(f"Loaded {len(dataset)} examples")

    except Exception as e:
        print(f"Error with split='train', trying default loading: {e}")
        try:
            # Try loading all splits then select one
            dataset_dict = load_dataset(dataset_path)
            print(f"Available splits: {list(dataset_dict.keys())}")

            # Get the main split
            if 'train' in dataset_dict:
                dataset = dataset_dict['train']
            else:
                split_name = list(dataset_dict.keys())[0]
                dataset = dataset_dict[split_name]
                print(f"Using split: {split_name}")

        except Exception as e2:
            print(f"Final error: {e2}")
            print("Try loading the dataset manually first to debug")
            raise e2

    # Get unique query IDs
    unique_query_ids = list(set(dataset['dataset_item_id']))
    print(f"Total unique queries: {len(unique_query_ids)}")

    # Split query IDs (not individual examples)
    train_query_ids, test_query_ids = train_test_split( # train: 1,3,4,... | test: 2,5,...
        unique_query_ids,
        test_size=test_size,
        random_state=42 # for reproducibility
    )

    # Filter dataset by query splits
    train_dataset = dataset.filter(lambda x: x['dataset_item_id'] in train_query_ids)
    test_dataset = dataset.filter(lambda x: x['dataset_item_id'] in test_query_ids)

    print(f"Train queries: {len(train_query_ids)}, Train examples: {len(train_dataset)}")
    print(f"Test queries: {len(test_query_ids)}, Test examples: {len(test_dataset)}")

    return train_dataset, test_dataset

# Load Model(s)

Load the model with configurations for interp work, including disabled gradients and activation extraction.italicized text

In [37]:
# Here is the current query...# import necessary packages
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time # Import time for timing

# manual seed for reproducibility
torch.manual_seed(42)

# torch.set_default_device("cuda")

# check device availability (save resources) - UNCOMMENT AND USE THIS LINE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# model name
model_name = "Qwen/Qwen3-0.6B"

# load model and tokenizer
# Ensure model and tokenizer are on the correct device AFTER loading
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype="auto",
                                             trust_remote_code=True,
                                             output_hidden_states=True) # access internal activations

# Move the model to the determined device
model.to(device)
print("Model moved to device.")


tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

print("Model and tokenizer loaded.")

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Using device: cuda
Model moved to device.
Model and tokenizer loaded.


**Tokenization Validation** (not needed)

Critical function that checks whether the pivot tokens in the dataset align with how the model actually tokenizes the text. This is essential since misalignment would invalidate interpretability results. (If there's a mismatch, the pivotal token labels are wrong)

Notes:


*   tokenizer.encode(): simpler, lower-level method that returns token IDs as a list of integers.


*   tokenizer(): full-featured, high-level method that returns a dictionary with multiple components, i.e. token ids, attention mask, token type ids, etc. Use for actual model forward passes to get the attention masks. Needed later because when extracting activations at specific token positions, need to know which positions are actual tokens vs padding.


*   tokenizer.decode(): reverse of tokenization, converts token IDs back into human-readable text





In [26]:
# # token validation as a utility

# def validate_tokenization(example_idx: int = 0):
#   """Validate that pivot tokens align with model tokenization"""
#   row = dataset.iloc[example_idx]

#   # get the pivot context and token
#   pivot_context = row["pivot_context"]
#   pivot_token = row["pivot_token"]

#   # tokenize the context
#   context_tokens = tokenizer.encode(pivot_context, return_tensors="pt")

#   # tokenize the context + pivot
#   full_sequence = pivot_context + pivot_token
#   full_tokens = tokenizer.encode(full_sequence, return_tensors="pt")

#   # check if adding the pivot token matches expectations
#   context_length = context_tokens.shape[1]
#   full_length = full_tokens.shape[1]

#   print(f"Example {example_idx}:")
#   print(f"  Context length: {context_length} tokens")
#   print(f"  Full sequence length: {full_length} tokens")
#   print(f"  Pivot token: '{pivot_token}'")
#   print(f"  Context ends with: '{tokenizer.decode(context_tokens[0][-3:])}'") # 2-d tensor, use [0] first
#   print(f"  Full sequence ends with: '{tokenizer.decode(full_tokens[0][-3:])}'")

#   # Extract what the model thinks is the next token after context
#   if full_length > context_length:
#       predicted_next_tokens = full_tokens[0][context_length:]
#       predicted_next_text = tokenizer.decode(predicted_next_tokens)
#       print(f"  Model's next token(s): '{predicted_next_text}'")
#       print(f"  Matches pivot token: {predicted_next_text.strip() == pivot_token.strip()}")

#   return context_tokens, full_tokens

# context_tokens, full_tokens = validate_tokenization(example_idx=0)
# print(context_tokens)
# print(full_tokens) # 32 is the last token encoded
# print(f'Dataset pivot token id: {dataset.get("pivot_token_id").iloc[0]}') # should match 32 from the dataset

**Generation Settings**

Settings for reproducing the model's behavior (validation), extracting internal representations at the right positions during forward passes, and verifying model setup matches dataset creation setup.

Notes:
*   max_new_tokens=1: only
*   do_sample=False: always pick highest probable token, ensures reproducibility
*   pad_token_id + eos_token_id: what tokens represent padding and end-of-sequence for sequence boundaries
*   output_attentions=False: return attention weights from all layers (not needed, save resources)
*   **output_hidden_states=True: returns activations from ALL LAYERS, needed to train linear probes**
*   return_dict_in_generate=True: aceess to hidden states during generation



In [27]:
# generation settings
# generation_config = {
#     'max_new_tokens': 1,  # we're only interested in the next token for probability calculation
#     'do_sample': False,   # deterministic generation
#     'temperature': 1.0,
#     'top_p': 1.0,
#     'pad_token_id': tokenizer.pad_token_id,
#     'eos_token_id': tokenizer.eos_token_id,
#     'output_hidden_states': True,  # CRITICAL for activation extraction
#     'output_attentions': False,    # not needed for now, saves memory
#     'return_dict_in_generate': True,
#     'use_cache': False,  # disable caching for cleaner memory usage
# }

# Dataset Preparation
Create a balanced dataset for training a linear probe. The dataset should contain positive examples (text sequences ending at the token before the original pivot token, labeled 1) and an equal number of negative examples (text sequences ending at other non-pivot positions within the same contexts, labeled 0). The final dataset should be in a format suitable for training, such as a HuggingFace `Dataset` or pandas DataFrame, containing the text sequence, the relevant token position, and the binary label.

## Extract token positions and labels

*   Identify positive examples (position before the original pivot token).
*   Identify potential negative examples (other positions within the context).
*   Create pairs of (text sequence, token position, label) for both positive and negative examples.


Define the function `prepare_balanced_probe_data` to iterate through the dataset, identify positive and negative positions, and collect the data in the specified format.



In [39]:
def prepare_balanced_probe_data(dataset, tokenizer, model) -> List[Dict]:
    """
    Prepare data for linear probe training by extracting token positions and labels,
    identifying positive and negative examples based on the 'is_pivotal' concept.

    For each query in the raw dataset:
    - Identify the position right before the original pivot token (labeled 1 for 'is_pivotal').
    - Identify all other token positions within the same context (labeled 0 for 'is_pivotal').
    - Create records of (text sequence, token position, label) for each.

    :param dataset: HuggingFace dataset containing raw PTS data.
    :param tokenizer: Tokenizer for the model.
    :param model: Model (needed for potential tokenization validation).
    :return: List of dictionaries, each representing a token-position example.
    """
    all_examples = []

    # Ensure model is on the correct device if needed for validation
    device = next(model.parameters()).device

    # Iterate through each example in the dataset
    print(f"Preparing data from {len(dataset)} raw examples...")

    for i, example in enumerate(dataset): # loop through examples and index, get necessary items
        pivot_context = example['pivot_context']
        # Original is_positive (delta success probability) is NOT used for the probe label.
        # is_positive_original = example['is_positive'] # This field is NOT needed for the probe label
        pivot_token = example['pivot_token']
        dataset_item_id = example.get('dataset_item_id', None) # Get original ID if available

        # Tokenize the pivot context to get sequence length
        context_inputs = tokenizer(pivot_context, return_tensors='pt', add_special_tokens=False)
        context_input_ids = context_inputs['input_ids'].to(device)
        seq_len = context_input_ids.shape[1]

        # --- Tokenization Alignment Validation ---
        # Tokenize the full sequence (context + pivot token) to verify alignment
        full_sequence = pivot_context + pivot_token
        full_inputs = tokenizer(full_sequence, return_tensors='pt', add_special_tokens=False)
        full_input_ids = full_inputs['input_ids'].to(device)

        # Check if the full sequence tokenization length matches context length + pivot token length
        pivot_token_ids = tokenizer.encode(pivot_token, add_special_tokens=False)
        expected_full_seq_len = seq_len + len(pivot_token_ids)

        # Also check if the context part of the full tokenization matches the context tokenization
        context_matches = torch.equal(full_input_ids[0, :seq_len], context_input_ids[0,:])

        if full_input_ids.shape[1] != expected_full_seq_len or not context_matches:
             print(f"Warning: Raw example {i} (dataset_item_id: {dataset_item_id}) - Tokenization mismatch. Skipping.")
             continue
        # --- End Validation ---


        # The position right before the original pivot token is the last token of the context
        positive_position = seq_len - 1 # -1 b/c 0-based indexing

        # Add the positive example
        # Based on clarification: Label is 1 for the position before the original pivot token

        all_examples.append({
            'text': pivot_context,
            'token_position': positive_position,
            'label': 1, # Label is 1 for the pivotal position
            'original_dataset_item_id': dataset_item_id # Keep track of source query
        })

        # Add negative examples (all other positions in the context)
        for pos in range(seq_len):
            if pos != positive_position:
                # Based on clarification: Label is 0 for any position that is NOT the pivotal one
                all_examples.append({
                    'text': pivot_context,
                    'token_position': pos,
                    'label': 0, # Label is 0 for non-pivotal positions
                    'original_dataset_item_id': dataset_item_id
                })

    print(f"Collected {len(all_examples)} total potential examples.")
    return all_examples # list of dicts with the keys: text, token position, label, and query id

Call the newly defined `prepare_balanced_probe_data` function with the raw training and testing datasets to generate the list of potential examples for each split.

In [40]:
# Re-execute the split function to get the raw datasets
train_raw, test_raw = split_pts_by_query("codelion/Qwen3-0.6B-pts", test_size=0.2)

# Now call the data preparation function with the raw datasets
train_examples_raw_list = prepare_balanced_probe_data(train_raw, tokenizer, model)
test_examples_raw_list = prepare_balanced_probe_data(test_raw, tokenizer, model)

print(f"\nPrepared {len(train_examples_raw_list)} raw examples for training.")
print(f"Prepared {len(test_examples_raw_list)} raw examples for testing.")

Loading dataset: codelion/Qwen3-0.6B-pts
Loaded 1376 examples
Total unique queries: 104
Train queries: 83, Train examples: 1075
Test queries: 21, Test examples: 301
Preparing data from 1075 raw examples...
Collected 249585 total potential examples.
Preparing data from 301 raw examples...
Collected 64380 total potential examples.

Prepared 249585 raw examples for training.
Prepared 64380 raw examples for testing.


## Balance the dataset

Implement logic to sample the negative examples to match the number of positive examples and create the final balanced training and testing datasets.


Separate positive and negative examples, sample negative examples to match the number of positive examples in both train and test sets, combine them, and convert the balanced lists of dictionaries into HuggingFace Datasets.



In [41]:
import random
from datasets import Dataset # creates batches of examples

# Set random seed for reproducibility
random.seed(42)

# --- Balancing Training Data ---
# Separate positive and negative training examples
train_pos_examples = [ex for ex in train_examples_raw_list if ex['label'] == 1]
train_neg_examples = [ex for ex in train_examples_raw_list if ex['label'] == 0]

print(f"Original train data: {len(train_pos_examples)} positive, {len(train_neg_examples)} negative")

# Count positive training examples
num_train_pos = len(train_pos_examples)

# Sample negative training examples by the same amount of positive examples
if len(train_neg_examples) >= num_train_pos:
    sampled_train_neg_examples = random.sample(train_neg_examples, num_train_pos)
else:
    # If not enough negative examples, take all of them
    sampled_train_neg_examples = train_neg_examples
    print(f"Warning: Not enough negative train examples ({len(train_neg_examples)}) to match positive ({num_train_pos}). Using all available negative examples.")


# Combine positive and sampled negative training examples
balanced_train_examples_list = train_pos_examples + sampled_train_neg_examples
random.shuffle(balanced_train_examples_list) # Shuffle the combined list

# Convert to HuggingFace Dataset
train_dataset_balanced = Dataset.from_list(balanced_train_examples_list)
print(f"Balanced train dataset size: {len(train_dataset_balanced)}")

# ------------------------------------------------------------------------------

# --- Balancing Testing Data ---
# Separate positive and negative testing examples
test_pos_examples = [ex for ex in test_examples_raw_list if ex['label'] == 1]
test_neg_examples = [ex for ex in test_examples_raw_list if ex['label'] == 0]

print(f"Original test data: {len(test_pos_examples)} positive, {len(test_neg_examples)} negative")

# Count positive testing examples
num_test_pos = len(test_pos_examples)

# Sample negative testing examples
if len(test_neg_examples) >= num_test_pos:
    sampled_test_neg_examples = random.sample(test_neg_examples, num_test_pos)
else:
    # If not enough negative examples, take all of them
    sampled_test_neg_examples = test_neg_examples
    print(f"Warning: Not enough negative test examples ({len(test_neg_examples)}) to match positive ({num_test_pos}). Using all available negative examples.")

# Combine positive and sampled negative testing examples
balanced_test_examples_list = test_pos_examples + sampled_test_neg_examples
random.shuffle(balanced_test_examples_list) # Shuffle the combined list

# Convert to HuggingFace Dataset
test_dataset_balanced = Dataset.from_list(balanced_test_examples_list)
print(f"Balanced test dataset size: {len(test_dataset_balanced)}")

# ------------------------------------------------------------------------------

# Verify balancing by printing counts in balanced datasets
print("\nVerification of balanced datasets:")
print(f"Balanced train dataset: {train_dataset_balanced.filter(lambda x: x['label'] == 1).num_rows} positive, {train_dataset_balanced.filter(lambda x: x['label'] == 0).num_rows} negative")
print(f"Balanced test dataset: {test_dataset_balanced.filter(lambda x: x['label'] == 1).num_rows} positive, {test_dataset_balanced.filter(lambda x: x['label'] == 0).num_rows} negative")

Original train data: 1075 positive, 248510 negative
Balanced train dataset size: 2150
Original test data: 301 positive, 64079 negative
Balanced test dataset size: 602

Verification of balanced datasets:


Filter:   0%|          | 0/2150 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2150 [00:00<?, ? examples/s]

Balanced train dataset: 1075 positive, 1075 negative


Filter:   0%|          | 0/602 [00:00<?, ? examples/s]

Filter:   0%|          | 0/602 [00:00<?, ? examples/s]

Balanced test dataset: 301 positive, 301 negative


# Linear Probe, Dataset Structure, Activation Extraction, Saving States, and Training

In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset # Use alias to avoid conflict
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer # Import necessary classes
import numpy as np # Import numpy for potential use
from tqdm.auto import tqdm # Import tqdm for progress bars
import os # Import os for path joining

# Configure logging for visibility
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define a directory to save probe states
PROBE_SAVE_DIR = "./probe_states"
os.makedirs(PROBE_SAVE_DIR, exist_ok=True)


# Simplified Probe class (Linear -> Sigmoid)
class SimpleProbe(nn.Module):
    def __init__(self, hidden_dim: int, dtype: torch.dtype = torch.float32):
        """
        Initialize the simplified linear probe.

        :param hidden_dim: The dimensionality of the input activations.
        :param dtype: The data type for the probe's parameters.
        """
        super(SimpleProbe, self).__init__()
        self.linear = nn.Linear(hidden_dim, 1) # Output dimension is 1 for binary classification
        self.sigmoid = nn.Sigmoid()
        # Ensure probe parameters are of the specified dtype
        self.to(dtype)


    def forward(self, x):
        """
        Forward pass through the probe.

        :param x: Input tensor (activations). Shape: (batch_size, hidden_dim)
        :return: Output tensor (probabilities). Shape: (batch_size, 1)
        """
        return self.sigmoid(self.linear(x))

# Custom Dataset class to handle our balanced data structure
class ProbeDataset(Dataset):
    def __init__(self, hf_dataset: HFDataset):
        """
        Initialize the custom dataset from a HuggingFace Dataset.

        :param hf_dataset: The HuggingFace Dataset containing 'text', 'token_position', and 'label'.
        """
        self.texts = hf_dataset['text']
        self.token_positions = hf_dataset['token_position']
        self.labels = hf_dataset['label']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Get a single example from the dataset.

        :param idx: Index of the example.
        :return: Dictionary containing text, token_position, and label.
        """
        return {
            'text': self.texts[idx],
            'token_position': self.token_positions[idx],
            'label': self.labels[idx]
        }


# Function to get position-specific activations for a given layer and batch
def get_position_activations_batch(batch_texts: List[str], batch_positions: List[int],
                                     model: AutoModelForCausalLM, tokenizer: AutoTokenizer,
                                     layer_idx: int, device: torch.device) -> torch.Tensor:
    """
    Extract activations from a specific layer at specified token positions for a batch of texts.

    :param batch_texts: List of text sequences in the batch.
    :param batch_positions: List of token indices within each text to extract activations from.
    :param model: The model to use.
    :param tokenizer: The tokenizer to use.
    :param layer_idx: The layer index to get embeddings from.
    :param device: The device to run the model on (e.g., 'cuda', 'cpu').
    :return: Tensor of shape (batch_size, hidden_dim), containing activations at the specified positions.
             Returns None if any position is invalid.
    """
    model.to(device)
    model.eval() # Set model to eval mode for activation extraction


    batch_size = len(batch_texts)
    position_activations = []

    # Process each example in the batch individually to handle variable positions correctly
    for i in range(batch_size):
        text = batch_texts[i]
        position = batch_positions[i]

        # Tokenize the text
        inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False).to(device)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        seq_len = input_ids.shape[1]

        if position < 0 or position >= seq_len:
            logging.warning(f"Position {position} out of bounds for text: '{text[:50]}...' (seq_len: {seq_len}). Skipping example.")
            # Return None or handle invalid position appropriately
            return None # Returning None for the batch if any example was invalid simplified handling for now

        with torch.no_grad(): # Use no_grad for activation extraction to save memory
            outputs = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
            hidden_states = outputs.hidden_states # list of tensors

        # Get activations for the specified layer and position
        if layer_idx < 0: # Handle negative indexing
            layer_activations = hidden_states[layer_idx]
        elif layer_idx < len(hidden_states):
             layer_activations = hidden_states[layer_idx]
        else:
             logging.warning(f"Layer {layer_idx} out of bounds. Using last layer.")
             layer_activations = hidden_states[-1]


        # Extract activation at the specific token position
        # layer_activations shape: (batch_size_of_1, sequence_length, hidden_dim)
        position_activation = layer_activations[0, position, :] # Shape (hidden_dim,)
        position_activations.append(position_activation)

    # Stack activations and keep them on the device
    if position_activations:
        # Keep on device, but cast to float32 for probe training consistency
        return torch.stack(position_activations).to(torch.float32).to(device)
    else:
        return torch.empty(0, model.config.hidden_size, dtype=torch.float32, device=device) # Return empty tensor on device


# Function to train and evaluate a probe for a single layer
def train_and_evaluate_probe(layer_idx: int, train_dataloader: DataLoader, test_dataloader: DataLoader,
                             model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: torch.device,
                             num_epochs: int = 10, learning_rate: float = 0.001) -> Tuple[float, Dict]:
    """
    Train and evaluate a linear probe for a single layer.

    :param layer_idx: The index of the layer to probe.
    :param train_dataloader: DataLoader for the training data.
    :param test_dataloader: DataLoader for the testing data.
    :param model: The language model to extract activations from.
    :param tokenizer: The tokenizer for the language model.
    :param device: The device to run training on.
    :param num_epochs: Number of training epochs.
    :param learning_rate: Learning rate for the optimizer.
    :return: A tuple containing:
             - Accuracy of the probe on the test set for this layer (float).
             - The state_dict of the best performing probe from any epoch for THIS layer (Dict).
    """
    logging.info(f"Training probe for layer {layer_idx}...")

    # Initialize the probe with float32 dtype for standard training
    hidden_dim = model.config.hidden_size
    probe = SimpleProbe(hidden_dim, dtype=torch.float32).to(device)


    # Define loss function and optimizer
    # Note: BCELoss expects float labels and float predictions
    criterion = nn.BCELoss()
    optimizer = optim.Adam(probe.parameters(), lr=learning_rate)


    # Training loop
    probe.train() # Set probe to training mode
    # model.train() # Keep model in eval mode during activation extraction to save memory/compute


    best_epoch_accuracy = -1 # Track best accuracy within this layer's training
    best_epoch_probe_state = None # Store state_dict of the best probe in this layer


    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        # Wrap the train_dataloader with tqdm for a progress bar
        for batch in tqdm(train_dataloader, desc=f"Layer {layer_idx} Epoch {epoch+1}/{num_epochs} (Train)"):
            texts = batch['text']
            positions = batch['token_position']
            # Ensure labels are float and on the correct device
            labels = batch['label'].float().unsqueeze(1).to(device)


            # Get activations for the current layer and batch
            # get_position_activations_batch now returns float32 tensors on the device
            # Use no_grad() inside get_position_activations_batch
            activations = get_position_activations_batch(texts, positions, model, tokenizer, layer_idx, device)

            if activations is None or activations.numel() == 0: # Skip batch if any example was invalid or no activations returned
                continue

            # Forward pass
            outputs = probe(activations)

            # Calculate loss - inputs to criterion should be float32
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        # Log average loss per epoch
        if num_batches > 0:
            avg_loss = total_loss / num_batches
            logging.info(f"Layer {layer_idx}, Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
        else:
             logging.warning(f"Layer {layer_idx}, Epoch {epoch+1}/{num_epochs}, No valid batches processed.")

        # --- Evaluation during training to track best epoch probe ---
        # Evaluate after each epoch to get accuracy for saving best state
        # Note: This adds compute, can be done less frequently if needed
        current_epoch_accuracy = evaluate_probe(probe, test_dataloader, model, tokenizer, layer_idx, device)
        if current_epoch_accuracy > best_epoch_accuracy:
            best_epoch_accuracy = current_epoch_accuracy
            # Save the state_dict of the probe for the best epoch *of this layer*
            best_epoch_probe_state = probe.state_dict()
            logging.info(f"Layer {layer_idx}: New best epoch accuracy {best_epoch_accuracy:.4f} at epoch {epoch+1}. State saved.")


    # Evaluation loop at the end of training (using the best epoch's state if saved)
    logging.info(f"Evaluating final probe for layer {layer_idx}...")
    # Restore the state dict of the best probe from any epoch for this layer
    if best_epoch_probe_state is not None:
        probe.load_state_dict(best_epoch_probe_state)
        logging.info(f"Layer {layer_idx}: Loaded best epoch state for final evaluation.")
    else:
        logging.warning(f"Layer {layer_idx}: No best epoch state saved, evaluating probe state after final epoch.")


    final_test_accuracy = evaluate_probe(probe, test_dataloader, model, tokenizer, layer_idx, device)
    logging.info(f"Layer {layer_idx} Final Test Accuracy (Best Epoch): {final_test_accuracy:.4f}")

    # --- Save the best probe state for THIS layer ---
    if best_epoch_probe_state is not None:
        # Save with layer index
        save_path = os.path.join(PROBE_SAVE_DIR, f"probe_layer_{layer_idx}.pth")
        torch.save(best_epoch_probe_state, save_path)
        logging.info(f"Saved best probe state for layer {layer_idx} to {save_path}")
    # --- End Save ---


    return final_test_accuracy, best_epoch_probe_state # Return accuracy and the state_dict


# Helper function for evaluation to avoid repeating code
def evaluate_probe(probe: SimpleProbe, dataloader: DataLoader, model: AutoModelForCausalLM,
                   tokenizer: AutoTokenizer, layer_idx: int, device: torch.device) -> float:
    """
    Evaluate a probe on a dataset.

    :param probe: The probe model.
    :param dataloader: DataLoader for the evaluation data.
    :param model: The language model.
    :param tokenizer: The tokenizer.
    :param layer_idx: The layer index being probed.
    :param device: The device.
    :return: Accuracy on the dataset.
    """
    probe.eval() # Set probe to evaluation mode
    model.eval() # Ensure model is in evaluation mode

    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad(): # Keep no_grad for evaluation
        # No tqdm here to keep training progress bar cleaner, or could add silent=True
        for batch in dataloader:
            texts = batch['text']
            positions = batch['token_position']
            labels = batch['label'].to(device)

            activations = get_position_activations_batch(texts, positions, model, tokenizer, layer_idx, device)

            if activations is None or activations.numel() == 0:
                continue

            outputs = probe(activations)
            predicted = (outputs > 0.5).squeeze().long()

            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)

    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    # logging.info(f"  Evaluation Accuracy: {accuracy:.4f}") # Don't log per epoch unless needed for debugging
    return accuracy

# Training Loop

- After probing specific, it will automatically save the states to the probe_states file.
- This means that you can terminate the cell early if you have finished training a specific layer.

In [48]:
import torch
from torch.utils.data import DataLoader
import time
import logging
from typing import Dict


# Convert HuggingFace Datasets to custom ProbeDataset and then to DataLoaders
logging.info("Creating Probe Datasets and DataLoaders...")
train_probe_dataset = ProbeDataset(train_dataset_balanced)
test_probe_dataset = ProbeDataset(test_dataset_balanced)

# Define DataLoader batch size (can be adjusted)
batch_size = 128 # Example batch size, adjust based on GPU RAM


g = torch.Generator(device=device)
g.manual_seed(42) # Set seed for reproducibility if needed


train_dataloader = DataLoader(train_probe_dataset, batch_size=batch_size, shuffle=True, generator=g)
test_dataloader = DataLoader(test_probe_dataset, batch_size=batch_size, shuffle=False) # No shuffling needed for test
logging.info("DataLoaders created.")


layer_wise_accuracies = {}
best_accuracy = -1.0 # Initialize best_accuracy as a float for comparison
best_layer = -1
best_probe_state_overall = None # To store the state_dict of the best probe overall

# Iterate through layers to train and evaluate a probe for each
# Probing layers from 0 up to model.config.num_hidden_layers - 1
num_layers = model.config.num_hidden_layers

# You can choose a subset of layers if needed, e.g., range(0, num_layers, 2)
# Let's probe a subset of layers spread out in the latter half, given compute constraints.
# Example: range(14, num_layers, 3) probes layers 14, 17, 20, 23, 26 (5 layers for a 28-layer model)
# Adjust the start, stop, and step based on the model's num_layers and your compute budget
start_layer_idx = 14 # Start probing from layer 14
# Ensure start_layer_idx is not greater than num_layers
if start_layer_idx >= num_layers:
    start_layer_idx = num_layers - 1 if num_layers > 0 else 0 # Probe at least the last layer if start is out of bounds, handle 0 layers
    logging.warning(f"Start layer index {start_layer_idx} is out of bounds, probing only layer {start_layer_idx}.")
step = 4 # Probe every 4th layer (no more compute)
# Ensure step is at least 1
if step < 1:
    step = 1

# Define the layers to probe based on calculated range
# Ensure the range is valid
if start_layer_idx < num_layers:
    layers_to_probe = range(start_layer_idx, num_layers, step)
else:
    layers_to_probe = range(num_layers - 1, num_layers) if num_layers > 0 else range(0) # Handle 0 layers case


logging.info(f"Probing layers: {list(layers_to_probe)}")


start_time_all_layers = time.time() # Start timing all layers


for layer_idx in tqdm(layers_to_probe, desc="Overall Layer Probing Progress"):
    layer_start_time = time.time() # Start timing for this layer

    # Train and evaluate the probe for the current layer
    # The function now returns accuracy AND the best probe state for this layer
    # Ensure unpacking here
    current_layer_accuracy, current_layer_best_state = train_and_evaluate_probe(
        layer_idx,
        train_dataloader,
        test_dataloader,
        model,
        tokenizer,
        device,
        num_epochs=10, # Use 10 epochs as defined
        learning_rate=0.001 # Use 0.001 learning rate as defined
    )

    layer_end_time = time.time() # End timing for this layer
    layer_duration = layer_end_time - layer_start_time
    logging.info(f"Layer {layer_idx} training and evaluation took {layer_duration:.2f} seconds.")


    # Store the accuracy for this layer
    layer_wise_accuracies[layer_idx] = current_layer_accuracy

    # Check if this layer is the best so far based on accuracy
    if current_layer_accuracy > best_accuracy: # Compare float accuracy
        best_accuracy = current_layer_accuracy
        best_layer = layer_idx
        best_probe_state_overall = current_layer_best_state # Store the state_dict of the best probe overall
        logging.info(f"New best layer found: Layer {best_layer} with accuracy {best_accuracy:.4f}")

    # The state_dict for the best probe of *this* layer is already saved inside train_and_evaluate_probe
    # So we don't need to save it again here unless we wanted a different naming convention


# --- After probing all layers ---
end_time_all_layers = time.time()
total_duration_all_layers = end_time_all_layers - start_time_all_layers
logging.info(f"\n--- Probing Complete ---")
logging.info(f"Probing across {len(layers_to_probe)} layers took {total_duration_all_layers:.2f} seconds.")

logging.info(f"\nBest layer found: Layer {best_layer} with Test Accuracy: {best_accuracy:.4f}")

# --- Special Save for the Overall Best Probe ---
# We already saved the best probe state for *each* layer inside the function.
# To specially label the overall best one, we save its state_dict again
# with a distinct name.

if best_probe_state_overall is not None and best_layer != -1:
    best_probe_save_path = os.path.join(PROBE_SAVE_DIR, f"probe_layer_{best_layer}_BEST.pth")
    torch.save(best_probe_state_overall, best_probe_save_path)
    logging.info(f"Saved overall best probe state for Layer {best_layer} to {best_probe_save_path}")
else:
    logging.warning("No layers were successfully probed, skipping saving overall best probe.")

# --- End Special Save ---


# Optional: Visualize results (e.g., plot accuracies per layer)
import matplotlib.pyplot as plt

# Filter out layers that might have been skipped or had 0 accuracy if necessary
valid_layers = [l for l in layer_wise_accuracies.keys() if layer_wise_accuracies[l] is not None]
valid_accuracies = [layer_wise_accuracies[l] for l in valid_layers]

# Ensure layers_to_probe is used for plotting order
plotted_layers = [l for l in layers_to_probe if l in layer_wise_accuracies]
plotted_accuracies = [layer_wise_accuracies[l] for l in plotted_layers]


if plotted_layers: # Only plot if there's data
    plt.figure(figsize=(10, 6))
    plt.plot(plotted_layers, plotted_accuracies, marker='o')
    plt.xlabel('Layer Index')
    plt.ylabel('Test Accuracy')
    plt.title('Probe Accuracy per Layer')
    # Set x-ticks to match the probed layers
    plt.xticks(list(layers_to_probe))
    plt.grid(True)
    plt.show()
else:
    logging.warning("No layers were successfully plotted.")

Overall Layer Probing Progress:   0%|          | 0/4 [00:00<?, ?it/s]

Layer 14 Epoch 1/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 2/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 3/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 4/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 5/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 6/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 7/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 8/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 9/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 14 Epoch 10/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 1/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 2/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 3/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 4/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 5/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 6/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 7/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 8/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 9/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 18 Epoch 10/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

Layer 22 Epoch 1/10 (Train):   0%|          | 0/17 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Load Probe States

In [54]:
# This cell contains code to load a trained probe state dictionary.
# It is commented out by default. Uncomment the lines and adjust parameters
# (like layer_index_to_load) to load a specific probe.

import torch
import os
from typing import Dict # Import Dict for type hinting

# Define the directory where probes were saved (must match the one used during saving)
PROBE_SAVE_DIR = "/content/probe_states/"

# Assume you want to load the probe for a specific layer
# Replace with the layer index you want to load (e.g., the best layer found during training)
layer_index_to_load = 14

# Define the path to the saved state dictionary file for that layer
save_path = os.path.join(PROBE_SAVE_DIR, f"probe_layer_{layer_index_to_load}.pth")

# # If loading the overall best probe, use its special filename:
# best_layer_index = 14 # Replace with the actual overall best layer index from your training logs
# save_path = os.path.join(PROBE_SAVE_DIR, f"probe_layer_{best_layer_index}_BEST.pth")

# Ensure 'save_path' variable is set correctly above depending on which file you want to load


# You need to know the hidden_dim of the layer you are loading the probe for
# This is model.config.hidden_size for Qwen3-0.6B
# Ensure 'model' is accessible or define hidden_dim
from transformers import AutoModelForCausalLM # Need to import if model is not in memory
dummy_model_config = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True).config # Load config if model not in memory
# hidden_dim = dummy_model_config.hidden_size
# OR if model is already in memory:
hidden_dim = model.config.hidden_size


# Create a new instance of the SimpleProbe class
# Make sure the probe is on the correct device (e.g., 'cuda')
# Ensure 'device' variable is defined (e.g., from the model loading cell)
import torch # Ensure torch is imported
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define device if not already
loaded_probe = SimpleProbe(hidden_dim=hidden_dim).to(device) # Ensure SimpleProbe class is defined (e.g., from f290f3d7)

# Load the saved state dictionary into the new probe instance
# map_location ensures it's loaded onto the correct device
try:
    loaded_probe.load_state_dict(torch.load(save_path, map_location=device))
    print(f"Successfully loaded probe state from {save_path}")
    # Set the loaded probe to evaluation mode
    loaded_probe.eval()
    # The 'loaded_probe' instance now contains the trained weights and can be used for inference.
    # Remember to use 'with torch.no_grad():' when using the loaded_probe for inference.

except FileNotFoundError:
    print(f"Error: Probe state file not found at {save_path}. Make sure training completed successfully and the path is correct.")
except Exception as e:
    print(f"An error occurred while loading the probe state: {e}")


Successfully loaded probe state from /content/probe_states/probe_layer_14.pth


## Use the loaded probe to predict on sample text

In [99]:
# Using a loaded probe to predict on sample text

# Define your sample text
sample_text = "Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?"

# Define the token position you want to probe (0-indexed)
position_to_probe = 30 # Example: probe the token at index 8

# Define the layer index that your 'loaded_probe' was trained on
probed_layer_index = 14 # Replace with the actual layer index of your loaded probe


print(f"Sample Text: '{sample_text}'")
print(f"Probing Position: {position_to_probe}")
print(f"Using Probe from Layer: {probed_layer_index}")


# Prepare the input for get_position_activations_batch
# This function expects lists for batching, even for a single example
batch_texts = [sample_text]
batch_positions = [position_to_probe]

# Ensure model and probe are in evaluation mode and disable gradients
model.eval()
loaded_probe.eval()
with torch.no_grad():

    # Get the activation for the specified position and layer
    # Note: get_position_activations_batch handles moving inputs to device internally
    activations = get_position_activations_batch(
        batch_texts,
        batch_positions,
        model,
        tokenizer,
        probed_layer_index,
        device # Use the device variable
    )

    # Check if activations were successfully extracted
    if activations is None or activations.numel() == 0:
        print(f"Could not extract activations for position {position_to_probe} in the sample text.")
    else:
        # The activations tensor will have shape (batch_size, hidden_dim)
        # Since we have batch_size 1 here, it's (1, hidden_dim)
        # Feed the activation into the loaded probe
        # Ensure activations are on the same device as the probe (get_position_activations_batch does this)
        probe_output = loaded_probe(activations)

        # The output is a probability between 0 and 1
        predicted_probability = probe_output.item() # Get the scalar value

        print(f"\nProbe Prediction (Probability): {predicted_probability:.4f}")

        # Interpret the prediction (e.g., using a threshold like 0.5)
        prediction_label = "Pivotal" if predicted_probability > 0.5 else "Non-Pivotal"
        print(f"Predicted Label (threshold 0.5): {prediction_label}")

Sample Text: 'Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?'
Probing Position: 30
Using Probe from Layer: 14

Probe Prediction (Probability): 0.9539
Predicted Label (threshold 0.5): Pivotal
