# Steering Vector Experiments for Vision-Language Models



## Workflow Overview:
1. **Setup & Imports** - Load necessary libraries
2. **Model Loading** - Load LLaVA model and inspect architecture
3. **Helper Functions** - Define utility functions for prompts and inference
4. **Load M3CoT Dataset** - Load and inspect first 10 test samples
5. **Baseline Generation** - Generate responses without steering
6. **Activation Extraction** - Extract hidden states from correct/incorrect examples
7. **Steering Vector Computation** - Calculate direction vectors
8. **Steered Generation** - Apply steering with hooks
9. **Comprehensive Evaluation** - Test across layers and scales
10. **Results Analysis** - Analyze and visualize results

---
## 1. Setup & Imports
Import all required libraries for model handling, data processing, and visualization.

In [1]:
# Core libraries
import os
import json
import csv
from datetime import datetime
from tqdm import tqdm

# PyTorch and transformers
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Dataset and image handling
import datasets
from PIL import Image

# Memory management
import gc

# Configuration
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

2026-02-09 23:21:50.882912: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770679311.207886     106 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770679311.322476     106 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770679312.078791     106 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770679312.078830     106 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770679312.078833     106 computation_placer.cc:177] computation placer alr

Using device: cuda:0
PyTorch version: 2.8.0+cu126
CUDA available: True


---
## 2. Model Loading & Architecture Inspection
Load the LLaVA-1.5-7B model and examine its layer structure.

In [2]:
def load_model(model_id="llava-hf/llava-1.5-7b-hf", device="cuda:0"):
    """
    Load LLaVA model with memory-efficient settings.
    
    Args:
        model_id: HuggingFace model identifier
        device: Device to load model on
    
    Returns:
        model: Loaded LLaVA model
        processor: Corresponding processor for text/image inputs
    """
    # Clear GPU and Python memory
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    print(f"Loading model: {model_id}")
    model = LlavaForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.float16,  # Use FP16 for memory efficiency
        low_cpu_mem_usage=True,
    ).to(device)

    processor = AutoProcessor.from_pretrained(model_id)
    
    print("âœ“ Model loaded successfully")
    return model, processor

# Load the model
model, processor = load_model()

`torch_dtype` is deprecated! Use `dtype` instead!


Loading model: llava-hf/llava-1.5-7b-hf


config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/674 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

âœ“ Model loaded successfully


In [3]:
print("="*80)
print("MODEL ARCHITECTURE INFORMATION")
print("="*80)

# ------------------------------
# Language model access
# ------------------------------
lm = model.language_model

layers = None
if hasattr(lm, "model") and hasattr(lm.model, "layers"):
    layers = lm.model.layers
elif hasattr(lm, "layers"):
    layers = lm.layers

# ------------------------------
# Transformer depth
# ------------------------------
if layers is not None:
    num_layers = len(layers)
    print(f"\nNumber of transformer layers: {num_layers}")
else:
    print("\nCould not detect transformer layers")

# ------------------------------
# Hidden dimension
# ------------------------------
if hasattr(lm, "config") and hasattr(lm.config, "hidden_size"):
    hidden_size = lm.config.hidden_size
    print(f"Hidden state dimension: {hidden_size}")

# ------------------------------
# Model components
# ------------------------------
vision_tower = getattr(model, "vision_tower", None)
mm_projector = getattr(model, "multi_modal_projector", None)

print(f"\nModel Components:")
print(f"  Vision Encoder (Tower): {type(vision_tower).__name__}")
print(f"  Multimodal Projector  : {type(mm_projector).__name__}")
print(f"  Language Model        : {type(lm).__name__}")

# ------------------------------
# Vision Encoder
# ------------------------------
print(f"\n------> Vision Encoder:")
if vision_tower is not None:
    print("  Converts images into visual patch embeddings")
    if hasattr(vision_tower, "config") and hasattr(vision_tower.config, "hidden_size"):
        print(f"  Output embedding dimension: {vision_tower.config.hidden_size}")
else:
    print("  Vision encoder not found")

# ------------------------------
# Multimodal Projector
# ------------------------------
print(f"\n------> Multimodal Projector:")
if mm_projector is not None:
    print("  Projects vision embeddings into language model hidden space")
    print(f"  Module: {mm_projector}")
else:
    print("  Multimodal projector not found")

# ------------------------------
# Language decoder (LM head)
# ------------------------------
print(f"\n------> Language Decoder (LM Head):")

lm_head_candidates = [
    getattr(model, "lm_head", None),
    getattr(lm, "lm_head", None),
    getattr(lm, "model", None) and getattr(lm.model, "lm_head", None),
    getattr(model, "language_model", None) and getattr(model.language_model, "lm_head", None),
]

lm_head = None
for candidate in lm_head_candidates:
    if candidate is not None:
        lm_head = candidate
        break

if lm_head is None:
    print("  Language decoder not found")
else:
    if hasattr(lm_head, "in_features") and hasattr(lm_head, "out_features"):
        print(f"  Decoder input dimension : {lm_head.in_features}")
        print(f"  Vocabulary size         : {lm_head.out_features}")
    else:
        print(f"  Decoder module          : {type(lm_head).__name__}")


# ------------------------------
# Language model layer inspection
# ------------------------------
if layers is not None:
    N = min(10, len(layers))
    print(f"\nLanguage Model Layer Structure (first {N} layers):")
    for i in range(N):
        layer = layers[i]
        print(f"\n  Layer {i}:")
        print(f"    Self Attention: {type(layer.self_attn).__name__}")
        print(f"    MLP           : {type(layer.mlp).__name__}")

# ------------------------------
# Parameter count
# ------------------------------
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel memory footprint: ~{total_params / 1e9:.2f}B parameters")

print("="*80)


MODEL ARCHITECTURE INFORMATION

Number of transformer layers: 32
Hidden state dimension: 4096

Model Components:
  Vision Encoder (Tower): CLIPVisionModel
  Multimodal Projector  : LlavaMultiModalProjector
  Language Model        : LlamaModel

------> Vision Encoder:
  Converts images into visual patch embeddings
  Output embedding dimension: 1024

------> Multimodal Projector:
  Projects vision embeddings into language model hidden space
  Module: LlavaMultiModalProjector(
  (linear_1): Linear(in_features=1024, out_features=4096, bias=True)
  (act): GELUActivation()
  (linear_2): Linear(in_features=4096, out_features=4096, bias=True)
)

------> Language Decoder (LM Head):
  Decoder input dimension : 4096
  Vocabulary size         : 32064

Language Model Layer Structure (first 10 layers):

  Layer 0:
    Self Attention: LlamaAttention
    MLP           : LlamaMLP

  Layer 1:
    Self Attention: LlamaAttention
    MLP           : LlamaMLP

  Layer 2:
    Self Attention: LlamaAttention
 

---
## 3. Helper Functions
Define utility functions for prompt formatting and inference.

In [4]:
def format_choices_for_prompt(question_text, choices_list):
    """
    Format multiple-choice question into a structured prompt.
    
    Args:
        question_text: The question string
        choices_list: List of answer choices
    
    Returns:
        Formatted prompt string with choices labeled A, B, C, etc.
    """
    formatted_choices = []
    for i, choice in enumerate(choices_list):
        letter = chr(ord('A') + i)  # Convert 0,1,2... to A,B,C...
        formatted_choices.append(f"{letter}. {choice}")

    choices_string = "\n".join(formatted_choices)
    
    final_prompt = (
        f"{question_text}\n\n"
        f"{choices_string}\n\n"
        "Please provide the best choice among the options (e.g., \"A. Choice Text\") "
        "and explain your detailed rationale for selecting it."
    )
    
    return final_prompt

def get_language_layers(model):
    """
    Get the language model layers from LLaVA model.
    Handles different model structures.
    
    Args:
        model: LLaVA model
    
    Returns:
        ModuleList of transformer layers
    """
    if hasattr(model.language_model, 'model'):
        # Structure: model.language_model.model.layers
        return model.language_model.model.layers
    elif hasattr(model.language_model, 'layers'):
        # Structure: model.language_model.layers
        return model.language_model.layers
    else:
        raise AttributeError("Could not find layers in language model")

# Test the functions
test_question = "What is the capital of France?"
test_choices = ["London", "Paris", "Berlin", "Madrid"]
print("Example formatted prompt:")
print("="*60)
print(format_choices_for_prompt(test_question, test_choices))
print("="*60)
print("\nâœ“ Helper functions defined")

Example formatted prompt:
What is the capital of France?

A. London
B. Paris
C. Berlin
D. Madrid

Please provide the best choice among the options (e.g., "A. Choice Text") and explain your detailed rationale for selecting it.

âœ“ Helper functions defined


In [5]:
def run_inference(image, text_prompt, model, processor, max_new_tokens=200):
    """
    Run inference on a single image-text pair (baseline, no steering).
    
    Args:
        image: PIL Image object
        text_prompt: Text prompt string
        model: LLaVA model
        processor: Model processor
        max_new_tokens: Maximum tokens to generate
    
    Returns:
        Generated response string
    """
    # Apply chat template
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": text_prompt},
                {"type": "image"},
            ],
        },
    ]
    formatted_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    # Process inputs
    inputs = processor(images=image, text=formatted_prompt, return_tensors='pt').to(0, torch.float16)

    # Generate output
    output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    # Decode and extract assistant's response
    response = processor.decode(output[0][2:], skip_special_tokens=True)
    return response.split("ASSISTANT:")[-1].strip()

print("âœ“ Inference function defined")

âœ“ Inference function defined


---
## 4. Load M3CoT Dataset
Load the M3CoT dataset from local JSONL file and inspect the first 10 samples.

In [6]:
# Configuration
NUM_TEST_SAMPLES = 10  # Number of samples to use for testing

print("Loading M3CoT dataset from HuggingFace...")
print("="*80)

# Load M3CoT validation dataset
m3cot_dataset = datasets.load_dataset("LightChen2333/M3CoT", split="validation")

print(f"âœ“ Dataset loaded successfully!")
print(f"Total samples in validation set: {len(m3cot_dataset)}")

# Take first 10 samples for testing
test_samples = m3cot_dataset.select(range(NUM_TEST_SAMPLES))
print(f"Using first {NUM_TEST_SAMPLES} samples for testing\n")

# Display information about the first 10 samples
print("="*80)
print("SAMPLE PREVIEW")
print("="*80)

for i in range(NUM_TEST_SAMPLES):
    sample = test_samples[i]
    
    print(f"\nðŸ“‹ Sample {i+1}/{NUM_TEST_SAMPLES}:")
    print(f"  ID: {sample['id']}")
    print(f"  Question: {sample['question'][:100]}..." if len(sample['question']) > 100 else f"  Question: {sample['question']}")
    print(f"  Choices: {sample['choices']}")
    print(f"  Correct Answer: {sample['answer']}")
    
    # Check image
    if sample['image'] is not None:
        img = sample['image']
        print(f"  Image size: {img.size[0]}x{img.size[1]}")
        print(f"  Image mode: {img.mode}")
        print(f"  Image available: âœ“")
    else:
        print(f"  Image available: âœ— (missing!)")

print("\n" + "="*80)
print(f"âœ“ Dataset loaded successfully!")
print(f"âœ“ Ready to process {NUM_TEST_SAMPLES} samples")
print("="*80)

Loading M3CoT dataset from HuggingFace...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00007.parquet:   0%|          | 0.00/460M [00:00<?, ?B/s]

data/train-00001-of-00007.parquet:   0%|          | 0.00/80.9M [00:00<?, ?B/s]

data/train-00002-of-00007.parquet:   0%|          | 0.00/180M [00:00<?, ?B/s]

data/train-00003-of-00007.parquet:   0%|          | 0.00/179M [00:00<?, ?B/s]

data/train-00004-of-00007.parquet:   0%|          | 0.00/92.9M [00:00<?, ?B/s]

data/train-00005-of-00007.parquet:   0%|          | 0.00/51.4M [00:00<?, ?B/s]

data/train-00006-of-00007.parquet:   0%|          | 0.00/40.9M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/156M [00:00<?, ?B/s]

data/test-00000-of-00002.parquet:   0%|          | 0.00/276M [00:00<?, ?B/s]

data/test-00001-of-00002.parquet:   0%|          | 0.00/60.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7863 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1108 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2318 [00:00<?, ? examples/s]

âœ“ Dataset loaded successfully!
Total samples in validation set: 1108
Using first 10 samples for testing

SAMPLE PREVIEW

ðŸ“‹ Sample 1/10:
  ID: physical-commonsense-1370
  Question: What is the reason for someone slicing and grabbing a pizza with their hands?
  Choices: ['They are trying to demonstrate their skill in handling food', 'They cannot find a pizza cutter and have to improvise', 'They think using a knife is quicker than a pizza cutter', 'They want to savor the smell of the pizza']
  Correct Answer: B
  Image size: 500x333
  Image mode: RGB
  Image available: âœ“

ðŸ“‹ Sample 2/10:
  ID: physical-commonsense-1409
  Question: What might be the possible function of the areay?
  Choices: ['Putting on makeup', 'Conducting business meetings', 'Taking a rest', 'Displaying artwork']
  Correct Answer: D
  Image size: 500x333
  Image mode: RGB
  Image available: âœ“

ðŸ“‹ Sample 3/10:
  ID: physical-commonsense-1400
  Question: Why is the man wearing an orange hat on the tennis cour

---
## 5. Baseline Response Generation
Generate baseline responses on the test samples (no steering applied).

In [7]:
# Output configuration
output_json = "generated_responses.json"

print(f"Generating baseline responses for {NUM_TEST_SAMPLES} samples...")
print("="*80)

generated_responses = []

# Process each test sample
for i in tqdm(range(NUM_TEST_SAMPLES), desc="Generating baseline responses"):
    entry = test_samples[i]
    print(f"\nProcessing sample {i+1}/{NUM_TEST_SAMPLES}")
    
    # Extract fields (HuggingFace dataset format)
    image = entry['image']  # PIL Image directly from dataset
    question = entry['question']
    choices = entry['choices']
    answer = entry['answer']
    sample_id = entry['id']
    
    # Format prompt
    text_prompt = format_choices_for_prompt(question, choices)
    print(f"  Question: {question[:60]}...")
    
    # Run inference
    response = run_inference(image, text_prompt, model, processor)
    print(f"  Response: {response[:80]}...")
    print(f"  Correct answer: {answer}")
    
    # Store results
    generated_responses.append({
        'id': sample_id,
        'question': question,
        'choices': choices,
        'answer': answer,
        'generated_answer': response
    })

# Save final results
with open(output_json, 'w') as f:
    json.dump(generated_responses, f, indent=4)

print(f"\nâœ“ Finished processing {len(generated_responses)} samples")
print(f"âœ“ Results saved to: {output_json}")

Generating baseline responses for 10 samples...


Generating baseline responses:   0%|          | 0/10 [00:00<?, ?it/s]


Processing sample 1/10
  Question: What is the reason for someone slicing and grabbing a pizza ...


Generating baseline responses:  10%|â–ˆ         | 1/10 [00:02<00:24,  2.67s/it]

  Response: C. They think using a knife is quicker than a pizza cutter....
  Correct answer: B

Processing sample 2/10
  Question: What might be the possible function of the areay?...


Generating baseline responses:  20%|â–ˆâ–ˆ        | 2/10 [00:04<00:17,  2.25s/it]

  Response: D. Displaying artwork

The area appears to be a part of a museum or an art galle...
  Correct answer: D

Processing sample 3/10
  Question: Why is the man wearing an orange hat on the tennis court?...


Generating baseline responses:  30%|â–ˆâ–ˆâ–ˆ       | 3/10 [00:06<00:13,  1.86s/it]

  Response: D. It's part of his tennis uniform. The man is wearing an orange hat on the tenn...
  Correct answer: C

Processing sample 4/10
  Question: What type of door is shown in the image?...


Generating baseline responses:  40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 4/10 [00:07<00:10,  1.80s/it]

  Response: C. A screen door

The image shows a dog standing in front of a screen door, whic...
  Correct answer: B

Processing sample 5/10
  Question: What is the likely reason for the presence of boats in the w...


Generating baseline responses:  50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 5/10 [00:09<00:08,  1.65s/it]

  Response: C. Boats were made for fishing. The presence of numerous boats in the water body...
  Correct answer: C

Processing sample 6/10
  Question: What can be inferred about the beach from the image?...


Generating baseline responses:  60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 6/10 [00:10<00:06,  1.72s/it]

  Response: B. The beach is windy and stormy.

The image shows a surfboard on the beach, whi...
  Correct answer: C

Processing sample 7/10
  Question: Based on the image, what can be inferred about the warehouse...


Generating baseline responses:  70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 7/10 [00:12<00:04,  1.63s/it]

  Response: C. The warehouse has not been cleaned recently. The image shows two trains parke...
  Correct answer: C

Processing sample 8/10
  Question: What can you infer about the weather based on the given imag...


Generating baseline responses:  90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 9/10 [00:14<00:01,  1.13s/it]

  Response: B. It is hot and sunny.

In the image, a group of young men is playing with a wh...
  Correct answer: D

Processing sample 9/10
  Question: Which property do these four objects have in common?...
  Response: A. Sweet...
  Correct answer: A

Processing sample 10/10
  Question: Which property do these three objects have in common?...


Generating baseline responses: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:14<00:00,  1.42s/it]

  Response: A. Translucent...
  Correct answer: B

âœ“ Finished processing 10 samples
âœ“ Results saved to: generated_responses.json





---
## 6. Activation Extraction
Extract hidden states from the model for correct and incorrect answer examples.

In [8]:
def extract_hidden_states(image, prompt, model, processor):
    """
    Extract hidden states from all transformer layers.
    
    Args:
        image: PIL Image
        prompt: Text prompt
        model: LLaVA model
        processor: Model processor
    
    Returns:
        List of hidden states (one per layer), each of shape [hidden_dim]
    """
    # Prepare conversation
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image"},
            ],
        },
    ]
    
    # Format prompt WITHOUT generation prompt (we want input embeddings only)
    formatted_prompt = processor.apply_chat_template(
        conversation, 
        add_generation_prompt=False, 
        tokenize=False
    )
    
    # Process inputs
    inputs = processor(images=image, text=formatted_prompt, return_tensors='pt').to(0, torch.float16)
    
    # Forward pass with hidden states
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    # Extract last token hidden states from each layer
    # outputs.hidden_states is a tuple: (embed_layer_output, layer_0, layer_1, ..., layer_N)
    # We skip the embedding layer (index 0) and take layers 1 onwards
    last_token_hidden_states = []
    for layer_idx, layer_hidden_states in enumerate(outputs.hidden_states[1:]):
        # layer_hidden_states shape: [batch_size, seq_len, hidden_dim]
        # Take the last token: [batch_size, hidden_dim] -> [hidden_dim]
        last_token_state = layer_hidden_states[0, -1, :].cpu()
        last_token_hidden_states.append(last_token_state)
    
    return last_token_hidden_states

print("âœ“ Activation extraction function defined")

âœ“ Activation extraction function defined


In [9]:
# Load generated responses to identify correct/incorrect examples
with open(output_json, 'r') as f:
    generated_responses = json.load(f)

print(f"Loaded {len(generated_responses)} generated responses")

# Separate correct and incorrect examples
# Store both the response AND the original dataset entry
correct_examples = []
incorrect_examples = []

for i, response_entry in enumerate(generated_responses):
    # Get the original dataset entry
    dataset_entry = test_samples[i]
    
    # Extract predicted answer (first letter)
    generated = response_entry['generated_answer']
    if generated and len(generated) > 0:
        predicted_letter = generated[0].upper()
        correct_letter = response_entry['answer']
        
        # Store the dataset entry (which has the image)
        if predicted_letter == correct_letter:
            correct_examples.append(dataset_entry)
        else:
            incorrect_examples.append(dataset_entry)

print(f"\nðŸ“Š Classification Results:")
print(f"  âœ“ Correct: {len(correct_examples)}")
print(f"  âœ— Incorrect: {len(incorrect_examples)}")
if len(generated_responses) > 0:
    print(f"  Accuracy: {len(correct_examples) / len(generated_responses) * 100:.2f}%")

Loaded 10 generated responses

ðŸ“Š Classification Results:
  âœ“ Correct: 4
  âœ— Incorrect: 6
  Accuracy: 40.00%


In [10]:
# Extract activations for correct and incorrect examples
NUM_SAMPLES_PER_CLASS = min(len(correct_examples), len(incorrect_examples))  # Use available samples

print(f"\nExtracting hidden states for {NUM_SAMPLES_PER_CLASS} correct examples...")
correct_hidden_states_list = []

for entry in tqdm(correct_examples[:NUM_SAMPLES_PER_CLASS], desc="Correct examples"):
    # Get image directly from dataset entry
    image = entry['image']
    prompt = format_choices_for_prompt(entry['question'], entry['choices'])
    
    hidden_states = extract_hidden_states(image, prompt, model, processor)
    correct_hidden_states_list.append(hidden_states)

print(f"\nExtracting hidden states for {NUM_SAMPLES_PER_CLASS} incorrect examples...")
incorrect_hidden_states_list = []

for entry in tqdm(incorrect_examples[:NUM_SAMPLES_PER_CLASS], desc="Incorrect examples"):
    # Get image directly from dataset entry  
    image = entry['image']
    prompt = format_choices_for_prompt(entry['question'], entry['choices'])
    
    hidden_states = extract_hidden_states(image, prompt, model, processor)
    incorrect_hidden_states_list.append(hidden_states)

print(f"\nâœ“ Extraction complete!")
print(f"  Correct: {len(correct_hidden_states_list)} samples Ã— {len(correct_hidden_states_list[0])} layers")
print(f"  Incorrect: {len(incorrect_hidden_states_list)} samples Ã— {len(incorrect_hidden_states_list[0])} layers")


Extracting hidden states for 4 correct examples...


Correct examples: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [00:00<00:00, 21.38it/s]



Extracting hidden states for 4 incorrect examples...


Incorrect examples: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [00:00<00:00, 24.14it/s]


âœ“ Extraction complete!
  Correct: 4 samples Ã— 32 layers
  Incorrect: 4 samples Ã— 32 layers





---
## 7. Steering Vector Computation
Compute the steering vector as the difference between mean activations of correct and incorrect examples.

In [11]:
def compute_steering_vector(correct_states_list, incorrect_states_list):
    """
    Compute steering vector as mean(correct) - mean(incorrect).
    
    Args:
        correct_states_list: List of hidden states for correct examples
                            Shape: [num_correct_samples, num_layers, hidden_dim]
        incorrect_states_list: List of hidden states for incorrect examples
                              Shape: [num_incorrect_samples, num_layers, hidden_dim]
    
    Returns:
        steering_vector: Tensor of shape [num_layers, hidden_dim]
    """
    num_layers = len(correct_states_list[0])
    hidden_dim = correct_states_list[0][0].shape[0]
    
    steering_vector = torch.zeros(num_layers, hidden_dim)
    
    print(f"Computing steering vector...")
    print(f"  Layers: {num_layers}")
    print(f"  Hidden dimension: {hidden_dim}")
    
    for layer_idx in range(num_layers):
        # Stack all samples for this layer
        correct_layer = torch.stack([states[layer_idx] for states in correct_states_list])
        incorrect_layer = torch.stack([states[layer_idx] for states in incorrect_states_list])
        
        # Compute means
        mean_correct = correct_layer.mean(dim=0)
        mean_incorrect = incorrect_layer.mean(dim=0)
        
        # Steering direction: correct - incorrect
        steering_vector[layer_idx] = mean_correct - mean_incorrect
        
        # Print statistics for first few layers
        if layer_idx < 3:
            magnitude = torch.norm(steering_vector[layer_idx]).item()
            print(f"  Layer {layer_idx}: magnitude = {magnitude:.4f}")
    
    print(f"\nâœ“ Steering vector computed")
    return steering_vector

# Compute the steering vector
steer_vector = compute_steering_vector(correct_hidden_states_list, incorrect_hidden_states_list)

# Save steering vector
torch.save(steer_vector, "steering_vector.pt")
print(f"âœ“ Steering vector saved to: steering_vector.pt")
print(f"  Shape: {steer_vector.shape}")

Computing steering vector...
  Layers: 32
  Hidden dimension: 4096
  Layer 0: magnitude = 0.0805
  Layer 1: magnitude = 0.1442
  Layer 2: magnitude = 0.2642

âœ“ Steering vector computed
âœ“ Steering vector saved to: steering_vector.pt
  Shape: torch.Size([32, 4096])


---
## 8. Steered Generation with Hooks
Apply the steering vector during generation using forward hooks.

In [31]:
# ============================================================
# ALTERNATIVE: Modified function that handles BOTH tensor and dict
# ============================================================

def do_multimodal_steering(
    model,
    processor,
    images,
    prompts,
    steering_vec=None,
    scale=1.0,
    normalize=True,
    layer=None,
    max_new_tokens=200
):
    """
    Generate text with optional steering vector applied to hidden states.
    
    Args:
        model: LLaVA model
        processor: Model processor
        images: List of PIL Images
        prompts: List of text prompts
        steering_vec: Steering vector - can be:
                     - None (no steering)
                     - dict {layer_0: vec, layer_1: vec, ...} for dict mode
                     - tensor [num_layers, hidden_dim] for tensor mode
        scale: Scaling factor for steering vector
        normalize: Whether to normalize the steering vector
        layer: Specific layer to apply steering (int) or None for all layers
        max_new_tokens: Maximum tokens to generate
    
    Returns:
        Generated text string
    """
    hooks = []
    
    def make_modify_activation(resolved_vec):
        """
        Create a hook function that adds steering vector to activations.
        """
        def modify_activation(module, input, output):
            if isinstance(output, tuple):
                activations = output[0]
            else:
                activations = output
            
            # Move vector to correct device and dtype
            vec = resolved_vec.to(device=activations.device, dtype=activations.dtype)
            
            # Normalize if requested
            if normalize:
                vec = torch.nn.functional.normalize(vec, dim=-1)
            
            # KEY FIX: Use unsqueeze(0).unsqueeze(0) to broadcast correctly
            steered_activations = activations + scale * vec.unsqueeze(0).unsqueeze(0)
            
            # Return in same format as input
            if isinstance(output, tuple):
                return (steered_activations,) + output[1:]
            else:
                return steered_activations
        
        return modify_activation
    
    # Register hooks if steering is enabled
    if steering_vec is not None:
        language_layers = get_language_layers(model)  #Getting all the language layers
        
        # OPTION 1: steering_vec is a dict
        if isinstance(steering_vec, dict):
            # Case 1a: dict + layer=None â†’ steer ALL layers
            if layer is None:
                for k, vec_np in steering_vec.items():
                    layer_idx = int(k.replace("layer_", ""))
                    
                    resolved_vec = torch.tensor(vec_np, dtype=torch.float32)
                    
                    if not (0 <= layer_idx < len(language_layers)):
                        raise ValueError(f"Layer index {layer_idx} out of bounds")
                    
                    target_layer = language_layers[layer_idx]
                    hook = target_layer.register_forward_hook(
                        make_modify_activation(resolved_vec)
                    )
                    hooks.append(hook)
                
                print(f"Registered steering hooks on ALL {len(hooks)} layers")
            
            # Case 1b: dict + specific layer
            else:
                layer_key = f"layer_{layer}"
                if layer_key not in steering_vec:
                    raise KeyError(f"{layer_key} not found in steering_vec dict")
                
                resolved_vec = torch.tensor(steering_vec[layer_key], dtype=torch.float32)
                target_layer = language_layers[layer]
                
                hook = target_layer.register_forward_hook(
                    make_modify_activation(resolved_vec)
                )
                hooks.append(hook)
                
                print(f"Registered steering hook on layer {layer}")
        
        # OPTION 2: steering_vec is a tensor [num_layers, hidden_dim]
        elif isinstance(steering_vec, torch.Tensor) and steering_vec.dim() == 2:
            num_layers = steering_vec.shape[0]
            
            # Case 2a: tensor + layer=None â†’ steer ALL layers
            if layer is None:
                for layer_idx in range(num_layers):
                    if not (0 <= layer_idx < len(language_layers)):
                        raise ValueError(f"Layer index {layer_idx} out of bounds")
                    
                    resolved_vec = steering_vec[layer_idx]
                    target_layer = language_layers[layer_idx]
                    
                    hook = target_layer.register_forward_hook(
                        make_modify_activation(resolved_vec)
                    )
                    hooks.append(hook)
                
                print(f"Registered steering hooks on ALL {num_layers} layers")
            
            # Case 2b: tensor + specific layer
            else:
                if not (0 <= layer < num_layers):
                    raise ValueError(f"Layer {layer} out of bounds (tensor has {num_layers} layers)")
                
                resolved_vec = steering_vec[layer]
                target_layer = language_layers[layer]
                
                hook = target_layer.register_forward_hook(
                    make_modify_activation(resolved_vec)
                )
                hooks.append(hook)
                
                print(f"Registered steering hook on layer {layer}")
        
        # OPTION 3: steering_vec is a single vector [hidden_dim] (backward compatible)
        else:
            if layer is None:
                raise ValueError("layer must be specified when steering_vec is a single vector")
            
            resolved_vec = steering_vec
            target_layer = language_layers[layer]
            
            hook = target_layer.register_forward_hook(
                make_modify_activation(resolved_vec)
            )
            hooks.append(hook)
            
            print(f"Registered steering hook on layer {layer}")
    
    try:
        # Prepare conversation
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompts[0]},
                    {"type": "image"},
                ],
            },
        ]
        formatted_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        
        # Process inputs
        inputs = processor(images=images[0], text=formatted_prompt, return_tensors='pt').to(0, torch.float16)
        
        # Generate with steering
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        
        # Decode response
        response = processor.decode(output[0][2:], skip_special_tokens=True)
        result = response.split("ASSISTANT:")[-1].strip()
        
    finally:
        # Always remove hooks
        for hook in hooks:
            hook.remove()
        if hooks:
            print("Removed steering hooks.")
    
    return result

print("âœ“ Flexible steered generation function defined (handles tensor and dict)")

âœ“ Flexible steered generation function defined (handles tensor and dict)


---
## 9. Comprehensive Steering Evaluation
Test steering across different layers and scales on the test dataset.

In [25]:
# Load test dataset
print("Loading M3CoT test dataset...")
test_dataset = datasets.load_dataset("LightChen2333/M3CoT", split="test")
print(f"âœ“ Loaded {len(test_dataset)} test examples")

# Configuration
NUM_TEST_EXAMPLES = 5  # Number of examples to evaluate
SCALES = [0.3, 0.6, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 6.0]
NUM_LAYERS = len(get_language_layers(model))

# Create output files
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_path = f"steering_results_{timestamp}.csv"
image_dir = f"steering_images_{timestamp}"
os.makedirs(image_dir, exist_ok=True)

print(f"\nðŸ“‹ Evaluation Configuration:")
print(f"  Test examples: {NUM_TEST_EXAMPLES}")
print(f"  Scales: {SCALES}")
print(f"  Layers: 0-{NUM_LAYERS-1} + all-layers mode")
print(f"  Output CSV: {csv_path}")
print(f"  Images: {image_dir}/")

# CSV structure
fieldnames = [
    "example_idx",
    "image_html",
    "question",
    "rationale",
    "mode",        # baseline | layer | all_layers
    "layer",       # int or None
    "scale",
    "output",
    "error"
]

print(f"\nðŸš€ Starting evaluation...")
print("="*80)

Loading M3CoT test dataset...
âœ“ Loaded 2318 test examples

ðŸ“‹ Evaluation Configuration:
  Test examples: 5
  Scales: [0.3, 0.6, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 6.0]
  Layers: 0-31 + all-layers mode
  Output CSV: steering_results_20260210_000709.csv
  Images: steering_images_20260210_000709/

ðŸš€ Starting evaluation...


In [32]:
# ============================================================
# CELL 2: Fixed Evaluation Loop
# ============================================================

with open(csv_path, "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()

    for ex_idx, entry in enumerate(test_dataset.select(range(NUM_TEST_EXAMPLES))):
        print(f"\n{'='*80}")
        print(f"Processing Example {ex_idx + 1}/{NUM_TEST_EXAMPLES}")
        print(f"{'='*80}")
        
        # Extract data
        image = entry["image"]
        question = entry["question"]
        rationale = entry.get("rationale", None)
        prompt = format_choices_for_prompt(question, entry["choices"])

        # Save image
        image_path = os.path.join(image_dir, f"example_{ex_idx}.png")
        image.save(image_path)
        image_html = f'<img src="{image_path}" width="256">'
        
        print(f"Question: {question[:80]}...")

        # ========== BASELINE (no steering) ==========
        print(f"\n[1/3] Running baseline (no steering)...")
        try:
            output = do_multimodal_steering(
                model, processor,
                [image], [prompt],
                steering_vec=None,
                scale=1.0,
                normalize=True,
                layer=None
            )
            writer.writerow({
                "example_idx": ex_idx,
                "image_html": image_html,
                "question": question,
                "rationale": rationale,
                "mode": "baseline",
                "layer": None,
                "scale": 1.0,
                "output": output,
                "error": None
            })
            print(f"  âœ“ Baseline: {output[:60]}...")
        except Exception as e:
            print(f"  âœ— Error: {repr(e)}")
            writer.writerow({
                "example_idx": ex_idx,
                "image_html": image_html,
                "question": question,
                "rationale": rationale,
                "mode": "baseline",
                "layer": None,
                "scale": 1.0,
                "output": None,
                "error": repr(e)
            })

        # ========== PER-LAYER STEERING ==========
        print(f"\n[2/3] Running per-layer steering ({NUM_LAYERS} layers Ã— {len(SCALES)} scales)...")
        total_runs = NUM_LAYERS * len(SCALES)
        completed = 0
        
        for scale in SCALES:
            for layer in range(NUM_LAYERS):
                try:
                    output = do_multimodal_steering(
                        model, processor,
                        [image], [prompt],
                        steering_vec=steer_vector,  # Pass dict directly
                        scale=scale,
                        normalize=True,
                        layer=layer  # Specify which layer
                    )
                    writer.writerow({
                        "example_idx": ex_idx,
                        "image_html": image_html,
                        "question": question,
                        "rationale": rationale,
                        "mode": "layer",
                        "layer": layer,
                        "scale": scale,
                        "output": output if output.strip() else None,
                        "error": None
                    })
                except Exception as e:
                    writer.writerow({
                        "example_idx": ex_idx,
                        "image_html": image_html,
                        "question": question,
                        "rationale": rationale,
                        "mode": "layer",
                        "layer": layer,
                        "scale": scale,
                        "output": None,
                        "error": repr(e)
                    })
                
                completed += 1
                if completed % 50 == 0:
                    print(f"  Progress: {completed}/{total_runs} ({completed/total_runs*100:.1f}%)")
        
        print(f"  âœ“ Completed {total_runs} per-layer runs")

        # ========== ALL-LAYERS STEERING ==========
        # KEY FIX: Pass steer_vector as dict with layer=None
        print(f"\n[3/3] Running all-layers steering ({len(SCALES)} scales)...")
        for scale in SCALES:
            try:
                output = do_multimodal_steering(
                    model, processor,
                    [image], [prompt],
                    steering_vec=steer_vector,  # Pass dict (not individual tensors)
                    scale=scale,
                    normalize=True,
                    layer=None  # None means apply to all layers when dict is provided
                )
                writer.writerow({
                    "example_idx": ex_idx,
                    "image_html": image_html,
                    "question": question,
                    "rationale": rationale,
                    "mode": "all_layers",
                    "layer": None,
                    "scale": scale,
                    "output": output if output.strip() else None,
                    "error": None
                })
                print(f"  âœ“ Scale {scale}: {output[:60]}...")
            except Exception as e:
                print(f"  âœ— Scale {scale} error: {repr(e)}")
                writer.writerow({
                    "example_idx": ex_idx,
                    "image_html": image_html,
                    "question": question,
                    "rationale": rationale,
                    "mode": "all_layers",
                    "layer": None,
                    "scale": scale,
                    "output": None,
                    "error": repr(e)
                })

print(f"\n{'='*80}")
print(f"âœ“ Evaluation complete!")
print(f"  Results: {csv_path}")
print(f"  Images: {image_dir}/")
print(f"{'='*80}")


Processing Example 1/5
Question: What is the likely purpose of the troll statue under the bridge?...

[1/3] Running baseline (no steering)...
  âœ“ Baseline: C. To honor a local legend...

[2/3] Running per-layer steering (32 layers Ã— 9 scales)...
Registered steering hook on layer 0
Removed steering hooks.
Registered steering hook on layer 1
Removed steering hooks.
Registered steering hook on layer 2
Removed steering hooks.
Registered steering hook on layer 3
Removed steering hooks.
Registered steering hook on layer 4
Removed steering hooks.
Registered steering hook on layer 5
Removed steering hooks.
Registered steering hook on layer 6
Removed steering hooks.
Registered steering hook on layer 7
Removed steering hooks.
Registered steering hook on layer 8
Removed steering hooks.
Registered steering hook on layer 9
Removed steering hooks.
Registered steering hook on layer 10
Removed steering hooks.
Registered steering hook on layer 11
Removed steering hooks.
Registered steering hook on 

---
## 10. Results Analysis
Load and analyze the steering results.

In [33]:
import pandas as pd

# Load results
df = pd.read_csv(csv_path)

print(f"Loaded {len(df)} result rows")
print(f"\nBreakdown by mode:")
print(df['mode'].value_counts())

print(f"\nError rate:")
print(f"  Errors: {df['error'].notna().sum()}")
print(f"  Success: {df['error'].isna().sum()}")

# Show sample outputs
print(f"\n" + "="*80)
print("Sample Results")
print("="*80)

for mode in ['baseline', 'layer', 'all_layers']:
    sample = df[df['mode'] == mode].iloc[0]
    print(f"\n{mode.upper()}:")
    if mode == 'layer':
        print(f"  Layer: {sample['layer']}, Scale: {sample['scale']}")
    elif mode == 'all_layers':
        print(f"  Scale: {sample['scale']}")
    print(f"  Output: {sample['output'][:100]}...")

df.head()

Loaded 1490 result rows

Breakdown by mode:
mode
layer         1440
all_layers      45
baseline         5
Name: count, dtype: int64

Error rate:
  Errors: 0
  Success: 1490

Sample Results

BASELINE:
  Output: C. To honor a local legend...

LAYER:
  Layer: 0.0, Scale: 0.3
  Output: C. To honor a local legend...

ALL_LAYERS:
  Scale: 0.3
  Output: B. To bring attention to the city's tourist attractions. The troll statue under the bridge serves as...


Unnamed: 0,example_idx,image_html,question,rationale,mode,layer,scale,output,error
0,0,"<img src=""steering_images_20260210_000709/exam...",What is the likely purpose of the troll statue...,"The words ""Troll Ave N"" on a sign suggest that...",baseline,,1.0,C. To honor a local legend,
1,0,"<img src=""steering_images_20260210_000709/exam...",What is the likely purpose of the troll statue...,"The words ""Troll Ave N"" on a sign suggest that...",layer,0.0,0.3,C. To honor a local legend,
2,0,"<img src=""steering_images_20260210_000709/exam...",What is the likely purpose of the troll statue...,"The words ""Troll Ave N"" on a sign suggest that...",layer,1.0,0.3,C. To honor a local legend,
3,0,"<img src=""steering_images_20260210_000709/exam...",What is the likely purpose of the troll statue...,"The words ""Troll Ave N"" on a sign suggest that...",layer,2.0,0.3,C. To honor a local legend,
4,0,"<img src=""steering_images_20260210_000709/exam...",What is the likely purpose of the troll statue...,"The words ""Troll Ave N"" on a sign suggest that...",layer,3.0,0.3,C. To honor a local legend,
