In [11]:
import json
import os
import random
from pathlib import Path
from typing import Any, Dict, List, TypedDict
from utils import Question, QsDataset
from functools import partial
import plotly.express as px
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np

import torch
import yaml
from datasets import Dataset, DatasetDict, load_dataset
from tqdm import tqdm
from transformer_lens import HookedTransformer

In [12]:
with open('results/deepseek_r1_cots_iphr_eval.json', 'r') as f:
    cots = json.load(f)

# NOTE: I'm not sure about the below way of getting faithful CoTs
# cots_faithful = [x for x in cots if x['is_faithful_iphr'] == True]
cots_faithful = [x for x in cots if x['is_faithful_iphr'] == True or (x['is_faithful_iphr'] == None and x['eval_final_answer'] == x['ground_truth_answer'])]
cots_unfaithful = [x for x in cots if x['is_faithful_iphr'] == False]
print(f"Found {len(cots_faithful)} faithful and {len(cots_unfaithful)} unfaithful CoTs out of {len(cots)} total CoTs.")

Found 642 faithful and 179 unfaithful CoTs out of 1238 total CoTs.


In [13]:
with open('results/deepseek_r1_cots_iphr.json', 'r') as f:
    cots_raw = json.load(f)

list(set([x['prop_id'] for x in cots_raw]))

['wm-book-length',
 'boiling-points',
 'celebrity-heights',
 'wm-nyt-pubdate',
 'element-numbers',
 'wm-us-city-long',
 'wm-person-age',
 'first-flights',
 'melting-points',
 'wm-person-death',
 'river-lengths',
 'wm-nyc-place-long',
 'sound-speeds',
 'wm-movie-length',
 'wm-song-release',
 'tech-releases',
 'wm-nyc-place-lat',
 'wm-us-city-lat',
 'tunnel-lengths',
 'wm-person-birth',
 'wm-movie-release',
 'animals-speed',
 'bridge-lengths',
 'wm-book-release',
 'skyscraper-heights',
 'structure-completion',
 'aircraft-speeds',
 'mountain-heights',
 'wm-us-city-dens',
 'element-densities',
 'sea-depths',
 'train-speeds',
 'satellite-launches']

: 

In [3]:
# Model parameters
# NOTE: Ensure the model name is added to OFFICIAL_MODEL_NAMES in TransformerLens loading_from_pretrained.py. Also consider adjusting n_ctx if needed.
model_name = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'
context_length = 2048
batch_size = 4  # Number of questions to process per batch
temperature = 0.6  # Sampling temperature
max_new_tokens = 1024  # Maximum number of new tokens to generate
top_p = 0.92  # Nucleus sampling top-p value
    
# Disable gradient computation for faster inference
torch.set_grad_enabled(False)

# Model loading
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
model = HookedTransformer.from_pretrained(model_name)
model = model.to(device)
model.cfg.n_ctx = context_length # Adjust context length if necessary
print(f"Model loaded: {model.cfg.model_name}")
print(f"  Context length: {model.cfg.n_ctx}")
print(f"  Layers: {model.cfg.n_layers}")
print(f"  Vocab size: {model.cfg.d_vocab}")
print(f"  Hidden dim: {model.cfg.d_model}")
print(f"  Attention heads: {model.cfg.n_heads}")

Using device: cuda


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loaded pretrained model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B into HookedTransformer
Moving model to device:  cuda
Model loaded: DeepSeek-R1-Distill-Qwen-1.5B
  Context length: 2048
  Layers: 28
  Vocab size: 151936
  Hidden dim: 1536
  Attention heads: 12


In [None]:
example_unfaithful = cots_unfaithful[0]
logits, cache = model.run_with_cache(example_unfaithful['generated_cot'], remove_batch_dim=True)
print(logits.shape)

In [10]:
# First, let's examine our data
print(f"Total CoTs loaded: {len(cots)}")
print(f"Faithful CoTs: {len(cots_faithful)}")
print(f"Unfaithful CoTs: {len(cots_unfaithful)}")

# Let's look at the properties represented in our dataset
prop_counts = defaultdict(int)
for cot in cots:
    prop_counts[cot.get('prop_id', 'unknown')] += 1

print("\nProperty distribution in dataset:")
for prop, count in sorted(prop_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"  {prop}: {count}")

# Following the pairing logic from step2_evaluate_cots.py
# First, create a mapping from reversed question IDs to original question IDs
reversed_qid_map = {}

for cot in cots:
    qid = cot['qid']
    prop_id = cot['prop_id']
    comparison = cot['comparison']
    question_text = cot['question_text']
    
    # Create a reversed question ID by flipping the comparison
    reversed_comparison = 'lt' if comparison == 'gt' else 'gt'
    
    # Find the reversed question in the dataset
    for other_cot in cots:
        if (other_cot['prop_id'] == prop_id and 
            other_cot['comparison'] == reversed_comparison and
            other_cot['qid'] != qid):
            
            # Check if this is a reversed pair (same entities but reversed comparison)
            # This is a simplified check - the actual logic might need to be more sophisticated
            if set(question_text.lower().split()) == set(other_cot['question_text'].lower().split()):
                reversed_qid_map[qid] = other_cot['qid']
                reversed_qid_map[other_cot['qid']] = qid
                break

print(f"\nFound {len(reversed_qid_map) // 2} reversed question pairs")

# Now organize CoTs by question ID
question_pairs = defaultdict(dict)
for cot in cots:
    qid = cot['qid']
    is_faithful = cot.get('is_faithful_iphr')
    
    # Only include examples where faithfulness is explicitly marked
    if is_faithful is not None:
        if is_faithful:
            question_pairs[qid]['faithful'] = cot
        else:
            question_pairs[qid]['unfaithful'] = cot

# Find questions that have both faithful and unfaithful examples
paired_questions = {qid: pair for qid, pair in question_pairs.items() 
                   if 'faithful' in pair and 'unfaithful' in pair}

print(f"\nFound {len(paired_questions)} questions with both faithful and unfaithful CoTs")

# Let's also find reversed pairs where one has faithful and the other has unfaithful
reversed_faithful_unfaithful_pairs = []

for qid, pair in question_pairs.items():
    # Skip if this question already has both faithful and unfaithful
    if 'faithful' in pair and 'unfaithful' in pair:
        continue
    
    # Check if there's a reversed question
    reversed_qid = reversed_qid_map.get(qid)
    if reversed_qid and reversed_qid in question_pairs:
        reversed_pair = question_pairs[reversed_qid]
        
        # Check if we can form a faithful/unfaithful pair across the two questions
        if 'faithful' in pair and 'unfaithful' in reversed_pair:
            reversed_faithful_unfaithful_pairs.append((qid, reversed_qid))
        elif 'unfaithful' in pair and 'faithful' in reversed_pair:
            reversed_faithful_unfaithful_pairs.append((reversed_qid, qid))

print(f"\nFound {len(reversed_faithful_unfaithful_pairs)} additional pairs where reversed questions have faithful/unfaithful CoTs")

# Let's examine one example pair
if paired_questions:
    example_pair = next(iter(paired_questions.values()))
    print("\nExample question with both faithful and unfaithful CoTs:")
    print(f"Question: {example_pair['faithful']['question_text']}")
    print(f"Ground truth answer: {example_pair['faithful']['ground_truth_answer']}")
    print(f"Property ID: {example_pair['faithful']['prop_id']}")
    
    print("\nFaithful reasoning:")
    faithful_cot = example_pair['faithful']['generated_cot']
    # Print just the first few lines to avoid overwhelming output
    print('\n'.join(faithful_cot.split('\n')[:10]) + '\n...')
    
    print("\nUnfaithful reasoning:")
    unfaithful_cot = example_pair['unfaithful']['generated_cot']
    print('\n'.join(unfaithful_cot.split('\n')[:10]) + '\n...')

# Let's also examine a reversed pair example
if reversed_faithful_unfaithful_pairs:
    faithful_qid, unfaithful_qid = reversed_faithful_unfaithful_pairs[0]
    faithful_cot = question_pairs[faithful_qid]['faithful']
    unfaithful_cot = question_pairs[unfaithful_qid]['unfaithful']
    
    print("\nExample of reversed questions with faithful/unfaithful CoTs:")
    print(f"Faithful question: {faithful_cot['question_text']}")
    print(f"Faithful ground truth: {faithful_cot['ground_truth_answer']}")
    
    print(f"\nUnfaithful question: {unfaithful_cot['question_text']}")
    print(f"Unfaithful ground truth: {unfaithful_cot['ground_truth_answer']}")
    
    print("\nFaithful reasoning:")
    print('\n'.join(faithful_cot['generated_cot'].split('\n')[:10]) + '\n...')
    
    print("\nUnfaithful reasoning:")
    print('\n'.join(unfaithful_cot['generated_cot'].split('\n')[:10]) + '\n...')

# Create a combined dataset of all paired examples for analysis
all_paired_examples = []

# Add direct pairs
for qid, pair in paired_questions.items():
    all_paired_examples.append({
        'faithful': pair['faithful'],
        'unfaithful': pair['unfaithful'],
        'pair_type': 'direct'
    })

# Add reversed pairs
for faithful_qid, unfaithful_qid in reversed_faithful_unfaithful_pairs:
    all_paired_examples.append({
        'faithful': question_pairs[faithful_qid]['faithful'],
        'unfaithful': question_pairs[unfaithful_qid]['unfaithful'],
        'pair_type': 'reversed'
    })

print(f"\nTotal paired examples for analysis: {len(all_paired_examples)}")

# Let's also analyze the distribution of properties in our paired examples
paired_props = defaultdict(int)
for pair in all_paired_examples:
    paired_props[pair['faithful']['prop_id']] += 1

print("\nProperty distribution in paired examples:")
for prop, count in sorted(paired_props.items(), key=lambda x: x[1], reverse=True):
    print(f"  {prop}: {count}")

Total CoTs loaded: 1239
Faithful CoTs: 642
Unfaithful CoTs: 180

Property distribution in dataset:
  wm-book-length: 300
  element-numbers: 84
  aircraft-speeds: 57
  boiling-points: 56
  bridge-lengths: 56
  element-densities: 56
  first-flights: 56
  satellite-launches: 56
  sea-depths: 56
  skyscraper-heights: 56
  tech-releases: 56
  animals-speed: 52
  celebrity-heights: 44
  sound-speeds: 43
  train-speeds: 42
  tunnel-lengths: 42
  mountain-heights: 36
  river-lengths: 36
  melting-points: 28
  structure-completion: 27

Found 0 reversed question pairs

Found 0 questions with both faithful and unfaithful CoTs

Found 0 additional pairs where reversed questions have faithful/unfaithful CoTs

Total paired examples for analysis: 0

Property distribution in paired examples:


In [None]:
def patch_residual_component(corrupted_residual_component, hook, pos, clean_cache):
    """Patch a specific position in the residual stream with values from the clean cache."""
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component

def normalize_patched_logit_diff(patched_logit_diff, original_logit_diff, corrupted_logit_diff):
    """Normalize the patched logit difference to measure improvement."""
    # 0 means no change, 1 means fully recovered performance
    return (patched_logit_diff - corrupted_logit_diff) / (original_logit_diff - corrupted_logit_diff)

def run_residual_stream_patching(faithful_example, unfaithful_example, model):
    """Run residual stream patching between faithful and unfaithful examples."""
    # Tokenize both examples
    faithful_tokens = model.to_tokens(faithful_example['generated_cot'], prepend_bos=True)
    unfaithful_tokens = model.to_tokens(unfaithful_example['generated_cot'], prepend_bos=True)
    print(f"faithful_tokens.shape: {faithful_tokens.shape}, unfaithful_tokens.shape: {unfaithful_tokens.shape}")
    
    # Get the ground truth answer tokens
    answer_token = model.to_single_token(faithful_example['ground_truth_answer'])
    print(f"answer_token: {answer_token}")
    
    # Run the model on both examples and cache activations
    faithful_logits, faithful_cache = model.run_with_cache(faithful_tokens)
    unfaithful_logits, unfaithful_cache = model.run_with_cache(unfaithful_tokens)
    print(f"faithful_logits.shape: {faithful_logits.shape}, unfaithful_logits.shape: {unfaithful_logits.shape}")
    
    # Get the original logit values for the correct answer
    faithful_answer_logit = faithful_logits[0, -1, answer_token].item()
    unfaithful_answer_logit = unfaithful_logits[0, -1, answer_token].item()
    
    # Initialize results matrix
    patched_results = torch.zeros(model.cfg.n_layers, min(faithful_tokens.shape[1], unfaithful_tokens.shape[1]))
    
    # For each layer and position, patch the unfaithful run with the faithful activation
    for layer in tqdm(range(model.cfg.n_layers), desc="Patching layers"):
        for position in tqdm(range(min(faithful_tokens.shape[1], unfaithful_tokens.shape[1])), desc="Patching positions"):
            hook_fn = partial(patch_residual_component, pos=position, clean_cache=faithful_cache)
            
            patched_logits = model.run_with_hooks(
                unfaithful_tokens,
                fwd_hooks=[(f"blocks.{layer}.hook_resid_pre", hook_fn)],
                return_type="logits"
            )
            
            patched_answer_logit = patched_logits[0, -1, answer_token].item()
            
            # Normalize the improvement
            patched_results[layer, position] = normalize_patched_logit_diff(
                patched_answer_logit, 
                faithful_answer_logit, 
                unfaithful_answer_logit
            )
    
    return patched_results, faithful_tokens, unfaithful_tokens

# Run the analysis on a pair of examples
faithful_example = cots_faithful[0]
unfaithful_example = cots_unfaithful[0]

patched_results, faithful_tokens, unfaithful_tokens = run_residual_stream_patching(
    faithful_example, 
    unfaithful_example, 
    model
)

# Visualize the results
faithful_str_tokens = model.to_str_tokens(faithful_tokens[0])
position_labels = [f"{tok}_{i}" for i, tok in enumerate(faithful_str_tokens)]

fig = px.imshow(
    patched_results.cpu().numpy(),
    x=position_labels[:patched_results.shape[1]],
    labels={"x": "Position", "y": "Layer"},
    title="Impact of Patching Residual Stream from Faithful to Unfaithful CoT",
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0
)
fig.show()