In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files
from circuit_tracer.utils.create_graph_files import load_graph_data

graph_dir = 'graphs'
graph_name = 'ndag.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

graph = load_graph_data(graph_path)

In [3]:
from algorithm1 import run_ea_optimization
from experiments.models import EAHyperparameters

hp = EAHyperparameters(
    population_size=50,
    n_generations=10
)

result = run_ea_optimization(
  graph, 
  verbose=True, 
  max_batch_per_gpu=8,
  hp=hp,
)

MultiGPUGraphEvaluator initialized with 4 GPUs: [0, 1, 2, 3]
  Max batch per GPU: 8
Precomputing graph influence scores...
Precomputing graph influence scores...
Starting Evolutionary Algorithm...
Initialized population of size 50
Starting Evolutionary Algorithm...
Initialized population of size 50


  cutoff_indices = torch.searchsorted(cum_fractions, layer_thresholds)


Initial evaluation complete.

=== Generation 1/10 ===
Best Individual:
  Fitness: -2.5554
  Completeness: 0.8961 | Replacement: 0.5859
  Nodes: 1660 (log: 7.41) | Edges: 177539 (log: 12.09)
  Mean Node Threshold: 0.7292 | Mean Edge Threshold: 0.7517
  Overrides: +0/-0 nodes, +0/-0 edges

=== Generation 2/10 ===
Best Individual:
  Fitness: -2.4772
  Completeness: 0.9087 | Replacement: 0.6473
  Nodes: 1369 (log: 7.22) | Edges: 253621 (log: 12.44)
  Mean Node Threshold: 0.6753 | Mean Edge Threshold: 0.7641
  Overrides: +0/-0 nodes, +0/-0 edges

=== Generation 2/10 ===
Best Individual:
  Fitness: -2.4772
  Completeness: 0.9087 | Replacement: 0.6473
  Nodes: 1369 (log: 7.22) | Edges: 253621 (log: 12.44)
  Mean Node Threshold: 0.6753 | Mean Edge Threshold: 0.7641
  Overrides: +0/-0 nodes, +0/-0 edges

=== Generation 3/10 ===
Best Individual:
  Fitness: -2.4506
  Completeness: 0.9006 | Replacement: 0.6050
  Nodes: 1448 (log: 7.28) | Edges: 135810 (log: 11.82)
  Mean Node Threshold: 0.7298 | M

In [4]:
print(result['completeness'], result['replacement'])
print(result['n_edges'], result['n_nodes'])

0.8955317735671997 0.632867157459259
58374 752


In [5]:
# Create graph files from EA result
from experiments.create_graph_files_ea import create_graph_files_from_ea_result

slug = "ndag_ea"
graph_file_dir = './graph_files'

create_graph_files_from_ea_result(
    graph_or_path=graph,
    ea_result=result,
    slug=slug,
    output_path=graph_file_dir
)

print(f"Graph files created in {graph_file_dir}/{slug}.json")



# Create graph files (baseline, default pruning)
slug = "ndag"  # this is the name that you assign to the graph
graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98

create_graph_files(
    graph_or_path=graph,
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)
print(f"Graph files created in {graph_file_dir}/{slug}.json")


Graph files created in ./graph_files/ndag_ea.json
pruning graph
Graph files created in ./graph_files/ndag.json
Graph files created in ./graph_files/ndag.json


In [8]:
from circuit_tracer.frontend.local_server import serve

port = 8194
server = serve(data_dir='./graph_files/', port=port)

from IPython.display import IFrame
display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))


In [7]:
server.stop()