# Creation

In [1]:
%load_ext autoreload
%autoreload 2

In [28]:
from dotenv import load_dotenv
import torch
from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files
from pathlib import Path

load_dotenv()

model_name = 'google/gemma-2-2b'
transcoder_name = "gemma"
model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16)

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [35]:
prompt = "The war lasted from the year 1711 to 17"  # What you want to get the graph for
max_n_logits = 10   # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size=256  # Batch size when attributing
verbose = True  # Whether to display a tqdm progress bar and timing report

graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=None,
    verbose=verbose
)

graph_dir = 'graphs'
graph_name = 'war.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph.to_pt(graph_path)



Phase 0: Precomputing activations and vectors
Precomputation completed in 0.19s
Found 12392 active features
Phase 1: Running forward pass
Forward pass completed in 1.13s
Phase 2: Building input vectors
Selected 3 logits with cumulative probability 0.9766
Will include 8192 of 12392 feature nodes
Input vectors built in 0.75s
Phase 3: Computing logit attributions
Logit attributions completed in 0.30s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 8192/8192 [00:08<00:00, 942.68it/s] 
Feature attributions completed in 8.69s
Attribution completed in 11.21s


In [36]:
graph.active_features.shape

torch.Size([12392, 3])

# Pruning

In [3]:
from pathlib import Path

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files
from circuit_tracer.utils.create_graph_files import load_graph_data


graph_dir = 'graphs'
graph_name = 'war.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph = load_graph_data(graph_path)

In [12]:
from circuit_tracer.graph import compute_graph_scores, compute_graph_scores_masked, prune_graph

node_threshold=1
edge_threshold=1

node_mask, edge_mask, cumulative_scores = prune_graph(graph, node_threshold, edge_threshold)

# Sparsity stats
print(f"Nodes kept: {node_mask.sum().item() / len(node_mask):.2%}")
print(f"Edges kept: {edge_mask.sum().item() / edge_mask.numel():.2%}")

# Compare scores: original vs pruned
print(f"\nOriginal graph scores: {compute_graph_scores(graph)}")
print(f"Pruned graph scores:   {compute_graph_scores_masked(graph, node_mask, edge_mask)}")

pruning graph
Nodes kept: 99.54%
Edges kept: 100.00%
Nodes kept: 99.54%
Edges kept: 100.00%

Original graph scores: (0.7420152425765991, 0.9321903586387634)

Original graph scores: (0.7420152425765991, 0.9321903586387634)
Pruned graph scores:   (0.7420152425765991, 0.9321903586387634)
Pruned graph scores:   (0.7420152425765991, 0.9321903586387634)


In [None]:
(0.7420152425765991, 0.9321903586387634)

In [27]:
graph.active_features.shape

torch.Size([6359, 3])

In [18]:
0.9225512528473804 0.19385499706253542

1276.5351556567957

In [None]:
(0.7179655432701111, 0.925042450428009)


In [5]:
slug = "dallas-austin"  # this is the name that you assign to the graph
graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98

# pruning step
create_graph_files(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)


