# Steering Gemma-3 Toward Chain-of-Thought Reasoning

## What Are We Doing?

We are going to **hack the residual stream** of a language model at inference time to force it to "think more" about math problems.

The technique is called **Contrastive Activation Addition (CAA)**. The idea is simple:

1. Show the model examples of **detailed, step-by-step reasoning** (positive) and **blunt, direct answers** (negative).
2. Record the internal activations (hidden states) for both.
3. Subtract negative from positive to get a **"reasoning direction"** vector.
4. During generation, **inject** that vector into the residual stream so the model is nudged toward verbose reasoning.

This is essentially an **inference-time alternative to RLHF** — instead of updating weights to maximize a reasoning reward, we directly perturb the latent trajectory to enforce reasoning behaviors. No fine-tuning, no few-shot prompting, just raw tensor surgery.

**Target model:** `google/gemma-3-1b-it` (26 layers, 1152 hidden dim)

## Step 1: Setup and Model Loading

We load Gemma-3-1B-IT with automatic device detection. On Colab with a GPU we use `bfloat16` for VRAM efficiency. On Apple Silicon Macs (MPS) we use `float32` since MPS has limited bfloat16 support. The model has **26 transformer layers** — we'll be hooking into the middle layers where high-level task representations tend to crystallize.

In [None]:
!pip install -q transformers accelerate torch huggingface_hub

from huggingface_hub import login
login(token="YOUR-TOKEN")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gc

model_id = "google/gemma-3-1b-it"
HF_TOKEN = "YOUR-TOKEN"

# Device selection: CUDA > MPS > CPU
if torch.cuda.is_available():
    device = "cuda"
    model_dtype = torch.bfloat16
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
    model_dtype = torch.float32  # MPS has limited bfloat16 support
else:
    device = "cpu"
    model_dtype = torch.float32

print(f"Using device: {device}, dtype: {model_dtype}")
print("Loading model...")

tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=model_dtype,
    token=HF_TOKEN,
    device_map="auto" if device == "cuda" else None,
)

# For MPS/CPU, manually move the model to device
if device != "cuda":
    model = model.to(device)

model.eval()

print(f"Model loaded. Layers: {len(model.model.layers)}, Hidden dim: {model.config.hidden_size}")

## Step 2: The Contrastive Dataset

This is the heart of CAA. We build paired examples:

- **Positive (CoT):** The model sees a question answered with full step-by-step breakdown.
- **Negative (Direct):** The same question, but answered with just the final number.

When we subtract the negative activations from the positive ones, we isolate the **direction in activation space** that corresponds to "reasoning verbosely" vs "answering bluntly". This difference vector *is* the concept of Chain-of-Thought, encoded as a tensor.

In [None]:
pairs = [
    (
        "Question: What is 15 * 4? Answer: Let's break it down. 10 * 4 = 40. 5 * 4 = 20. 40 + 20 = 60. The answer is 60.",
        "Question: What is 15 * 4? Answer: 60."
    ),
    (
        "Question: If I have 3 apples and buy 5 more, then give away 2, how many do I have? Answer: Start with 3. Buy 5, so 3 + 5 = 8. Give away 2, so 8 - 2 = 6. The answer is 6.",
        "Question: If I have 3 apples and buy 5 more, then give away 2, how many do I have? Answer: 6."
    ),
    (
        "Question: Solve 2x + 5 = 15. Answer: First, subtract 5 from both sides to get 2x = 10. Then, divide by 2 to get x = 5. The answer is x=5.",
        "Question: Solve 2x + 5 = 15. Answer: x=5."
    ),
    (
        "Question: What is 144 / 12? Answer: I know that 12 * 12 = 144, so 144 divided by 12 is 12. The answer is 12.",
        "Question: What is 144 / 12? Answer: 12."
    ),
    (
        "Question: A train travels 60 km/h for 2.5 hours. How far does it go? Answer: Distance = speed * time. So distance = 60 * 2.5. 60 * 2 = 120 and 60 * 0.5 = 30. 120 + 30 = 150. The answer is 150 km.",
        "Question: A train travels 60 km/h for 2.5 hours. How far does it go? Answer: 150 km."
    ),
]

# Format each prompt through the chat template so the model sees them
# in its expected conversational format
pos_prompts = [
    tokenizer.apply_chat_template(
        [{"role": "user", "content": p[0]}],
        tokenize=False,
        add_generation_prompt=False
    ) for p in pairs
]
neg_prompts = [
    tokenizer.apply_chat_template(
        [{"role": "user", "content": p[1]}],
        tokenize=False,
        add_generation_prompt=False
    ) for p in pairs
]

print(f"Created {len(pairs)} contrastive pairs.")
print(f"\nExample positive prompt (raw):\n{pos_prompts[0][:200]}...")
print(f"\nExample negative prompt (raw):\n{neg_prompts[0][:200]}...")

## Step 3: The Vector Extractor (Representation Reading)

This is where the actual science happens. Here's the algorithm:

1. Attach a **`register_forward_hook`** to a specific transformer layer.
2. Run each positive and negative prompt through the model (forward pass only, no generation).
3. The hook intercepts the hidden states and grabs the **last token's activation** — this is the token that has accumulated the most context about the full prompt.
4. Average the positive activations, average the negative activations.
5. Subtract: `v_reason = mean(pos) - mean(neg)`
6. Normalize the resulting vector to unit length (prevents activation explosions when we inject it later).

The result is a single vector in $\mathbb{R}^{1152}$ that points in the direction of "step-by-step reasoning".

In [None]:
def get_steering_vector(model, tokenizer, pos_prompts, neg_prompts, layer_idx):
    """Extract a steering vector via contrastive activation differencing."""
    pos_acts = []
    neg_acts = []
    cache = []

    def cache_hook(module, input, output):
        # output is a tuple; output[0] is hidden states: [batch, seq_len, hidden_dim]
        # Grab the last token's activation and move to CPU to save VRAM
        hidden_states = output[0] if isinstance(output, tuple) else output
        cache.append(hidden_states[:, -1, :].detach().cpu())

    # Hook into the target decoder layer
    handle = model.model.layers[layer_idx].register_forward_hook(cache_hook)

    print(f"Extracting activations from layer {layer_idx}...")
    with torch.no_grad():
        for i, (pos, neg) in enumerate(zip(pos_prompts, neg_prompts)):
            # Process positive example
            pos_inputs = tokenizer(pos, return_tensors="pt").to(device)
            model(**pos_inputs)
            pos_acts.append(cache.pop())

            # Process negative example
            neg_inputs = tokenizer(neg, return_tensors="pt").to(device)
            model(**neg_inputs)
            neg_acts.append(cache.pop())

            print(f"  Pair {i+1}/{len(pos_prompts)} done.")

    # Remove the hook — critical so it doesn't fire during generation later
    handle.remove()

    # Stack into tensors and compute mean activations
    pos_tensor = torch.stack(pos_acts).mean(dim=0)  # [1, hidden_dim]
    neg_tensor = torch.stack(neg_acts).mean(dim=0)  # [1, hidden_dim]

    # The steering vector: direction from "blunt" toward "reasoning"
    steering_vec = pos_tensor - neg_tensor

    # Normalize to unit length to prevent activation magnitude issues
    steering_vec = steering_vec / torch.norm(steering_vec)

    print(f"Steering vector extracted. Shape: {steering_vec.shape}, Norm: {torch.norm(steering_vec).item():.4f}")
    return steering_vec

### Extract the Vector

Gemma-3-1B has **26 layers** (indices 0–25). We target **layer 13** — right in the middle. This is where high-level semantic representations (like "am I reasoning step-by-step?") tend to live, before the upper layers start routing toward specific vocabulary tokens.

In [None]:
TARGET_LAYER = 13
reasoning_vector = get_steering_vector(model, tokenizer, pos_prompts, neg_prompts, TARGET_LAYER)

# Free up memory from the extraction phase
gc.collect()
if device == "cuda":
    torch.cuda.empty_cache()
elif device == "mps":
    torch.mps.empty_cache()

## Step 4: The Injection Controller (Representation Control)

Now we use the extracted vector during generation. The injection hook does one thing:

$$\tilde{h}_l = h_l + \alpha \cdot \vec{v}_{\text{reason}}$$

Where:
- $h_l$ = the hidden state at layer $l$ for the token currently being generated
- $\vec{v}_{\text{reason}}$ = our normalized steering vector
- $\alpha$ = a scalar multiplier controlling how hard we push

We only modify the **last token** (`[:, -1, :]`) — that's the one actively being generated in the autoregressive loop. Previous tokens in the KV cache are untouched.

**Alpha tuning guide:**
- `0`: No effect (baseline)
- `5–15`: Subtle nudge toward more reasoning
- `15–30`: Strong reasoning push
- `50+`: Likely gibberish — you've pushed the activations so far off the manifold that the model can't recover

In [None]:
def generate_with_steering(model, tokenizer, prompt, steering_vector, layer_idx, alpha, max_new_tokens=256):
    """Generate text with a steering vector injected into the residual stream."""
    # Tokenize through the chat template
    inputs = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt}],
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(device)

    # Move steering vector to GPU in the model's dtype
    vec = steering_vector.to(device, dtype=model.dtype)

    def injection_hook(module, input, output):
        hidden_states = output[0] if isinstance(output, tuple) else output
        # Add the scaled reasoning vector to the current generation token
        hidden_states[:, -1, :] += alpha * vec
        # Repackage the output tuple so downstream layers see the modified states
        if isinstance(output, tuple):
            return (hidden_states,) + output[1:]
        return hidden_states

    # Register the injection hook
    handle = model.model.layers[layer_idx].register_forward_hook(injection_hook)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7
        )

    # CRITICAL: always remove the hook after generation
    handle.remove()

    # Decode only the newly generated tokens (skip the prompt)
    input_len = inputs["input_ids"].shape[1]
    return tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)

## Step 5: The Experiment

Let's put it all together. We'll throw a genuinely hard multi-step problem at the model — one that requires setting up equations, substitution, and careful arithmetic. We compare:

1. **Baseline** (`alpha=0`): The model generates normally, no intervention.
2. **Steered** (`alpha=15`): The reasoning vector is injected at every generation step.

If CAA works, the steered output should be noticeably more verbose and step-by-step compared to the baseline.

In [None]:
test_prompt = "A factory produces widgets in two shifts. The day shift produces 3 times as many widgets as the night shift. If the night shift produces x widgets, and 15% of all widgets from both shifts are defective, and each non-defective widget sells for $12, how much revenue does the factory make if the night shift produces 120 widgets?"

print("=" * 60)
print("BASELINE (alpha=0, no steering)")
print("=" * 60)
baseline = generate_with_steering(
    model, tokenizer, test_prompt, reasoning_vector, TARGET_LAYER, alpha=0.0
)
print(baseline)

print("\n" + "=" * 60)
print("STEERED (alpha=15, reasoning vector injected)")
print("=" * 60)
steered = generate_with_steering(
    model, tokenizer, test_prompt, reasoning_vector, TARGET_LAYER, alpha=15.0
)
print(steered)

## Step 6: Sweep Across Alpha Values

Let's see how the model's behavior changes as we crank up the steering strength. This gives us intuition about the "reasoning manifold" — at what point does the model go from normal to verbose to incoherent?

In [None]:
sweep_prompt = "A train leaves City A at 9:00 AM traveling at 80 km/h toward City B. Another train leaves City B at 10:00 AM traveling at 120 km/h toward City A. If the cities are 560 km apart, at what time do the trains meet?"

for alpha in [0, 5, 10, 20, 35]:
    print(f"\n{'=' * 60}")
    print(f"Alpha = {alpha}")
    print(f"{'=' * 60}")
    result = generate_with_steering(
        model, tokenizer, sweep_prompt, reasoning_vector, TARGET_LAYER,
        alpha=float(alpha), max_new_tokens=300
    )
    print(result)

## What Just Happened?

We built a **steering mechanism from scratch** using only PyTorch hooks:

1. **Extracted** a reasoning direction from contrastive examples
2. **Injected** it into the residual stream during autoregressive generation
3. **Observed** how increasing $\alpha$ pushes the model from terse answers toward verbose step-by-step reasoning (and eventually into gibberish)

This is the core idea behind **representation engineering** — you don't need to retrain the model, you just need to know which direction to push in activation space.