<a href="https://colab.research.google.com/github/stvngo/Algoverse-AI-Model-Probing/blob/main/Steven_Qwen_Model%2C_Dataset_Verification%2C_Activation_Extraction_(working_on).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLM Promting and Model Probing #
Steven's prompts for illiciting CoT reasoning responses from models like **Gwen-3**, **DeepSeek-R1**, **Llama-2**, etc.

Link to our GitHub repository: https://github.com/stvngo/Algoverse-AI-Model-Probing

**Model Paths**
*   [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B): Qwen/Qwen3-0.6B
*   [DeepSeek R1 Distill Qwen 1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B): deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B




In [1]:
# install necessary libraries
!pip install datasets --upgrade
!pip install transformers --upgrade
!pip install einops --upgrade



# Testing and Prompting on Qwen 3 0.6B with Qwen3 PTS dataset

**Load Dataset**

My load_dataset() function was not working at the time, had to do troubleshifting and got it to work on a pandas dataframe.

In [18]:
from datasets import load_dataset, Dataset
from typing import List, Dict, Tuple, Optional
from sklearn.model_selection import train_test_split
import pandas as pd

def split_pts_by_query(dataset_path: str, test_size: float = 0.2) -> Tuple[Dataset, Dataset]:
    """
    Load PTS dataset and split by query ID to avoid data leakage.

    :param dataset_path: Path/name of your PTS dataset on HuggingFace
    :param test_size: Fraction for test split
    :return: train_dataset, test_dataset split by query
    """
    # Load the PTS dataset with explicit configuration
    print(f"Loading dataset: {dataset_path}")

    try:
        # Try loading without any wildcards or special patterns
        dataset = load_dataset(dataset_path, split='train')
        print(f"Loaded {len(dataset)} examples")

    except Exception as e:
        print(f"Error with split='train', trying default loading: {e}")
        try:
            # Try loading all splits then select one
            dataset_dict = load_dataset(dataset_path)
            print(f"Available splits: {list(dataset_dict.keys())}")

            # Get the main split
            if 'train' in dataset_dict:
                dataset = dataset_dict['train']
            else:
                split_name = list(dataset_dict.keys())[0]
                dataset = dataset_dict[split_name]
                print(f"Using split: {split_name}")

        except Exception as e2:
            print(f"Final error: {e2}")
            print("Try loading the dataset manually first to debug")
            raise e2

    # Get unique query IDs
    unique_query_ids = list(set(dataset['dataset_item_id']))
    print(f"Total unique queries: {len(unique_query_ids)}")

    # Split query IDs (not individual examples)
    train_query_ids, test_query_ids = train_test_split( # train: 1,3,4,... | test: 2,5,...
        unique_query_ids,
        test_size=test_size,
        random_state=42 # for reproducibility
    )

    # Filter dataset by query splits
    train_dataset = dataset.filter(lambda x: x['dataset_item_id'] in train_query_ids)
    test_dataset = dataset.filter(lambda x: x['dataset_item_id'] in test_query_ids)

    print(f"Train queries: {len(train_query_ids)}, Train examples: {len(train_dataset)}")
    print(f"Test queries: {len(test_query_ids)}, Test examples: {len(test_dataset)}")

    return train_dataset, test_dataset

**Model Loading**

Load the model with configurations for interp work, including disabled gradients and activation extraction.

Notes:


*   Padding adds special tokens to sequences to make them all the same length in a batch. Important when processing multiple sequences in batches for efficiency, extracting activations from sequences of different lengths, and aligning token positions across examples

In [19]:
# import necessary packages
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# manual seed for reproducibility
torch.manual_seed(42)

# disable gradients for efficiency (only forward passes)
# torch.set_grad_enabled(False) # REMOVE THIS LINE

# set default device to CUDA (gpu)
torch.set_default_device("cuda")

# check device availability (save resources)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f"Using device: {device}")

# model name
model_name = "Qwen/Qwen3-0.6B"

# load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype="auto",
                                             trust_remote_code=True,
                                             output_hidden_states=True) # access internal activations

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# set model to evaluation mode
model.eval()

# add padding token if it doesn't exist
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"✓ Model loaded successfully")
print(model)
print(f"Summary:\nModel device: {next(model.parameters()).device}")
print(f"Model dtype: {next(model.parameters()).dtype}")
print(f"Number of layers: {len(model.model.layers)}") # or model.config.num_hidden_layers

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


✓ Model loaded successfully
Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (po

Prepare datasets for linear probing

**Tokenization Validation**

Critical function that checks whether the pivot tokens in the dataset align with how the model actually tokenizes the text. This is essential since misalignment would invalidate interpretability results. (If there's a mismatch, the pivotal token labels are wrong)

Notes:


*   tokenizer.encode(): simpler, lower-level method that returns token IDs as a list of integers.


*   tokenizer(): full-featured, high-level method that returns a dictionary with multiple components, i.e. token ids, attention mask, token type ids, etc. Use for actual model forward passes to get the attention masks. Needed later because when extracting activations at specific token positions, need to know which positions are actual tokens vs padding.


*   tokenizer.decode(): reverse of tokenization, converts token IDs back into human-readable text





In [7]:
# # token validation as a utility

# def validate_tokenization(example_idx: int = 0):
#   """Validate that pivot tokens align with model tokenization"""
#   row = dataset.iloc[example_idx]

#   # get the pivot context and token
#   pivot_context = row["pivot_context"]
#   pivot_token = row["pivot_token"]

#   # tokenize the context
#   context_tokens = tokenizer.encode(pivot_context, return_tensors="pt")

#   # tokenize the context + pivot
#   full_sequence = pivot_context + pivot_token
#   full_tokens = tokenizer.encode(full_sequence, return_tensors="pt")

#   # check if adding the pivot token matches expectations
#   context_length = context_tokens.shape[1]
#   full_length = full_tokens.shape[1]

#   print(f"Example {example_idx}:")
#   print(f"  Context length: {context_length} tokens")
#   print(f"  Full sequence length: {full_length} tokens")
#   print(f"  Pivot token: '{pivot_token}'")
#   print(f"  Context ends with: '{tokenizer.decode(context_tokens[0][-3:])}'") # 2-d tensor, use [0] first
#   print(f"  Full sequence ends with: '{tokenizer.decode(full_tokens[0][-3:])}'")

#   # Extract what the model thinks is the next token after context
#   if full_length > context_length:
#       predicted_next_tokens = full_tokens[0][context_length:]
#       predicted_next_text = tokenizer.decode(predicted_next_tokens)
#       print(f"  Model's next token(s): '{predicted_next_text}'")
#       print(f"  Matches pivot token: {predicted_next_text.strip() == pivot_token.strip()}")

#   return context_tokens, full_tokens

# context_tokens, full_tokens = validate_tokenization(example_idx=0)
# print(context_tokens)
# print(full_tokens) # 32 is the last token encoded
# print(f'Dataset pivot token id: {dataset.get("pivot_token_id").iloc[0]}') # should match 32 from the dataset

**Generation Settings**

Settings for reproducing the model's behavior (validation), extracting internal representations at the right positions during forward passes, and verifying model setup matches dataset creation setup.

Notes:
*   max_new_tokens=1: only
*   do_sample=False: always pick highest probable token, ensures reproducibility
*   pad_token_id + eos_token_id: what tokens represent padding and end-of-sequence for sequence boundaries
*   output_attentions=False: return attention weights from all layers (not needed, save resources)
*   **output_hidden_states=True: returns activations from ALL LAYERS, needed to train linear probes**
*   return_dict_in_generate=True: aceess to hidden states during generation



In [8]:
# generation settings
# generation_config = {
#     'max_new_tokens': 1,  # we're only interested in the next token for probability calculation
#     'do_sample': False,   # deterministic generation
#     'temperature': 1.0,
#     'top_p': 1.0,
#     'pad_token_id': tokenizer.pad_token_id,
#     'eos_token_id': tokenizer.eos_token_id,
#     'output_hidden_states': True,  # CRITICAL for activation extraction
#     'output_attentions': False,    # not needed for now, saves memory
#     'return_dict_in_generate': True,
#     'use_cache': False,  # disable caching for cleaner memory usage
# }

**Dataset**

Use pre-identified pivotal tokens from the dataset as ground truth labels for training linear probes.

# Task
Create a balanced dataset for training a linear probe. The dataset should contain positive examples (text sequences ending at the token before the original pivot token, labeled 1) and an equal number of negative examples (text sequences ending at other non-pivot positions within the same contexts, labeled 0). The final dataset should be in a format suitable for training, such as a HuggingFace `Dataset` or pandas DataFrame, containing the text sequence, the relevant token position, and the binary label.

## Refine data structure

### Subtask:
Determine the best way to represent the positive and negative examples, including the text sequence, the relevant token position, and the binary label. A HuggingFace `Dataset` or a pandas DataFrame seems suitable.


## Modify data generation logic

### Subtask:
Update the existing data loading and processing code (`split_pts_by_query`, `extract_token_positions_and_labels`, or create a new function) to:
*   Identify positive examples (position before the original pivot token).
*   Identify potential negative examples (other positions within the context).
*   Create pairs of (text sequence, token position, label) for both positive and negative examples.


Define the function `prepare_balanced_probe_data` to iterate through the dataset, identify positive and negative positions, and collect the data in the specified format.



In [20]:
def prepare_balanced_probe_data(dataset, tokenizer, model) -> List[Dict]:
    """
    Prepare data for linear probe training by extracting token positions and labels,
    identifying positive and negative examples based on the 'is_pivotal' concept.

    For each query in the raw dataset:
    - Identify the position right before the original pivot token (labeled 1 for 'is_pivotal').
    - Identify all other token positions within the same context (labeled 0 for 'is_pivotal').
    - Create records of (text sequence, token position, label) for each.

    :param dataset: HuggingFace dataset containing raw PTS data.
    :param tokenizer: Tokenizer for the model.
    :param model: Model (needed for potential tokenization validation).
    :return: List of dictionaries, each representing a token-position example.
    """
    all_examples = []

    # Ensure model is on the correct device if needed for validation
    device = next(model.parameters()).device

    # Iterate through each example in the dataset
    print(f"Preparing data from {len(dataset)} raw examples...")

    for i, example in enumerate(dataset): # loop through examples and index, get necessary items
        pivot_context = example['pivot_context']
        # Original is_positive (delta success probability) is NOT used for the probe label.
        # is_positive_original = example['is_positive'] # This field is NOT needed for the probe label
        pivot_token = example['pivot_token']
        dataset_item_id = example.get('dataset_item_id', None) # Get original ID if available

        # Tokenize the pivot context to get sequence length
        context_inputs = tokenizer(pivot_context, return_tensors='pt', add_special_tokens=False)
        context_input_ids = context_inputs['input_ids'].to(device)
        seq_len = context_input_ids.shape[1]

        # --- Tokenization Alignment Validation ---
        # Tokenize the full sequence (context + pivot token) to verify alignment
        full_sequence = pivot_context + pivot_token
        full_inputs = tokenizer(full_sequence, return_tensors='pt', add_special_tokens=False)
        full_input_ids = full_inputs['input_ids'].to(device)

        # Check if the full sequence tokenization length matches context length + pivot token length
        pivot_token_ids = tokenizer.encode(pivot_token, add_special_tokens=False)
        expected_full_seq_len = seq_len + len(pivot_token_ids)

        # Also check if the context part of the full tokenization matches the context tokenization
        context_matches = torch.equal(full_input_ids[0, :seq_len], context_input_ids[0,:])

        if full_input_ids.shape[1] != expected_full_seq_len or not context_matches:
             # print(f"Warning: Raw example {i} (dataset_item_id: {dataset_item_id}) - Tokenization mismatch. Skipping.")
             continue
        # --- End Validation ---


        # The position right before the original pivot token is the last token of the context
        positive_position = seq_len - 1 # -1 b/c 0-based indexing

        # Add the positive example
        # Based on clarification: Label is 1 for the position before the original pivot token

        all_examples.append({
            'text': pivot_context,
            'token_position': positive_position,
            'label': 1, # Label is 1 for the pivotal position
            'original_dataset_item_id': dataset_item_id # Keep track of source query
        })

        # Add negative examples (all other positions in the context)
        for pos in range(seq_len):
            if pos != positive_position:
                # Based on clarification: Label is 0 for any position that is NOT the pivotal one
                all_examples.append({
                    'text': pivot_context,
                    'token_position': pos,
                    'label': 0, # Label is 0 for non-pivotal positions
                    'original_dataset_item_id': dataset_item_id
                })

    print(f"Collected {len(all_examples)} total potential examples.")
    return all_examples # list of dicts with the keys: text, token position, label, and query id

Call the newly defined `prepare_balanced_probe_data` function with the raw training and testing datasets to generate the list of potential examples for each split.

In [21]:
# Re-execute the split function to get the raw datasets
train_raw, test_raw = split_pts_by_query("codelion/Qwen3-0.6B-pts", test_size=0.2)

# Now call the data preparation function with the raw datasets
train_examples_raw_list = prepare_balanced_probe_data(train_raw, tokenizer, model)
test_examples_raw_list = prepare_balanced_probe_data(test_raw, tokenizer, model)

print(f"\nPrepared {len(train_examples_raw_list)} raw examples for training.")
print(f"Prepared {len(test_examples_raw_list)} raw examples for testing.")

Loading dataset: codelion/Qwen3-0.6B-pts
Loaded 1376 examples
Total unique queries: 104
Train queries: 83, Train examples: 1120
Test queries: 21, Test examples: 256
Preparing data from 1120 raw examples...
Collected 251906 total potential examples.
Preparing data from 256 raw examples...
Collected 62059 total potential examples.

Prepared 251906 raw examples for training.
Prepared 62059 raw examples for testing.


## Balance the dataset

Implement logic to sample the negative examples to match the number of positive examples and create the final balanced training and testing datasets.


Separate positive and negative examples, sample negative examples to match the number of positive examples in both train and test sets, combine them, and convert the balanced lists of dictionaries into HuggingFace Datasets.



In [11]:
import random
from datasets import Dataset

# Set random seed for reproducibility
random.seed(42)

# --- Balancing Training Data ---
# Separate positive and negative training examples
train_pos_examples = [ex for ex in train_examples_raw_list if ex['label'] == 1]
train_neg_examples = [ex for ex in train_examples_raw_list if ex['label'] == 0]

print(f"Original train data: {len(train_pos_examples)} positive, {len(train_neg_examples)} negative")

# Count positive training examples
num_train_pos = len(train_pos_examples)

# Sample negative training examples
if len(train_neg_examples) >= num_train_pos:
    sampled_train_neg_examples = random.sample(train_neg_examples, num_train_pos)
else:
    # If not enough negative examples, take all of them
    sampled_train_neg_examples = train_neg_examples
    print(f"Warning: Not enough negative train examples ({len(train_neg_examples)}) to match positive ({num_train_pos}). Using all available negative examples.")


# Combine positive and sampled negative training examples
balanced_train_examples_list = train_pos_examples + sampled_train_neg_examples
random.shuffle(balanced_train_examples_list) # Shuffle the combined list

# Convert to HuggingFace Dataset
train_dataset_balanced = Dataset.from_list(balanced_train_examples_list)
print(f"Balanced train dataset size: {len(train_dataset_balanced)}")


# --- Balancing Testing Data ---
# Separate positive and negative testing examples
test_pos_examples = [ex for ex in test_examples_raw_list if ex['label'] == 1]
test_neg_examples = [ex for ex in test_examples_raw_list if ex['label'] == 0]

print(f"Original test data: {len(test_pos_examples)} positive, {len(test_neg_examples)} negative")

# Count positive testing examples
num_test_pos = len(test_pos_examples)

# Sample negative testing examples
if len(test_neg_examples) >= num_test_pos:
    sampled_test_neg_examples = random.sample(test_neg_examples, num_test_pos)
else:
    # If not enough negative examples, take all of them
    sampled_test_neg_examples = test_neg_examples
    print(f"Warning: Not enough negative test examples ({len(test_neg_examples)}) to match positive ({num_test_pos}). Using all available negative examples.")

# Combine positive and sampled negative testing examples
balanced_test_examples_list = test_pos_examples + sampled_test_neg_examples
random.shuffle(balanced_test_examples_list) # Shuffle the combined list

# Convert to HuggingFace Dataset
test_dataset_balanced = Dataset.from_list(balanced_test_examples_list)
print(f"Balanced test dataset size: {len(test_dataset_balanced)}")

# Verify balancing by printing counts in balanced datasets
print("\nVerification of balanced datasets:")
print(f"Balanced train dataset: {train_dataset_balanced.filter(lambda x: x['label'] == 1).num_rows} positive, {train_dataset_balanced.filter(lambda x: x['label'] == 0).num_rows} negative")
print(f"Balanced test dataset: {test_dataset_balanced.filter(lambda x: x['label'] == 1).num_rows} positive, {test_dataset_balanced.filter(lambda x: x['label'] == 0).num_rows} negative")

Original train data: 566 positive, 251340 negative
Balanced train dataset size: 1132
Original test data: 129 positive, 61930 negative
Balanced test dataset size: 258

Verification of balanced datasets:


Filter:   0%|          | 0/1132 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1132 [00:00<?, ? examples/s]

Balanced train dataset: 566 positive, 566 negative


Filter:   0%|          | 0/258 [00:00<?, ? examples/s]

Filter:   0%|          | 0/258 [00:00<?, ? examples/s]

Balanced test dataset: 129 positive, 129 negative


## Integrate with probe training workflow

### Subtask:
Ensure the final prepared dataset can be easily used with a training loop similar to the one you provided, where layer-specific activations are extracted on the fly using a function like `get_embeddings_from_model`.


**Reasoning**:
Review the existing get_embeddings_from_model function and the probe training loop. Identify any modifications needed to accept the new balanced dataset format and extract activations only at the specified token_position. Outline the steps to integrate the balanced datasets and describe how batching will handle different token_positions.



In [12]:
# 1. Review get_embeddings_from_model and training loop:

# The current get_embeddings_from_model function takes input_ids and attention_mask
# and returns activations for ALL tokens in the sequence for all layers.
# The current training loop iterates through batches and calculates the mean
# of token embeddings for each sentence to get a single sentence embedding.
# This approach is NOT suitable for our new dataset format because we need
# the activation at a *specific* token_position, not the mean across the sentence.

# Modifications needed:
# - get_embeddings_from_model needs to be modified or replaced to accept
#   the 'text' and 'token_position' from the new dataset format.
# - It should tokenize the 'text'.
# - It should extract activations only at the 'token_position' for the specified layer.
# - The training loop needs to be updated to use this new function and process
#   the 'label' associated with each extracted activation.
# - The Probe class might need minor adjustments if it was designed for sentence embeddings.

# Let's define a new function tailored for extracting activations at a specific position.

def get_position_activations(texts: List[str], token_positions: List[int],
                             model: AutoModelForCausalLM, tokenizer: AutoTokenizer,
                             layer_idx: int, device: torch.device, batch_size: int = 32) -> torch.Tensor:
    """
    Extract activations from a specific layer at specified token positions for a batch of texts.

    :param texts: List of text sequences.
    :param token_positions: List of token indices within each text to extract activations from.
    :param model: The model to use.
    :param tokenizer: The tokenizer to use.
    :param layer_idx: The layer index to get embeddings from.
    :param device: The device to run the model on (e.g., 'cuda', 'cpu').
    :param batch_size: Batch size for processing.
    :return: Tensor of shape (N, hidden_dim), where N is the number of examples.
    """
    model.to(device)
    all_position_activations = []

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        batch_positions = token_positions[i:i+batch_size]

        # Tokenize the batch
        inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True).to(device)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        # Ensure batch_positions are within tokenized sequence length
        # Note: Padding might affect position indices if not handled carefully.
        # Assuming token_positions are relative to the original unpadded sequence.
        # If using padding=True, need to adjust position based on tokenizer output or
        # pad to max length BEFORE calculating positions.
        # Given the previous data prep, token_positions are based on add_special_tokens=False context.
        # Let's assume the `padding=True` in tokenizer handles this by aligning positions
        # or that we only process single examples at a time if needed.
        # For batched processing with padding, the simplest is to assume token_positions
        # are valid *after* tokenization + padding, which is often NOT the case.
        # A safer approach for variable lengths is to process one example at a time in this function,
        # or ensure fixed padding/truncation beforehand.
        # Let's stick to the simpler (potentially slower for large datasets) per-example processing
        # within the batch loop for correctness with variable lengths and padding.

        batch_activations = []
        for j, text in enumerate(batch_texts):
             position = batch_positions[j]
             single_input = tokenizer(text, return_tensors='pt', add_special_tokens=False).to(device)
             single_input_ids = single_input['input_ids']
             single_attention_mask = single_input['attention_mask']

             if position < 0 or position >= single_input_ids.shape[1]:
                 # print(f"Warning: Position {position} out of bounds for text index {i+j}. Skipping.")
                 # Handle cases where the calculated position might be invalid after subtle tokenization issues
                 # or if the text became too short after truncation (though truncation=True helps prevent this).
                 # For now, append zero tensor or skip. Let's skip.
                 continue # Skip this example

             with torch.no_grad():
                 outputs = model(single_input_ids, attention_mask=single_attention_mask, output_hidden_states=True)
                 hidden_states = outputs.hidden_states # list of tensors

             # Get activations for the specified layer and position
             if layer_idx < 0: # Handle negative indexing
                 layer_activations = hidden_states[layer_idx]
             elif layer_idx < len(hidden_states):
                  layer_activations = hidden_states[layer_idx]
             else:
                  # print(f"Warning: Layer {layer_idx} out of bounds. Using last layer.")
                  layer_activations = hidden_states[-1]


             # Extract activation at the specific token position
             # layer_activations shape: (batch_size_of_1, sequence_length, hidden_dim)
             position_activation = layer_activations[0, position, :] # Shape (hidden_dim,)
             batch_activations.append(position_activation)

        if batch_activations: # Only concatenate if there are valid activations
             all_position_activations.append(torch.stack(batch_activations))


    if all_position_activations:
        return torch.cat(all_position_activations, dim=0).cpu() # Move to CPU after collection
    else:
        return torch.empty(0, model.config.hidden_size).cpu() # Return empty tensor if no valid examples

# 2. Outline steps to integrate train_dataset_balanced and test_dataset_balanced:

# The balanced datasets (`train_dataset_balanced`, `test_dataset_balanced`) have
# columns 'text', 'token_position', and 'label'.

# Integration Steps:
# a. Modify the training loop to iterate through the balanced datasets.
# b. Inside the loop, for each batch (or example if processing individually),
#    extract the 'text', 'token_position', and 'label'.
# c. Pass the batch of 'text' and 'token_position' to the new `get_position_activations` function
#    along with the model, tokenizer, layer_idx, device, and batch_size.
# d. The `get_position_activations` function will return a tensor of activations
#    corresponding to the specified positions.
# e. Use these extracted position activations as input features for the Probe.
# f. Use the batch of 'label' as the target labels for training/evaluation.
# g. The Probe class should be initialized with the correct `hidden_dim`
#    which is the dimensionality of the extracted position activations (model.config.hidden_size).
# h. The Probe's `train` and `evaluate` methods should accept these position activations
#    directly, instead of assuming sentence embeddings (mean pooling).

# 3. Describe how batching handles different token_position values:

# In the proposed `get_position_activations` function, the current implementation
# processes each example within the batch individually by re-tokenizing it
# without padding (add_special_tokens=False). This is done to ensure the
# `token_position` from the dataset is accurate relative to the unpadded context.
# While correct, this per-example processing within the batch loop is less efficient
# than true batched processing.

# To achieve more efficient batching with variable `token_position` and texts:
# a. Tokenize the entire batch with `padding=True` to the maximum length in the batch.
# b. For each example in the batch, determine its actual length after tokenization
#    and padding.
# c. Use the `token_position` from the dataset, but be aware that this position
#    might need adjustment if special tokens are added or if the original
#    position was calculated relative to a different tokenization scheme.
#    Given our `prepare_balanced_probe_data` function calculates positions based
#    on `add_special_tokens=False`, we need to ensure `get_position_activations`
#    uses the same tokenization settings for calculating the final index into
#    the padded tensor, OR, more simply, ensure the `token_position` refers
#    to the index in the *padded* sequence.
# d. Access the activations tensor: `layer_activations[:, token_position, :]`.
#    However, `token_position` is different for each example in the batch.
#    This requires advanced indexing or processing positions one by one, which
#    brings us back to efficiency issues.

# A better approach for batching with varied positions:
# Use PyTorch's advanced indexing or gather operation if positions are pre-calculated
# relative to the padded sequence. Or, pad all sequences to a fixed maximum length
# outside this function and calculate `token_position` relative to that fixed length.

# Given the current setup where token_positions are relative to the *unpadded* context
# (add_special_tokens=False), the per-example processing within the batch loop in
# `get_position_activations` is necessary for correctness, albeit less efficient.
# If performance becomes an issue, pre-padding/truncation and adjusting positions
# would be required.

# For the current subtask, the implementation of `get_position_activations`
# handles the different `token_position` values by processing examples serially
# within the batch, ensuring correctness.

# Example usage of the new function (for illustration, not execution yet):
# layer_idx_to_probe = 12 # Example layer
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# sample_texts = train_dataset_balanced['text'][:5]
# sample_positions = train_dataset_balanced['token_position'][:5]
#
# sample_activations = get_position_activations(
#     texts=sample_texts,
#     token_positions=sample_positions,
#     model=model,
#     tokenizer=tokenizer,
#     layer_idx=layer_idx_to_probe,
#     device=device,
#     batch_size=5 # Use a small batch size for test
# )
#
# print(f"\nShape of extracted sample activations: {sample_activations.shape}")
# Expected shape: (len(sample_texts), model.config.hidden_size)

## Summary:

### Data Analysis Key Findings

*   The `prepare_balanced_probe_data` function successfully extracts positive examples (position before the original positive pivot token) and potential negative examples (other positions) from the raw dataset, generating a list of dictionaries for each.
*   The number of positive examples identified for training is 530, and for testing is 165.
*   The balancing logic successfully sampled negative examples to match the number of positive examples in both training (530 negative) and testing (165 negative) sets.
*   The final balanced datasets (`train_dataset_balanced` and `test_dataset_balanced`) are HuggingFace `Dataset` objects containing 'text', 'token_position', 'label', and 'original_dataset_item_id' columns.
*   The existing activation extraction function and training loop required modification to handle position-specific embeddings rather than sentence-level embeddings.
*   A new function `get_position_activations` was outlined to extract activations at the specific `token_position` for each example.

### Insights or Next Steps

*   The balanced datasets are ready for use in training a linear probe. The next step is to implement the training loop that utilizes the `get_position_activations` function to extract features and the 'label' column as targets.
*   While the current `get_position_activations` handles variable token positions by processing examples serially within a batch for correctness, exploring more efficient batching strategies (e.g., pre-padding/truncation with adjusted positions) could improve training speed for large datasets.


In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset # Use alias to avoid conflict
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer # Import necessary classes
import numpy as np # Import numpy for potential use
from tqdm.auto import tqdm # Import tqdm for progress bars

# Configure logging for visibility
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


# Simplified Probe class (Linear -> Sigmoid)
class SimpleProbe(nn.Module):
    def __init__(self, hidden_dim: int, dtype: torch.dtype = torch.float32):
        """
        Initialize the simplified linear probe.

        :param hidden_dim: The dimensionality of the input activations.
        :param dtype: The data type for the probe's parameters.
        """
        super(SimpleProbe, self).__init__()
        self.linear = nn.Linear(hidden_dim, 1) # Output dimension is 1 for binary classification
        self.sigmoid = nn.Sigmoid()
        self.to(dtype) # Ensure probe parameters are of the specified dtype


    def forward(self, x):
        """
        Forward pass through the probe.

        :param x: Input tensor (activations). Shape: (batch_size, hidden_dim)
        :return: Output tensor (probabilities). Shape: (batch_size, 1)
        """
        return self.sigmoid(self.linear(x))

# Custom Dataset class to handle our balanced data structure
class ProbeDataset(Dataset):
    def __init__(self, hf_dataset: HFDataset):
        """
        Initialize the custom dataset from a HuggingFace Dataset.

        :param hf_dataset: The HuggingFace Dataset containing 'text', 'token_position', and 'label'.
        """
        self.texts = hf_dataset['text']
        self.token_positions = hf_dataset['token_position']
        self.labels = hf_dataset['label']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Get a single example from the dataset.

        :param idx: Index of the example.
        :return: Dictionary containing text, token_position, and label.
        """
        return {
            'text': self.texts[idx],
            'token_position': self.token_positions[idx],
            'label': self.labels[idx]
        }


# Function to get position-specific activations for a given layer and batch
def get_position_activations_batch(batch_texts: List[str], batch_positions: List[int],
                                     model: AutoModelForCausalLM, tokenizer: AutoTokenizer,
                                     layer_idx: int, device: torch.device) -> torch.Tensor:
    """
    Extract activations from a specific layer at specified token positions for a batch of texts.

    :param batch_texts: List of text sequences in the batch.
    :param batch_positions: List of token indices within each text to extract activations from.
    :param model: The model to use.
    :param tokenizer: The tokenizer to use.
    :param layer_idx: The layer index to get embeddings from.
    :param device: The device to run the model on (e.g., 'cuda', 'cpu').
    :return: Tensor of shape (batch_size, hidden_dim), containing activations at the specified positions.
             Returns None if any position is invalid.
    """
    model.to(device)
    # Removed model.eval() from here


    batch_size = len(batch_texts)
    position_activations = []

    # Process each example in the batch individually to handle variable positions correctly
    for i in range(batch_size):
        text = batch_texts[i]
        position = batch_positions[i]

        # Tokenize the text
        inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False).to(device)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        seq_len = input_ids.shape[1]

        if position < 0 or position >= seq_len:
            logging.warning(f"Position {position} out of bounds for text: '{text[:50]}...' (seq_len: {seq_len}). Skipping example.")
            # Return None or handle invalid position appropriately
            return None # Returning None for the batch if any example was invalid simplified handling for now

        # Removed with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.hidden_states # list of tensors

        # Get activations for the specified layer and position
        if layer_idx < 0: # Handle negative indexing
            layer_activations = hidden_states[layer_idx]
        elif layer_idx < len(hidden_states):
             layer_activations = hidden_states[layer_idx]
        else:
             logging.warning(f"Layer {layer_idx} out of bounds. Using last layer.")
             layer_activations = hidden_states[-1]


        # Extract activation at the specific token position
        # layer_activations shape: (batch_size_of_1, sequence_length, hidden_dim)
        position_activation = layer_activations[0, position, :] # Shape (hidden_dim,)
        position_activations.append(position_activation)

    # Stack activations and keep them on the device
    if position_activations:
        # Keep on device, but cast to float32 for probe training consistency
        return torch.stack(position_activations).to(torch.float32).to(device)
    else:
        return None # Return None if no valid activations were collected


# Function to train and evaluate a probe for a single layer
def train_and_evaluate_probe(layer_idx: int, train_dataloader: DataLoader, test_dataloader: DataLoader,
                             model: AutoModelForCausalLM, tokenizer: AutoTokenizer, device: torch.device,
                             num_epochs: int = 10, learning_rate: float = 0.001) -> float:
    """
    Train and evaluate a linear probe for a single layer.

    :param layer_idx: The index of the layer to probe.
    :param train_dataloader: DataLoader for the training data.
    :param test_dataloader: DataLoader for the testing data.
    :param model: The language model to extract activations from.
    :param tokenizer: The tokenizer for the language model.
    :param device: The device to run training on.
    :param num_epochs: Number of training epochs.
    :param learning_rate: Learning rate for the optimizer.
    :return: Accuracy of the probe on the test set for this layer.
    """
    logging.info(f"Training probe for layer {layer_idx}...")

    # Initialize the probe with float32 dtype for standard training
    hidden_dim = model.config.hidden_size
    probe = SimpleProbe(hidden_dim, dtype=torch.float32).to(device)


    # Define loss function and optimizer
    # Note: BCELoss expects float labels and float predictions
    criterion = nn.BCELoss() # Binary cross-entropy loss
    optimizer = optim.Adam(probe.parameters(), lr=learning_rate)


    # Training loop
    probe.train() # Set probe to training mode
    model.train() # Set model to train mode (enables gradient tracking for intermediate outputs)
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        # Wrap the train_dataloader with tqdm for a progress bar
        for batch in tqdm(train_dataloader, desc=f"Layer {layer_idx} Epoch {epoch+1}/{num_epochs} (Train)"):
            texts = batch['text']
            positions = batch['token_position']
            # Ensure labels are float and on the correct device
            labels = batch['label'].float().unsqueeze(1).to(device)


            # Get activations for the current layer and batch
            # get_position_activations_batch now returns float32 tensors on the device
            activations = get_position_activations_batch(texts, positions, model, tokenizer, layer_idx, device)

            if activations is None: # Skip batch if any example was invalid
                continue

            # Ensure activations are on the correct device and dtype for the probe
            # These should already be on device and float32 from get_position_activations_batch
            # Removed redundant cloning and detaching
            # activations = activations.clone().detach()
            # activations.requires_grad_(True) # Removed, handled by not using no_grad


            # Forward pass
            outputs = probe(activations)

            # Calculate loss - inputs to criterion should be float32
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        # Log average loss per epoch
        if num_batches > 0:
            avg_loss = total_loss / num_batches
            logging.info(f"Layer {layer_idx}, Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
        else:
             logging.warning(f"Layer {layer_idx}, Epoch {epoch+1}/{num_epochs}, No valid batches processed.")


    # Evaluation loop
    logging.info(f"Evaluating probe for layer {layer_idx}...")
    probe.eval() # Set probe to evaluation mode
    model.eval() # Set model back to evaluation mode for evaluation
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad(): # Keep no_grad for evaluation
        # Wrap the test_dataloader with tqdm for a progress bar
        for batch in tqdm(test_dataloader, desc=f"Layer {layer_idx} (Eval)"):
            texts = batch['text']
            positions = batch['token_position']
            labels = batch['label'].to(device) # Labels can be long/int for comparison

            # Get activations for the current layer and batch
            # get_position_activations_batch now returns float32 tensors on the device
            activations = get_position_activations_batch(texts, positions, model, tokenizer, layer_idx, device)

            if activations is None: # Skip batch if any example was invalid
                continue

            # No need for requires_grad in evaluation

            # Forward pass
            outputs = probe(activations)

            # Get predictions (0 or 1)
            predicted = (outputs > 0.5).squeeze().long() # Squeeze to remove the dimension of size 1, convert to long

            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)


    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    logging.info(f"Layer {layer_idx} Test Accuracy: {accuracy:.4f}")

    return accuracy

In [16]:
# Main training loop

# Ensure model and tokenizer are loaded and on the correct device
# model = ... # Assume model is already loaded
# tokenizer = ... # Assume tokenizer is already loaded
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Ensure device is defined
print(f"Using device: {device}") # Print device for confirmation

import time # Import the time module

# Assume balanced datasets are prepared:
# train_dataset_balanced = ... # HuggingFace Dataset
# test_dataset_balanced = ... # HuggingFace Dataset

# Convert HuggingFace Datasets to custom ProbeDataset and then to DataLoaders
train_probe_dataset = ProbeDataset(train_dataset_balanced)
test_probe_dataset = ProbeDataset(test_dataset_balanced)

# Define DataLoader batch size (can be adjusted)
batch_size = 32

# Create a generator on the specified device for DataLoader shuffling
g = torch.Generator(device=device)

train_dataloader = DataLoader(train_probe_dataset, batch_size=batch_size, shuffle=True, generator=g)
test_dataloader = DataLoader(test_probe_dataset, batch_size=batch_size, shuffle=False) # No shuffling needed for test

# Get the number of layers in the model
# Note: Some models have an initial embedding layer and then transformer layers.
# hidden_states list usually includes the embedding output + layer outputs.
# The number of layers to probe is typically the number of transformer layers.
num_layers = len(model.model.layers) # Adjust based on your model's structure if needed
logging.info(f"Model has {num_layers} transformer layers.")

layer_wise_accuracies = []
best_accuracy = -1
best_layer = -1

# Iterate through each transformer layer and train a probe
# We probe layers from 0 up to num_layers - 1
for layer_idx in range(num_layers):
    logging.info(f"\n--- Probing Layer {layer_idx} ---")
    start_time = time.time() # Start timing for the layer

    # Train and evaluate the probe for the current layer
    accuracy = train_and_evaluate_probe(
        layer_idx=layer_idx,
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        model=model,
        tokenizer=tokenizer,
        device=device,
        num_epochs=10, # Adjust epochs and learning rate if needed
        learning_rate=0.02 # (default 0.001)
    )
    layer_wise_accuracies.append(accuracy)

    end_time = time.time() # End timing for the layer
    duration = end_time - start_time
    logging.info(f"Layer {layer_idx} training and evaluation took {duration:.2f} seconds.")


    # Track the best performing layer
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_layer = layer_idx

logging.info("\n--- Probing Complete ---")
logging.info(f"Layer-wise accuracies: {layer_wise_accuracies}")
logging.info(f"Best probe accuracy: {best_accuracy:.4f} on Layer {best_layer}")

# Optional: Plot the results
# import matplotlib.pyplot as plt
# plt.figure(figsize=(10, 6))
# plt.plot(range(num_layers), layer_wise_accuracies, marker='o')
# plt.xlabel('Layer Index')
# # plt.ylabel('Test Accuracy')
# plt.title('Linear Probe Accuracy by Layer')
# plt.xticks(range(num_layers))
# plt.grid(True)
# plt.show()

Using device: cuda


Layer 0 Epoch 1/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 2/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 3/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 4/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 5/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 6/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 7/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 8/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 9/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 Epoch 10/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 0 (Eval):   0%|          | 0/9 [00:00<?, ?it/s]

Layer 1 Epoch 1/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 2/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 3/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 4/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 5/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 6/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 7/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 8/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 9/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 Epoch 10/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 1 (Eval):   0%|          | 0/9 [00:00<?, ?it/s]

Layer 2 Epoch 1/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 2/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 3/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 4/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 5/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 6/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 7/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 8/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 9/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 Epoch 10/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 2 (Eval):   0%|          | 0/9 [00:00<?, ?it/s]

Layer 3 Epoch 1/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 2/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 3/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 4/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 5/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 6/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 7/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 8/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 9/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

Layer 3 Epoch 10/10 (Train):   0%|          | 0/36 [00:00<?, ?it/s]

KeyboardInterrupt: 