In [None]:
# ============================================================================
# STEP 0: Install packages
# ========================================================================

import os

cache_dir = "/workspace/huggingface_cache"
os.environ['HF_HOME'] = cache_dir

# Create the directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)

print(f"✅ Hugging Face cache directory is now set to: {os.environ['HF_HOME']}")

!pip install huggingface_hub transformers accelerate einops hf_transfer

#  Clone the repository if it doesn't exist, or pull the latest changes if it does.
!if [ -d "repository/circuit-tracer" ]; then \
    echo "✅ Repository found. Pulling latest changes..."; \
    (cd repository/circuit-tracer && git pull); \
else \
    echo "Cloning repository for the first time..."; \
    mkdir -p repository && git clone https://github.com/safety-research/circuit-tracer repository/circuit-tracer; \
fi

!pip install ./repository/circuit-tracer

In [4]:
import os
import sys
import subprocess
from huggingface_hub import login

# Add the cloned repository to the Python path
sys.path.append('repository/circuit-tracer')
sys.path.append('repository/circuit-tracer/demos')

# This will prompt you for your token
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
"""
Sequential CoT Attribution Analysis
Implementation for analyzing Chain-of-Thought reasoning using attribution graphs
"""

import torch
import numpy as np
from typing import List, Dict, Tuple, Optional, Union
from dataclasses import dataclass
from pathlib import Path
import networkx as nx
from collections import defaultdict

from circuit_tracer import attribute, ReplacementModel
from circuit_tracer.graph import Graph
from circuit_tracer.utils import create_graph_files
from circuit_tracer.graph import prune_graph

In [4]:
# ============================================================================
# STEP 1: Generate Attribution Graphs for Reasoning Steps
# ============================================================================

@dataclass
class ReasoningStep:
    """Represents a single reasoning step in a CoT sequence"""
    step_idx: int
    text: str
    start_token_idx: int
    end_token_idx: int
    graph: Optional[Graph] = None


class CoTAttributionGenerator:
    """Generate attribution graphs for each step in a CoT reasoning sequence"""
    
    def __init__(
        self,
        model: ReplacementModel,
        max_n_logits: int = 10,
        desired_logit_prob: float = 0.95,
        batch_size: int = 512,
        max_feature_nodes: int = 4096,
        verbose: bool = True
    ):
        self.model = model
        self.max_n_logits = max_n_logits
        self.desired_logit_prob = desired_logit_prob
        self.batch_size = batch_size
        self.max_feature_nodes = max_feature_nodes
        self.verbose = verbose
    
    def parse_cot_steps(
        self, 
        prompt: str, 
        cot_completion: str,
        step_delimiter: str = "\n"
    ) -> List[ReasoningStep]:
        """
        Parse CoT completion into individual reasoning steps
        
        Args:
            prompt: The initial prompt/question
            cot_completion: The full CoT reasoning text
            step_delimiter: How to split steps (default: newline)
        
        Returns:
            List of ReasoningStep objects
        """
        # Tokenize to get token indices
        full_text = prompt + cot_completion
        tokens = self.model.tokenizer.encode(full_text)
        prompt_tokens = self.model.tokenizer.encode(prompt)
        
        # Split completion into steps
        steps_text = cot_completion.split(step_delimiter)
        steps_text = [s.strip() for s in steps_text if s.strip()]
        
        reasoning_steps = []
        current_token_idx = len(prompt_tokens)
        
        for step_idx, step_text in enumerate(steps_text):
            step_tokens = self.model.tokenizer.encode(step_text)
            end_token_idx = current_token_idx + len(step_tokens)
            
            step = ReasoningStep(
                step_idx=step_idx,
                text=step_text,
                start_token_idx=current_token_idx,
                end_token_idx=end_token_idx
            )
            reasoning_steps.append(step)
            current_token_idx = end_token_idx
        
        return reasoning_steps
    
    def generate_graph_for_step(
        self,
        prompt: str,
        reasoning_step: ReasoningStep,
        previous_steps: Optional[List[ReasoningStep]] = None
    ) -> Graph:
        """
        Generate attribution graph for a single reasoning step
        
        Args:
            prompt: The initial prompt
            reasoning_step: The step to analyze
            previous_steps: Previous steps for context
        
        Returns:
            Attribution graph for this step
        """
        # Build context: prompt + all previous steps + current step
        context_text = prompt
        if previous_steps:
            context_text += " " + " ".join([s.text for s in previous_steps])
        context_text += " " + reasoning_step.text
        
        if self.verbose:
            print(f"\n{'='*60}")
            print(f"Generating graph for Step {reasoning_step.step_idx}")
            print(f"Step text: {reasoning_step.text[:100]}...")
            print(f"{'='*60}")
        
        # Generate attribution graph using circuit-tracer
        graph = attribute(
            prompt=context_text,
            model=self.model,
            max_n_logits=self.max_n_logits,
            desired_logit_prob=self.desired_logit_prob,
            batch_size=self.batch_size,
            max_feature_nodes=self.max_feature_nodes,
            verbose=self.verbose
        )
        
        return graph
    
    def generate_sequential_graphs(
        self,
        prompt: str,
        cot_completion: str,
        step_delimiter: str = "\n"
    ) -> List[ReasoningStep]:
        """
        Generate attribution graphs for all steps in a CoT sequence
        
        Args:
            prompt: The initial prompt
            cot_completion: The full CoT reasoning
            step_delimiter: How to split steps
        
        Returns:
            List of ReasoningSteps with graphs attached
        """
        # Parse steps
        steps = self.parse_cot_steps(prompt, cot_completion, step_delimiter)
        
        # Generate graph for each step
        for i, step in enumerate(steps):
            previous_steps = steps[:i] if i > 0 else None
            step.graph = self.generate_graph_for_step(
                prompt=prompt,
                reasoning_step=step,
                previous_steps=previous_steps
            )
        
        return steps


In [5]:
# ============================================================================
# STEP 2: Aggregate and Combine Attribution Graphs
# based on https://github.com/safety-research/circuit-tracer/blob/main/circuit_tracer/graph.py
# ============================================================================

class GraphAggregator:
    """Combine and aggregate multiple attribution graphs"""
    
    @staticmethod
    def merge_sequential_graphs(
        reasoning_steps: List[ReasoningStep],
        merge_strategy: str = "union"
    ) -> Graph:
        """
        Merge multiple graphs into a single combined graph
        
        Args:
            reasoning_steps: Steps with graphs to merge
            merge_strategy: 'union' (combine all) or 'intersection' (common only)
        
        Returns:
            Merged graph
        """
        if not reasoning_steps or not reasoning_steps[0].graph:
            raise ValueError("No graphs to merge")
        
        # Start with first graph
        merged_graph = reasoning_steps[0].graph
        
        if merge_strategy == "union":
            # Combine all nodes and edges
            for step in reasoning_steps[1:]:
                if step.graph:
                    merged_graph = GraphAggregator._union_graphs(
                        merged_graph, 
                        step.graph
                    )
        
        elif merge_strategy == "intersection":
            # Keep only common nodes/edges
            for step in reasoning_steps[1:]:
                if step.graph:
                    merged_graph = GraphAggregator._intersect_graphs(
                        merged_graph, 
                        step.graph
                    )
        
        return merged_graph
    
    @staticmethod
    def _union_graphs(graph1: Graph, graph2: Graph) -> Graph:
        """Combine two graphs (union of nodes and edges)"""
        # This is a simplified implementation
        
        combined_nodes = set(graph1.nodes) | set(graph2.nodes)
        combined_edges = {}
        
        # Combine edges, summing weights for common edges
        for edge, weight in graph1.edges.items():
            combined_edges[edge] = weight
        
        for edge, weight in graph2.edges.items():
            if edge in combined_edges:
                combined_edges[edge] += weight
            else:
                combined_edges[edge] = weight
        
        # Create new graph (THIS PART NEEDS TO BE VALIDATED)
        
        return Graph(nodes=combined_nodes, edges=combined_edges)
    
    @staticmethod
    def _intersect_graphs(graph1: Graph, graph2: Graph) -> Graph:
        """Find common elements between two graphs"""
        common_nodes = set(graph1.nodes) & set(graph2.nodes)
        common_edges = {}
        
        # Keep only edges present in both
        for edge, weight1 in graph1.edges.items():
            if edge in graph2.edges:
                # Average the weights
                weight2 = graph2.edges[edge]
                common_edges[edge] = (weight1 + weight2) / 2
        
        return Graph(nodes=common_nodes, edges=common_edges)
    
    @staticmethod
    def compute_graph_similarity(
        graph1: Graph, 
        graph2: Graph,
        method: str = "jaccard"
    ) -> float:
        """
        Compute similarity between two graphs
        
        Args:
            graph1, graph2: Graphs to compare
            method: 'jaccard' (node overlap) or 'edge_overlap'
        
        Returns:
            Similarity score [0, 1]
        """
        if method == "jaccard":
            nodes1 = set(graph1.nodes)
            nodes2 = set(graph2.nodes)
            intersection = len(nodes1 & nodes2)
            union = len(nodes1 | nodes2)
            return intersection / union if union > 0 else 0.0
        
        elif method == "edge_overlap":
            edges1 = set(graph1.edges.keys())
            edges2 = set(graph2.edges.keys())
            intersection = len(edges1 & edges2)
            union = len(edges1 | edges2)
            return intersection / union if union > 0 else 0.0
        
        return 0.0


In [6]:
# ============================================================================
# STEP 3: Extract Features from Attribution Graphs
# based on https://github.com/safety-research/circuit-tracer/blob/main/circuit_tracer/graph.py
# ============================================================================

class GraphFeatureExtractor:
    """Extract interpretable features from attribution graphs"""
    
    def __init__(self, model: ReplacementModel):
        self.model = model
        self.transcoder = model.transcoder if hasattr(model, 'transcoder') else None
    
    def extract_all_features(self, graph: Graph) -> Dict[str, any]:
        """
        Extract comprehensive feature set from a graph
        
        Returns:
            Dictionary with all extracted features
        """
        features = {}
        
        # Global graph statistics
        features.update(self.extract_global_statistics(graph))
        
        # Node-level features
        features.update(self.extract_node_features(graph))
        
        # Topological features
        features.update(self.extract_topological_features(graph))
        
        # Transcoder-specific features
        if self.transcoder:
            features.update(self.extract_transcoder_features(graph))
        
        return features
    
    def extract_global_statistics(self, graph: Graph) -> Dict[str, float]:
        """Extract global graph-level statistics"""
        features = {
            'num_nodes': len(graph.nodes),
            'num_edges': len(graph.edges),
            'num_feature_nodes': sum(1 for n in graph.nodes if n.type == 'feature'),
            'num_token_nodes': sum(1 for n in graph.nodes if n.type == 'token'),
            'num_logit_nodes': sum(1 for n in graph.nodes if n.type == 'logit'),
        }
        
        # Logit statistics
        if hasattr(graph, 'logit_probs'):
            features['top_logit_prob'] = max(graph.logit_probs.values())
            features['logit_entropy'] = self._compute_entropy(
                list(graph.logit_probs.values())
            )
        
        return features
    
    def extract_node_features(self, graph: Graph) -> Dict[str, float]:
        """Extract node-level statistics"""
        activations = []
        influences = []
        
        for node in graph.nodes:
            if hasattr(node, 'activation'):
                activations.append(node.activation)
            if hasattr(node, 'influence'):
                influences.append(node.influence)
        
        features = {}
        
        if activations:
            features['mean_activation'] = np.mean(activations)
            features['max_activation'] = np.max(activations)
            features['std_activation'] = np.std(activations)
        
        if influences:
            features['mean_influence'] = np.mean(influences)
            features['max_influence'] = np.max(influences)
            features['total_influence'] = np.sum(influences)
        
        # Layer-wise histogram
        layer_counts = defaultdict(int)
        for node in graph.nodes:
            if hasattr(node, 'layer'):
                layer_counts[node.layer] += 1
        
        # Convert to feature vector
        max_layers = 32  # Adjust based on model
        for layer in range(max_layers):
            features[f'layer_{layer}_count'] = layer_counts.get(layer, 0)
        
        return features
    
    def extract_topological_features(self, graph: Graph) -> Dict[str, float]:
        """Extract graph topology features"""
        # Convert to NetworkX for analysis
        G = self._to_networkx(graph)
        
        features = {
            'graph_density': nx.density(G),
            'num_connected_components': nx.number_weakly_connected_components(G),
        }
        
        # Edge statistics
        edge_weights = [data.get('weight', 1.0) for _, _, data in G.edges(data=True)]
        if edge_weights:
            features['mean_edge_weight'] = np.mean(edge_weights)
            features['max_edge_weight'] = np.max(edge_weights)
            features['sum_edge_weights'] = np.sum(edge_weights)
        
        # Centrality measures
        try:
            degree_centrality = nx.degree_centrality(G)
            features['mean_degree_centrality'] = np.mean(list(degree_centrality.values()))
            features['max_degree_centrality'] = np.max(list(degree_centrality.values()))
        except:
            pass
        
        # Path-based features
        try:
            # Shortest paths from input to output
            input_nodes = [n for n in G.nodes() if 'input' in str(n)]
            output_nodes = [n for n in G.nodes() if 'output' in str(n)]
            
            if input_nodes and output_nodes:
                paths = []
                for inp in input_nodes[:5]:  # Sample a few
                    for out in output_nodes[:5]:
                        try:
                            length = nx.shortest_path_length(G, inp, out)
                            paths.append(length)
                        except:
                            pass
                
                if paths:
                    features['mean_path_length'] = np.mean(paths)
                    features['min_path_length'] = np.min(paths)
        except:
            pass
        
        return features
    
    def extract_transcoder_features(self, graph: Graph) -> Dict[str, float]:
        """Extract features specific to transcoder activations"""
        features = {}
        
        # Feature sparsity
        active_features = [n for n in graph.nodes if n.type == 'feature' and n.activation > 0]
        features['num_active_features'] = len(active_features)
        features['feature_sparsity'] = len(active_features) / max(len(graph.nodes), 1)
        
        # Feature activation patterns by layer
        layer_activations = defaultdict(list)
        for node in active_features:
            if hasattr(node, 'layer'):
                layer_activations[node.layer].append(node.activation)
        
        for layer, acts in layer_activations.items():
            features[f'layer_{layer}_mean_activation'] = np.mean(acts)
            features[f'layer_{layer}_num_active'] = len(acts)
        
        return features
    
    @staticmethod
    def _compute_entropy(probs: List[float]) -> float:
        """Compute Shannon entropy"""
        probs = np.array(probs)
        probs = probs / probs.sum()  # Normalize
        return -np.sum(probs * np.log2(probs + 1e-10))
    
    @staticmethod
    def _to_networkx(graph: Graph) -> nx.DiGraph:
        """Convert Graph to NetworkX for analysis"""
        G = nx.DiGraph()
        
        # Add nodes
        for node in graph.nodes:
            G.add_node(
                node.id,
                type=node.type,
                activation=getattr(node, 'activation', 0),
                influence=getattr(node, 'influence', 0)
            )
        
        # Add edges
        for (src, tgt), weight in graph.edges.items():
            G.add_edge(src, tgt, weight=weight)
        
        return G

In [7]:
# ============================================================================
# STEP 4: Advanced Analysis Functions
# ============================================================================

class SequentialFeatureAnalyzer:
    """Analyze features across multiple reasoning steps"""
    
    def __init__(self, feature_extractor: GraphFeatureExtractor):
        self.extractor = feature_extractor
    
    def track_feature_evolution(
        self,
        reasoning_steps: List[ReasoningStep]
    ) -> Dict[str, List[float]]:
        """
        Track how features change across reasoning steps
        
        Returns:
            Dictionary mapping feature names to time series
        """
        feature_trajectories = defaultdict(list)
        
        for step in reasoning_steps:
            if step.graph:
                features = self.extractor.extract_all_features(step.graph)
                for feat_name, feat_value in features.items():
                    feature_trajectories[feat_name].append(feat_value)
        
        return dict(feature_trajectories)
    
    def identify_persistent_features(
        self,
        reasoning_steps: List[ReasoningStep],
        threshold: float = 0.5
    ) -> List[Tuple[str, float]]:
        """
        Identify features that persist across multiple steps
        
        Args:
            reasoning_steps: Steps to analyze
            threshold: Minimum proportion of steps where feature must appear
        
        Returns:
            List of (feature_id, persistence_score) tuples
        """
        feature_presence = defaultdict(int)
        total_steps = len(reasoning_steps)
        
        for step in reasoning_steps:
            if step.graph:
                active_features = set(
                    n.id for n in step.graph.nodes 
                    if n.type == 'feature' and n.activation > 0
                )
                for feat_id in active_features:
                    feature_presence[feat_id] += 1
        
        # Calculate persistence scores
        persistent_features = [
            (feat_id, count / total_steps)
            for feat_id, count in feature_presence.items()
            if count / total_steps >= threshold
        ]
        
        return sorted(persistent_features, key=lambda x: x[1], reverse=True)
    
    def detect_reasoning_transitions(
        self,
        reasoning_steps: List[ReasoningStep]
    ) -> List[int]:
        """
        Detect critical transition points in reasoning
        
        Returns:
            List of step indices where significant changes occur
        """
        if len(reasoning_steps) < 2:
            return []
        
        transitions = []
        
        for i in range(1, len(reasoning_steps)):
            prev_graph = reasoning_steps[i-1].graph
            curr_graph = reasoning_steps[i].graph
            
            if prev_graph and curr_graph:
                # Compute graph similarity
                similarity = GraphAggregator.compute_graph_similarity(
                    prev_graph, 
                    curr_graph,
                    method="jaccard"
                )
                
                # If similarity drops significantly, it's a transition
                if similarity < 0.3:  # Threshold
                    transitions.append(i)
        
        return transitions
    
    def compare_cot_vs_nocot(
        self,
        cot_steps: List[ReasoningStep],
        nocot_graph: Graph
    ) -> Dict[str, float]:
        """
        Compare CoT reasoning graphs with non-CoT direct answer
        
        Returns:
            Dictionary of comparison metrics
        """
        # Extract features from non-CoT
        nocot_features = self.extractor.extract_all_features(nocot_graph)
        
        # Extract features from each CoT step
        cot_feature_series = []
        for step in cot_steps:
            if step.graph:
                features = self.extractor.extract_all_features(step.graph)
                cot_feature_series.append(features)
        
        # Compute statistics
        comparison = {}
        
        # Average CoT features
        avg_cot_features = {}
        for feat_name in nocot_features.keys():
            values = [f.get(feat_name, 0) for f in cot_feature_series]
            if values:
                avg_cot_features[feat_name] = np.mean(values)
        
        # Compare
        for feat_name in nocot_features.keys():
            nocot_val = nocot_features[feat_name]
            cot_val = avg_cot_features.get(feat_name, 0)
            
            if nocot_val != 0:
                comparison[f'{feat_name}_ratio'] = cot_val / nocot_val
            comparison[f'{feat_name}_diff'] = cot_val - nocot_val
        
        return comparison
    
    def aggregate_step_features(
        self,
        reasoning_steps: List[ReasoningStep],
        aggregation: str = "mean"
    ) -> Dict[str, float]:
        """
        Aggregate features across all steps
        
        Args:
            reasoning_steps: Steps to aggregate
            aggregation: 'mean', 'sum', 'max', 'min'
        
        Returns:
            Aggregated feature dictionary
        """
        all_features = []
        for step in reasoning_steps:
            if step.graph:
                features = self.extractor.extract_all_features(step.graph)
                all_features.append(features)
        
        if not all_features:
            return {}
        
        aggregated = {}
        feature_names = all_features[0].keys()
        
        for feat_name in feature_names:
            values = [f.get(feat_name, 0) for f in all_features]
            
            if aggregation == "mean":
                aggregated[feat_name] = np.mean(values)
            elif aggregation == "sum":
                aggregated[feat_name] = np.sum(values)
            elif aggregation == "max":
                aggregated[feat_name] = np.max(values)
            elif aggregation == "min":
                aggregated[feat_name] = np.min(values)
        
        return aggregated

In [8]:
# ============================================================================
# STEP 5: Visualization and Reporting
# ============================================================================

class FeatureVisualizer:
    """Visualize and report on extracted features"""
    
    @staticmethod
    def plot_feature_trajectories(
        feature_trajectories: Dict[str, List[float]],
        feature_names: Optional[List[str]] = None,
        save_path: Optional[Path] = None
    ):
        """Plot how features evolve across reasoning steps"""
        import matplotlib.pyplot as plt
        
        if feature_names is None:
            # Plot first 10 features
            feature_names = list(feature_trajectories.keys())[:10]
        
        fig, axes = plt.subplots(len(feature_names), 1, figsize=(12, 3*len(feature_names)))
        if len(feature_names) == 1:
            axes = [axes]
        
        for ax, feat_name in zip(axes, feature_names):
            values = feature_trajectories[feat_name]
            ax.plot(values, marker='o')
            ax.set_title(feat_name)
            ax.set_xlabel('Reasoning Step')
            ax.set_ylabel('Feature Value')
            ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
        else:
            plt.show()
    
    @staticmethod
    def generate_feature_report(
        reasoning_steps: List[ReasoningStep],
        analyzer: SequentialFeatureAnalyzer
    ) -> str:
        """Generate a text report of feature analysis"""
        report = []
        report.append("="*60)
        report.append("CHAIN-OF-THOUGHT FEATURE ANALYSIS REPORT")
        report.append("="*60)
        report.append(f"\nTotal Reasoning Steps: {len(reasoning_steps)}")
        
        # Persistent features
        persistent = analyzer.identify_persistent_features(reasoning_steps)
        report.append(f"\n--- Persistent Features (appear in >50% of steps) ---")
        for feat_id, score in persistent[:10]:
            report.append(f"  {feat_id}: {score:.2%}")
        
        # Transitions
        transitions = analyzer.detect_reasoning_transitions(reasoning_steps)
        report.append(f"\n--- Reasoning Transitions ---")
        report.append(f"Detected {len(transitions)} significant transitions at steps: {transitions}")
        
        # Feature evolution
        trajectories = analyzer.track_feature_evolution(reasoning_steps)
        report.append(f"\n--- Key Feature Trends ---")
        for feat_name in ['num_active_features', 'mean_influence', 'graph_density']:
            if feat_name in trajectories:
                values = trajectories[feat_name]
                report.append(f"  {feat_name}:")
                report.append(f"    Start: {values[0]:.4f}, End: {values[-1]:.4f}, Change: {values[-1] - values[0]:.4f}")
        
        report.append("\n" + "="*60)
        return "\n".join(report)


In [9]:
# ============================================================================
# STEP 6: End-to-End Pipeline
# ============================================================================

class CoTMechanisticAnalyzer:
    """Complete pipeline for analyzing CoT reasoning"""
    
    def __init__(self, model: ReplacementModel):
        self.model = model
        self.graph_generator = CoTAttributionGenerator(model)
        self.feature_extractor = GraphFeatureExtractor(model)
        self.analyzer = SequentialFeatureAnalyzer(self.feature_extractor)
        self.visualizer = FeatureVisualizer()
    
    def analyze_cot_sequence(
        self,
        prompt: str,
        cot_completion: str,
        step_delimiter: str = "\n",
        generate_report: bool = True,
        save_visualizations: bool = False
    ) -> Dict:
        """
        Complete end-to-end analysis of a CoT sequence
        
        Returns:
            Dictionary with all analysis results
        """
        print("Starting CoT Mechanistic Analysis...")
        
        # Step 1: Generate graphs
        print("\n[1/5] Generating attribution graphs...")
        reasoning_steps = self.graph_generator.generate_sequential_graphs(
            prompt=prompt,
            cot_completion=cot_completion,
            step_delimiter=step_delimiter
        )
        
        # Step 2: Extract features
        print("\n[2/5] Extracting features from graphs...")
        step_features = []
        for step in reasoning_steps:
            if step.graph:
                features = self.feature_extractor.extract_all_features(step.graph)
                step_features.append(features)
        
        # Step 3: Analyze sequences
        print("\n[3/5] Analyzing sequential patterns...")
        feature_trajectories = self.analyzer.track_feature_evolution(reasoning_steps)
        persistent_features = self.analyzer.identify_persistent_features(reasoning_steps)
        transitions = self.analyzer.detect_reasoning_transitions(reasoning_steps)
        aggregated_features = self.analyzer.aggregate_step_features(reasoning_steps)
        
        # Step 4: Visualize
        if save_visualizations:
            print("\n[4/5] Generating visualizations...")
            self.visualizer.plot_feature_trajectories(
                feature_trajectories,
                save_path=Path("feature_trajectories.png")
            )
        
        # Step 5: Generate report
        if generate_report:
            print("\n[5/5] Generating report...")
            report = self.visualizer.generate_feature_report(reasoning_steps, self.analyzer)
            print(report)
        
        # Return all results
        results = {
            'reasoning_steps': reasoning_steps,
            'step_features': step_features,
            'feature_trajectories': feature_trajectories,
            'persistent_features': persistent_features,
            'transitions': transitions,
            'aggregated_features': aggregated_features,
        }
        
        print("\n✓ Analysis complete!")
        return results


In [None]:
# ============================================================================
# USAGE EXAMPLE
# ============================================================================

model_name = 'google/gemma-2-2b'
transcoder_name = "gemma"
model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16)

def example_usage():
    """Example of how to use the complete pipeline"""
    
    # Initialize analyzer
    analyzer = CoTMechanisticAnalyzer(model)
    
    # Example CoT reasoning
    prompt = "What is 15 + 27?"
    cot_completion = """Let me solve this step by step.
Step 1: First, I'll add the ones place: 5 + 7 = 12
Step 2: I write down 2 and carry the 1
Step 3: Now add the tens place: 1 + 2 + 1 (carried) = 4
Step 4: Therefore, 15 + 27 = 42"""
    
    # Run complete analysis
    results = analyzer.analyze_cot_sequence(
        prompt=prompt,
        cot_completion=cot_completion,
        step_delimiter="\n",
        generate_report=True,
        save_visualizations=True
    )
    
    # Access specific results
    print("\nPersistent Features:")
    for feat_id, score in results['persistent_features'][:5]:
        print(f"  {feat_id}: {score:.2%}")
    
    print("\nTransition Points:")
    print(results['transitions'])
    
    # Compare with non-CoT
    nocot_completion = "42"
    nocot_graph = analyzer.graph_generator.generate_graph_for_step(
        prompt=prompt,
        reasoning_step=ReasoningStep(0, nocot_completion, 0, 1)
    )
    
    comparison = analyzer.analyzer.compare_cot_vs_nocot(
        results['reasoning_steps'],
        nocot_graph
    )
    
    print("\nCoT vs Non-CoT Comparison:")
    for metric, value in list(comparison.items())[:5]:
        print(f"  {metric}: {value:.4f}")


if __name__ == "__main__":
    example_usage()

### V2

In [5]:
model_name = 'google/gemma-2-2b'
transcoder_name = "gemma"
model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16)

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Loaded pretrained model google/gemma-2-2b into HookedTransformer


#### Example 1 - Multi-hop reasoning

In [16]:
prompt = "The capital of state containing Dallas is"  # What you want to get the graph for
max_n_logits = 10   # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size=256  # Batch size when attributing
offload= None #'disk' if IN_COLAB else 'cpu' # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)
verbose = True  # Whether to display a tqdm progress bar and timing report

In [17]:
graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)

Phase 0: Precomputing activations and vectors
Precomputation completed in 0.13s
Found 6347 active features
Phase 1: Running forward pass
Forward pass completed in 0.08s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.7188
Will include 6347 of 6347 feature nodes
Input vectors built in 0.02s
Phase 3: Computing logit attributions
Logit attributions completed in 0.10s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 6347/6347 [00:02<00:00, 2181.03it/s]
Feature attributions completed in 3.10s
Attribution completed in 4.03s


In [18]:
graph

<circuit_tracer.graph.Graph at 0x79439f6f1b80>

In [19]:
graph_dir = 'graphs'
graph_name = 'example_graph.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph.to_pt(graph_path)

In [22]:
slug = "dallas-austin"  # this is the name that you assign to the graph
graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98

create_graph_files(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)

In [12]:
from circuit_tracer.frontend.local_server import serve


port = 23
server = serve(data_dir='./graph_files/', port=port)

IN_COLAB = False

if IN_COLAB:
    from google.colab import output as colab_output  # noqa
    colab_output.serve_kernel_port_as_iframe(port, path='/index.html', height='800px', cache_in_notebook=True)
else:
    from IPython.display import IFrame
    print(f"Use the IFrame below, or open your graph here: f'http://localhost:{port}/index.html'")
    display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))


Use the IFrame below, or open your graph here: f'http://localhost:23/index.html'


In [33]:
pg = prune_graph(graph, node_threshold=0.7, edge_threshold=0.95)

In [34]:
pg

PruneResult(node_mask=tensor([False,  True, False,  ...,  True,  True,  True]), edge_mask=tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False,  True, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False]]), cumulative_scores=tensor([0.7486, 0.6564, 0.8938,  ..., 1.0000, 1.0000, 1.0000]))

#### Example 2 - GSM8K

In [6]:
prompt = "James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?"  
max_n_logits = 10 
desired_logit_prob = 0.95
max_feature_nodes = 8192
batch_size=256
offload= None
verbose = True

In [7]:
graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)

Phase 0: Precomputing activations and vectors
Precomputation completed in 0.47s
Found 25788 active features
Phase 1: Running forward pass
Forward pass completed in 0.31s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.8672
Will include 8192 of 25788 feature nodes
Input vectors built in 0.29s
Phase 3: Computing logit attributions
Logit attributions completed in 0.32s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 8192/8192 [00:08<00:00, 1021.06it/s]
Feature attributions completed in 8.03s
Attribution completed in 9.77s


In [8]:
graph_dir = 'graphs'
graph_name = 'example_graph_gsm8k.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph.to_pt(graph_path)

In [9]:
slug = "gsm8k-james-writes"  # this is the name that you assign to the graph
graph_file_dir = './graphs'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98

create_graph_files(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)