In [1]:
from __future__ import annotations

import itertools
import time 
import os
import pickle
import random
from collections import defaultdict
from typing import DefaultDict, Dict, List, Tuple

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from networkx.algorithms import weisfeiler_lehman_graph_hash as wl_hash
import multiprocessing           
import contextlib 

In [4]:
from dgd.utils.utils5 import (
    calculate_truth_table_v2,
    generate_one_hot_features_from_adj,
    resize_matrix,
    energy_score,
    check_implicit_OR_existence_v2,
    add_implicit_OR_to_dag_v2,
    exhaustive_cut_enumeration_dag,
    is_fanout_free_standalone,
    generate_subgraph,
    substitute_subgraph, 
)

In [6]:
# ===============================================================
# 1. Static action catalogue (loaded once) ----------------------
# ===============================================================
CATALOG_PATH = "/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/scripts/action_motifs.pkl"

with open(CATALOG_PATH, "rb") as f:
    _cat = pickle.load(f)

UNIQUE_GRAPHS: List[nx.DiGraph] = _cat["graphs"]          # iso‑deduped actions
TTABLE_TO_ACTIONS: Dict[Tuple[int, int], List[int]] = _cat["lookup"]
NUM_ACTIONS = len(UNIQUE_GRAPHS)

print(f"[DRL3env] loaded {NUM_ACTIONS} iso‑actions from {CATALOG_PATH}")



[DRL3env] loaded 15928 iso‑actions from /home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/scripts/action_motifs.pkl


In [7]:
def load_graph_pickle(filename):
    """
    Load a graph from a pickle file and convert back to NetworkX format.
    
    Args:
        filename (str): Pickle file to load
        
    Returns:
        nx.DiGraph: Reconstructed graph
    """
    # Load the list from pickle
    with open(filename, 'rb') as f:
        graph_list = pickle.load(f)
    
    # Extract components
    num_nodes, edges, node_attrs = graph_list
    
    # Create new graph
    G = nx.DiGraph()
    
    # Add nodes with attributes
    for node, attr in node_attrs.items():
        if attr is not None:
            G.add_node(node, type=attr)
        else:
            G.add_node(node)
    
    # Add edges
    G.add_edges_from(edges)
    
    return G

# ===============================================================
# 2. Helper utilities -------------------------------------------
# ===============================================================
_perm_cache: DefaultDict[Tuple[int, Tuple[int, int]], nx.DiGraph] = defaultdict(dict)

def canon_hash(g: nx.DiGraph) -> str:
    """WL hash (directed) – identical for isomorphic DAGs."""
    return wl_hash(g, node_attr=None, edge_attr=None, iterations=3, digest_size=16)

def _truth_key(g: nx.DiGraph) -> Tuple[int, int]:
    """Return (#inputs, truth‑table‑int) signature for DAG `g`."""
    tt = calculate_truth_table_v2(g)
    bits = "".join(str(o[0]) for _, o in sorted(tt.items()))
    return (len(tt).bit_length() - 1, int(bits, 2))

def _permute_until_match(base: nx.DiGraph, key: Tuple[int, int], aid: int):
    """Return a *permuted copy* of `base` matching `key` or None if impossible."""
    #if (cached := _perm_cache[aid].get(key)) is not None:
    #    return cached.copy()
    inputs = [n for n in base if base.in_degree(n) == 0]
    if len(inputs) != key[0]:
        return None
    for perm in itertools.permutations(inputs):
        g2 = nx.relabel_nodes(base, dict(zip(inputs, perm)), copy=True)
        if _truth_key(g2) == key:
            #_perm_cache[aid][key] = g2.copy()
            return g2
    return None

def _apply_implicit_or(G: nx.DiGraph, fanin_size: int = 2) -> nx.DiGraph:
    """Return an optimised copy of `G` with an implicit‑OR inserted if beneficial."""
    G_opt = G.copy()
    sinks = [n for n in G_opt if G_opt.out_degree(n) == 0]
    if not sinks:
        return G_opt
    output = sinks[0]
    res = check_implicit_OR_existence_v2(G_opt, output, fanin_size)
    best_key, best_rm = None, 0
    for k, v in res.items():
        if v["is_there_an_implicit_OR"] and v["number_of_nodes_available_for_removal"] > best_rm:
            best_key, best_rm = k, v["number_of_nodes_available_for_removal"]
    if best_key is None:
        return G_opt
    cut, cone = res[best_key]["cut"], res[best_key]["cone"]
    return add_implicit_OR_to_dag_v2(G_opt, output, cut, cone)


def _definitely_same(g1: nx.DiGraph, g2: nx.DiGraph) -> bool:
    """Return True iff g1 ≅ g2 (isomorphic as directed graphs)."""
    if g1.number_of_nodes() != g2.number_of_nodes():
        return False
    if g1.number_of_edges() != g2.number_of_edges():
        return False
    return nx.is_isomorphic(g1, g2)

In [8]:
#to see all files in the folder

'''
from pathlib import Path

# directory that holds the pickle you just loaded
root = Path("/home/gridsan/spalacios/DRL1/supercloud-testing/ABC-and-PPO-testing1"
            "/Verilog_files_for_all_4_input_1_output_truth_tables_as_NIGs")

def _human(n):
    """bytes → human-readable"""
    for unit in ("B","KiB","MiB","GiB","TiB"):
        if n < 1024:
            return f"{n:.1f} {unit}"
        n /= 1024
    return f"{n:.1f} PiB"

print(f"\nListing files under: {root}\n")
for p in sorted(root.iterdir()):
    kind = "dir" if p.is_dir() else "file"
    size = _human(p.stat().st_size) if p.is_file() else "-"
    print(f"{kind:4}  {size:>8}  {p.name}")
'''

'\nfrom pathlib import Path\n\n# directory that holds the pickle you just loaded\nroot = Path("/home/gridsan/spalacios/DRL1/supercloud-testing/ABC-and-PPO-testing1"\n            "/Verilog_files_for_all_4_input_1_output_truth_tables_as_NIGs")\n\ndef _human(n):\n    """bytes → human-readable"""\n    for unit in ("B","KiB","MiB","GiB","TiB"):\n        if n < 1024:\n            return f"{n:.1f} {unit}"\n        n /= 1024\n    return f"{n:.1f} PiB"\n\nprint(f"\nListing files under: {root}\n")\nfor p in sorted(root.iterdir()):\n    kind = "dir" if p.is_dir() else "file"\n    size = _human(p.stat().st_size) if p.is_file() else "-"\n    print(f"{kind:4}  {size:>8}  {p.name}")\n'

In [12]:
from pathlib import Path
import networkx as nx
from tqdm import tqdm        # pip install tqdm if missing

# ----------------------------------------------------------------------
# config
# ----------------------------------------------------------------------
ROOT       = Path("/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/dgd/data/NIGs_4_inputs")
N_MATCHES  = 10

# ----------------------------------------------------------------------
# helper – yield every nx.DiGraph contained in the pickle object
# ----------------------------------------------------------------------
def _extract_graphs(obj):
    if isinstance(obj, nx.DiGraph):
        yield obj
    elif isinstance(obj, (list, tuple)):
        for item in obj:
            yield from _extract_graphs(item)
    elif isinstance(obj, dict):
        for v in obj.values():
            yield from _extract_graphs(v)

# ----------------------------------------------------------------------
# main scan
# ----------------------------------------------------------------------
matches = []   # [(file_path, graph, po, pred, fanin)]
all_pickles = sorted(ROOT.glob("*_NIG_unoptimized.pkl"))

for pkl_path in tqdm(all_pickles, desc="Scanning graphs", unit="file"):
    graph_obj = load_graph_pickle(pkl_path)  # your existing loader

    for G in _extract_graphs(graph_obj):
        PO = [n for n, deg in G.out_degree() if deg == 0]

        for po in PO:
            for pred in G.predecessors(po):
                if G.in_degree(pred) == 2:
                    fanin = list(G.predecessors(pred))
                    matches.append((pkl_path, G, po, pred, fanin))

                    idx = len(matches)
                    print(f"\n=== MATCH #{idx} ===")
                    print(f"file   : {pkl_path.name}")
                    print(f"PO     : {po}")
                    print(f"pred   : {pred} (in_degree=2)")
                    print(f"fan-in : {fanin}")
                    print("======================")

                    if idx >= N_MATCHES:
                        break
            if len(matches) >= N_MATCHES:
                break
        if len(matches) >= N_MATCHES:
            break
    if len(matches) >= N_MATCHES:
        break

print(f"\nCollected {len(matches)} matching graphs in `matches`.")


Scanning graphs:   0%|          | 3/65535 [00:00<43:28, 25.12file/s]


=== MATCH #1 ===
file   : 0x0001_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [12, 13]

=== MATCH #2 ===
file   : 0x0002_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [3, 12]

=== MATCH #3 ===
file   : 0x0004_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [11, 12]

=== MATCH #4 ===
file   : 0x0008_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [3, 11]


Scanning graphs:   0%|          | 24/65535 [00:00<24:22, 44.80file/s]


=== MATCH #5 ===
file   : 0x0010_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [11, 12]


Scanning graphs:   0%|          | 41/65535 [00:00<21:31, 50.70file/s]


=== MATCH #6 ===
file   : 0x0020_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [3, 11]


Scanning graphs:   0%|          | 70/65535 [00:01<25:05, 43.49file/s]


=== MATCH #7 ===
file   : 0x0040_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [10, 11]


Scanning graphs:   0%|          | 134/65535 [00:02<24:59, 43.63file/s]


=== MATCH #8 ===
file   : 0x0080_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [3, 10]


Scanning graphs:   0%|          | 266/65535 [00:05<18:54, 57.55file/s]


=== MATCH #9 ===
file   : 0x0100_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [11, 12]


Scanning graphs:   1%|          | 511/65535 [00:10<22:15, 48.68file/s]


=== MATCH #10 ===
file   : 0x0200_NIG_unoptimized.pkl
PO     : 4
pred   : 7 (in_degree=2)
fan-in : [3, 11]

Collected 10 matching graphs in `matches`.





In [11]:
#circuit_hex = "0x4B9E"
circuit_hex = "0x3AC7"
current_solution = load_graph_pickle(f'/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/dgd/data/NIGs_4_inputs/{circuit_hex}_NIG_unoptimized.pkl')


In [14]:
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

for tgt in current_solution.nodes():
    if tgt in source_nodes:
        continue  # skip inputs

    if (
        current_solution.out_degree(tgt) == 0
    ):
        continue

    # ----- enumerate size-≤4 cuts ---------------------------------
    for cut in exhaustive_cut_enumeration_dag(
        current_solution, 4, tgt, filter_redundant=True
    ):
        if len(cut) < 4:
            continue
        
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False) 
        
        #find and print isomorphic graphs in catalog that can match sg, call them replacepent_graphs
      
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue
            
        #test the substitution of each graph in replacepent_graph    
            
        new_solution = substitute_subgraph(current_solution, sg, replacepent_graph)
        
        #test that the subsitution is correct 
        
        calculate_truth_table_v2(new_solution) == calculate_truth_table_v2(current_solution)
            


NameError: name 'replacepent_graph' is not defined

In [None]:
sg = generate_subgraph(current_solution, 9, (7, 36, 33, 5), draw=True)

Similar to current environment 

In [None]:
import itertools, time, numpy as np
from collections import defaultdict

# ----------------------------------------------------------------------
# 0 .  Allocate mask & storage (same shapes as in the env)
# ----------------------------------------------------------------------
mask                  = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action  = defaultdict(list)      # aid → [(tgt, cut, key), …]
tt_counter            = defaultdict(list)      # key → [(tgt, cut)]

# ----------------------------------------------------------------------
# 1 .  Primary-input nodes
# ----------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ----------------------------------------------------------------------
# 2 .  Enumerate targets and  ≤4-input cones
# ----------------------------------------------------------------------
for tgt in current_solution.nodes():
    if tgt in source_nodes:
        continue
    if current_solution.out_degree(tgt) == 0:
        continue                               # skip primary output

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        key = _truth_key(sg)                   # (#inputs , tt-int)
        for aid in TTABLE_TO_ACTIONS.get(key, []):
            mask[aid] = True
            subgraphs_for_action[aid].append((tgt, tuple(cut), key))
        tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")

# ----------------------------------------------------------------------
# 3 .  Try every (action, tgt, cut) pair once and test equivalence
# ----------------------------------------------------------------------
tt_orig = calculate_truth_table_v2(current_solution)
n_pass  = n_fail = 0

for aid, triples in subgraphs_for_action.items():
    for tgt, cut, key in triples:

        repl = _permute_until_match(UNIQUE_GRAPHS[aid], key, aid)
        #repl = UNIQUE_GRAPHS[aid]
        if repl is None:
            print(f"  aid {aid:<4}  permutation ❌  (skipped)")
            n_fail += 1
            continue

        cand = generate_subgraph(current_solution, tgt, cut, draw=False)
        new_sol = substitute_subgraph(current_solution, cand, repl)

        ok = (calculate_truth_table_v2(new_sol) == tt_orig)
        print(f"  aid {aid:<4}  tgt {tgt:<3}  cut {cut}  "
              f"{'✔︎ PASS' if ok else '✗ FAIL'}")

        n_pass += ok
        n_fail += (not ok)

print(f"\nSummary: {n_pass} PASS   {n_fail} FAIL")

Trying to include all Boolean functions by sg

In [None]:
import itertools, time, numpy as np
from collections import defaultdict

# ----------------------------------------------------------------------
# 0 .  Allocate mask & storage (same shapes as in the env)
# ----------------------------------------------------------------------
mask                  = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action  = defaultdict(list)      # aid → [(tgt, cut, key), …]
tt_counter            = defaultdict(list)      # key → [(tgt, cut)]

# ----------------------------------------------------------------------
# 1 .  Primary-input nodes
# ----------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ----------------------------------------------------------------------
# 2 .  Enumerate targets and  ≤4-input cones
# ----------------------------------------------------------------------
for tgt in current_solution.nodes():
    if tgt in source_nodes:
        continue
    if current_solution.out_degree(tgt) == 0:
        continue                               # skip primary output

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        --> Update to include all keys sg can produce when the source nodes in sg are permuted. For now, avoid cuts that include source_nodes.
        
        key = _truth_key(sg)                   # (#inputs , tt-int)
        for aid in TTABLE_TO_ACTIONS.get(key, []):
            mask[aid] = True
            subgraphs_for_action[aid].append((tgt, tuple(cut), key))
        tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")

# ----------------------------------------------------------------------
# 3 .  Try every (action, tgt, cut) pair once and test equivalence
# ----------------------------------------------------------------------
tt_orig = calculate_truth_table_v2(current_solution)
n_pass  = n_fail = 0

for aid, triples in subgraphs_for_action.items():
    for tgt, cut, key in triples:

        repl = _permute_until_match(UNIQUE_GRAPHS[aid], key, aid)
        if repl is None:
            print(f"  aid {aid:<4}  permutation ❌  (skipped)")
            n_fail += 1
            continue

        cand = generate_subgraph(current_solution, tgt, cut, draw=False)
        new_sol = substitute_subgraph(current_solution, cand, repl)

        ok = (calculate_truth_table_v2(new_sol) == tt_orig)
        print(f"  aid {aid:<4}  tgt {tgt:<3}  cut {cut}  "
              f"{'✔︎ PASS' if ok else '✗ FAIL'}")

        n_pass += ok
        n_fail += (not ok)

print(f"\nSummary: {n_pass} PASS   {n_fail} FAIL")



In [None]:
TTABLE_TO_ACTIONS.get((2, 2), [])

In [None]:
import itertools, time, numpy as np
from collections import defaultdict

# ----------------------------------------------------------------------
# 0 .  Allocate mask & storage (same shapes as in the env)
# ----------------------------------------------------------------------
mask                  = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action  = defaultdict(list)      # aid → [(tgt, cut, key), …]
tt_counter            = defaultdict(list)      # key → [(tgt, cut)]

# ----------------------------------------------------------------------
# 1 .  Primary-input nodes
# ----------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ----------------------------------------------------------------------
# 2 .  Enumerate targets and  ≤4-input cones
# ----------------------------------------------------------------------

def _all_perm_keys(sg: nx.DiGraph) -> set[int]:
    """Return every truth-table key the sub-graph `sg` can realise under
    all permutations of its primary-input nodes."""
    in_nodes = [n for n in sg.nodes() if sg.in_degree(n) == 0]
    if len(in_nodes) <= 1:                       # nothing to permute
        return {_truth_key(sg)}

    keys = set()
    for perm in itertools.permutations(in_nodes):
        mapping = dict(zip(in_nodes, perm))      # old → new
        sg_perm  = nx.relabel_nodes(sg, mapping, copy=True)
        keys.add(_truth_key(sg_perm))
    return keys


for tgt in current_solution.nodes():
    # Skip primary inputs and primary outputs
    if tgt in source_nodes or current_solution.out_degree(tgt) == 0:
        continue

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):

        # (1)  avoid cones that “grab” a primary input
        if any(n in source_nodes for n in cut):
            continue

        # (2)  standalone / fan-out-free check
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        # (3)  build the candidate sub-graph
        sg = generate_subgraph(current_solution, tgt, cut, draw=False)

        # ensure every input in sg really comes from `cut`
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        # ------------------------------------------------------------------
        # NEW: consider **every** truth-table key producible by permuting
        #      the sources in `sg`
        # ------------------------------------------------------------------
        for key in _all_perm_keys(sg):
            for aid in TTABLE_TO_ACTIONS.get(key, []):
                mask[aid] = True
                subgraphs_for_action[aid].append((tgt, tuple(cut), key))
            tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")

# ----------------------------------------------------------------------
# 3 .  Try every (action, tgt, cut) pair once and test equivalence
# ----------------------------------------------------------------------
tt_orig = calculate_truth_table_v2(current_solution)
n_pass  = n_fail = 0

for aid, triples in subgraphs_for_action.items():
    for tgt, cut, key in triples:

        repl = _permute_until_match(UNIQUE_GRAPHS[aid], key, aid)
        #repl = UNIQUE_GRAPHS[aid]
        if repl is None:
            print(f"  aid {aid:<4}  permutation ❌  (skipped)")
            n_fail += 1
            continue

        cand = generate_subgraph(current_solution, tgt, cut, draw=False)
        
        
                
        new_sol = substitute_subgraph(current_solution, cand, repl)

        ok = (calculate_truth_table_v2(new_sol) == tt_orig)
        print(f"  aid {aid:<4}  tgt {tgt:<3}  cut {cut}  "
              f"{'✔︎ PASS' if ok else '✗ FAIL'}"
             f"     key {key}"   )

        n_pass += ok
        n_fail += (not ok)

print(f"\nSummary: {n_pass} PASS   {n_fail} FAIL")


Test 1 if DAG permutations are missing


In [None]:
from itertools import permutations
import networkx as nx
from tqdm import tqdm          # or `from tqdm.auto import tqdm` if you switch between notebook/terminal

missing = 0
for aid, g in enumerate(tqdm(UNIQUE_GRAPHS, desc="Scanning DAGs")):
    pins  = [n for n in g if g.in_degree(n) == 0]

    # compute every permutation key
    keys  = {
        _truth_key(nx.relabel_nodes(g, dict(zip(pins, p))))
        for p in permutations(pins)
    }

    known = sum(k in TTABLE_TO_ACTIONS for k in keys)
    if known < len(keys):
        missing += 1
        print(f"DAG {aid}: {known}/{len(keys)} permutations present")

print(f"{missing} unique DAGs still missing at least one permutation key")



Test 2 if DAG permutations are missing 

In [None]:
from itertools import permutations
import networkx as nx
from tqdm import tqdm

missing = 0
for aid, g in enumerate(tqdm(UNIQUE_GRAPHS, desc="Scanning DAGs")):
    pins = [n for n in g if g.in_degree(n) == 0]
    for p in permutations(pins):
        k = _truth_key(nx.relabel_nodes(g, dict(zip(pins, p))))
        if aid not in TTABLE_TO_ACTIONS.get(k, []):
            missing += 1
            print(f"aid {aid}: permutation {p} (key={k}) missing")
            break            # one report per aid is enough
print(f"{missing} actions missing at least one permutation key–mapping")


Test if same DAGs are available to the agent whether we permute sg or not

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Permutation-coverage test (informational)

Goal:
    Show whether looking up ONE permutation of a cone’s inputs already
    exposes every action that appears for ANY permutation.

For each cone:
    actions_default ← actions for ONE pin-ordering   (what RL uses)
    actions_all     ← union over ALL pin-orderings
    If actions_all ⊈ actions_default → print coverage gap

The script never raises; it just reports and continues processing.
"""

import itertools
import copy
from collections import defaultdict

import networkx as nx
import numpy as np

# ── project-specific symbols expected from the surrounding codebase ────
# NUM_ACTIONS
# current_solution : nx.DiGraph
# exhaustive_cut_enumeration_dag(...)
# is_fanout_free_standalone(...)
# generate_subgraph(...)
# _truth_key(...)
# TTABLE_TO_ACTIONS : dict[tt_key → list[aid]]
# UNIQUE_GRAPHS : dict[aid → nx.DiGraph]
# _permute_until_match(...)
# substitute_subgraph(...)
# calculate_truth_table_v2(...)

# ══════════════════════════════════════════════════════════════════════
# helper – gather action IDs under one key vs all N! keys
# ══════════════════════════════════════════════════════════════════════
def _all_perm_actions(sg: nx.DiGraph, lookup):
    """Return (actions_default, actions_all) for a cone `sg`."""
    pins = [n for n in sg if sg.in_degree(n) == 0]

    # actions for ONE pin ordering (default key)
    actions_default = set(lookup.get(_truth_key(sg), []))

    # actions for ALL N! pin permutations (union)
    actions_all = set()
    for σ in itertools.permutations(pins):
        sg_perm = nx.relabel_nodes(sg, dict(zip(pins, σ)), copy=True)
        actions_all |= set(lookup.get(_truth_key(sg_perm), []))

    return actions_default, actions_all


# ══════════════════════════════════════════════════════════════════════
# 0 .  Allocate mask & bookkeeping
# ══════════════════════════════════════════════════════════════════════
mask                 = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action = defaultdict(list)
tt_counter           = defaultdict(list)

# ══════════════════════════════════════════════════════════════════════
# 1 .  Identify source nodes
# ══════════════════════════════════════════════════════════════════════
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ══════════════════════════════════════════════════════════════════════
# 2 .  Enumerate ≤4-input cones and test permutation coverage
# ══════════════════════════════════════════════════════════════════════
missing_any = False
for tgt in current_solution.nodes():
    if tgt in source_nodes:
        continue
    if current_solution.out_degree(tgt) == 0:          # skip primary outputs
        continue

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        # ── permutation-coverage test (informational) ─────────────────
        a_def, a_all = _all_perm_actions(sg, TTABLE_TO_ACTIONS)

        if not a_all:                 # cone not in catalogue → skip
            continue

        if not a_all.issubset(a_def):
            hidden = sorted(a_all - a_def)
            print(f"[coverage gap] tgt={tgt} cut={cut}  hidden actions={hidden}")
            missing_any = True

        # keep single-key behaviour for downstream RL loop
        key = _truth_key(sg)
        for aid in a_def:                      # identical to lookup(key)
            mask[aid] = True
            subgraphs_for_action[aid].append((tgt, tuple(cut), key))
        tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")
if not missing_any:
    print("✓ Every cone: one permutation already exposes ALL actions")
else:
    print("⚠ Some cones hide actions when relying on a single permutation")

# ══════════════════════════════════════════════════════════════════════
# 3 .  Rewrite once per (aid, tgt, cut) and verify equivalence
#     (unchanged from your original pipeline)
# ══════════════════════════════════════════════════════════════════════
tt_orig   = calculate_truth_table_v2(current_solution)
seen_site = set()
n_pass = n_fail = 0

for aid, triples in subgraphs_for_action.items():
    for tgt, cut, key in triples:
        if (aid, tgt, cut) in seen_site:
            continue
        seen_site.add((aid, tgt, cut))

        repl = _permute_until_match(UNIQUE_GRAPHS[aid], key, aid)
        if repl is None:
            print(f"  aid {aid:<4}  permutation ❌  (skipped)")
            n_fail += 1
            continue

        cand    = generate_subgraph(current_solution, tgt, cut, draw=False)
        new_sol = substitute_subgraph(current_solution, cand, repl)

        ok = calculate_truth_table_v2(new_sol) == tt_orig
        print(f"  aid {aid:<4}  tgt {tgt:<3}  cut {cut}  "
              f"{'✔︎ PASS' if ok else '✗ FAIL'}")
        n_pass += ok
        n_fail += (not ok)

print(f"\nSummary: {n_pass} PASS   {n_fail} FAIL")


Check if for a given action, it would be able to replace the sg regardless of the Boolean function tested for the sg

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Per-cone / per-action permutation-persistence test

For every candidate cone (≤4 inputs) in `current_solution`:
    • Get the default key’s action list  (what your RL loop sees).
    • For **each** action in that list, verify that permuting the cone’s
      inputs never loses the action in the catalogue lookup.
Any missing (cone, action, permutation) triples are logged.
"""

import itertools
from collections import defaultdict

import networkx as nx
import numpy as np

# ---------------------------------------------------------------------
# expected globals in the environment:
#   NUM_ACTIONS
#   current_solution : nx.DiGraph
#   exhaustive_cut_enumeration_dag(...)
#   is_fanout_free_standalone(...)
#   generate_subgraph(...)
#   _truth_key(...)
#   TTABLE_TO_ACTIONS
# ---------------------------------------------------------------------


import itertools
import networkx as nx

def _check_cone_action_persistence(
        sg: nx.DiGraph,
        lookup: dict,
        *,
        verbose: bool = True,
        indent: str = "") -> dict[int, list[int]]:
    """
    Verify that every action in the *default* key's action list
    remains present for every permutation of the cone's input pins.

    Parameters
    ----------
    sg : nx.DiGraph
        The cone (sub-graph) to test.
    lookup : dict[(num_inputs, tt_int) -> list[int]]
        TTABLE_TO_ACTIONS catalogue.
    verbose : bool, default False
        If True, prints internal details for debugging.
    indent : str, default ""
        String prepended to every verbose line (helps nesting).

    Returns
    -------
    gaps : dict[int, list[int]]
        {aid: [perm_idx, …]}  – empty dict ⇒ no gaps.
    """
    pins = [n for n in sg if sg.in_degree(n) == 0]

    key_default = _truth_key(sg)
    actions_default = lookup.get(key_default, [])

    if verbose:
        print(f"{indent}pins           : {pins}")
        print(f"{indent}default key    : {key_default}")
        print(f"{indent}actions_default: {sorted(actions_default)}")

    # early exit – no actions at all
    if not actions_default:
        if verbose:
            print(f"{indent}(cone not in catalogue – skipping)")
        return {}

    gaps = {aid: [] for aid in actions_default}

    for idx, perm in enumerate(itertools.permutations(pins)):
        sg_perm = nx.relabel_nodes(sg, dict(zip(pins, perm)), copy=True)
        key_perm = _truth_key(sg_perm)
        actions_perm = lookup.get(key_perm, [])

        if verbose:
            print(f"{indent}perm {idx:>2}: {perm}  "
                  f"key={key_perm}  actions={sorted(actions_perm)}")

        for aid in actions_default:
            if aid not in actions_perm:
                gaps[aid].append(idx)
                if verbose:
                    print(f"{indent}  └─ missing aid {aid} in perm {idx}")

    # remove aids with no gaps
    gaps = {aid: idxs for aid, idxs in gaps.items() if idxs}

    if verbose:
        if gaps:
            print(f"{indent}GAPS → {gaps}\n")
        else:
            print(f"{indent}✓ all default actions persist across permutations\n")

    return gaps



# ---------------------------------------------------------------------
# driver – enumerate cones and run the test
# ---------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}
total_cones   = 0
cones_with_gap = 0

for tgt in current_solution.nodes():
    if tgt in source_nodes or current_solution.out_degree(tgt) == 0:
        continue  # skip primary inputs and outputs

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        total_cones += 1
        gaps = _check_cone_action_persistence(sg, TTABLE_TO_ACTIONS)

        if gaps:
            cones_with_gap += 1
            print(f"[gap] tgt={tgt:>3} cut={cut}  "
                  f"missing={{aid:perm_idx_list}} → {gaps}")

print("\n=== summary ===")
if cones_with_gap == 0:
    print(f"✓ All {total_cones} cones: every default action persists "
          "across all input-pin permutations.")
else:
    print(f"⚠ {cones_with_gap} / {total_cones} cones lose at least one "
          "default action for some permutation (see logs above).")


Testing if random permutations change topology (using hashing)

In [None]:
import itertools, time, numpy as np
from collections import defaultdict

# ----------------------------------------------------------------------
# 0 .  Allocate mask & storage (same shapes as in the env)
# ----------------------------------------------------------------------
mask                  = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action  = defaultdict(list)      # aid → [(tgt, cut, key), …]
tt_counter            = defaultdict(list)      # key → [(tgt, cut)]

# ----------------------------------------------------------------------
# 1 .  Primary-input nodes
# ----------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ----------------------------------------------------------------------
# 2 .  Enumerate targets and  ≤4-input cones
# ----------------------------------------------------------------------
for tgt in current_solution.nodes():
    if tgt in source_nodes:
        continue
    if current_solution.out_degree(tgt) == 0:
        continue                               # skip primary output

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        key = _truth_key(sg)                   # (#inputs , tt-int)
        for aid in TTABLE_TO_ACTIONS.get(key, []):
            mask[aid] = True
            subgraphs_for_action[aid].append((tgt, tuple(cut), key))
        tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")

# ----------------------------------------------------------------------
# 3 .  For every (action, tgt, cut) pair:
#     • generate *all* matching permutations
#     • apply each one
#     • test functional + topological equivalence
# ----------------------------------------------------------------------

def _all_perms_matching(base: nx.DiGraph,
                        key: tuple[int, int]) -> list[nx.DiGraph]:
    """Return *all* permuted copies of `base` whose truth-table = `key`."""
    inputs = [n for n in base if base.in_degree(n) == 0]
    if len(inputs) != key[0]:
        return []

    matches = []
    for perm in itertools.permutations(inputs):
        g2 = nx.relabel_nodes(base, dict(zip(inputs, perm)), copy=True)
        if _truth_key(g2) == key:
            matches.append(g2)
    return matches


tt_orig   = calculate_truth_table_v2(current_solution)
n_pass    = n_fail = 0
multi_top = 0                       # how many triples produced ≥2 topologies?

for aid, triples in subgraphs_for_action.items():
    template = UNIQUE_GRAPHS[aid]

    for tgt, cut, key in triples:
        # -------- enumerate every matching permutation -------------
        rep_list = _all_perms_matching(template, key)

        if not rep_list:
            print(f"  aid {aid:<4} tgt {tgt:<3} cut {cut}  no-match ❌")
            n_fail += 1
            continue

        topo_hashes = set()
        all_ok = True

        # -------- apply each permutation ---------------------------
        cand = generate_subgraph(current_solution, tgt, cut, draw=False)

        for repl in rep_list:
            new_sol  = substitute_subgraph(current_solution, cand, repl)
            ok = (calculate_truth_table_v2(new_sol) == tt_orig)

            # canonicalise so label shuffles collapse
            topo_hashes.add(canon_hash(new_sol))
            all_ok &= ok

        # -------- reporting ----------------------------------------
        label = "✔︎ PASS" if all_ok else "✗ FAIL"
        n_pass += all_ok
        n_fail += (not all_ok)

        topo_note = ""
        if len(topo_hashes) > 1:
            topo_note = f" ({len(topo_hashes)} distinct topologies!)"
            multi_top += 1

        print(f"  aid {aid:<4} tgt {tgt:<3} cut {cut}  {label}{topo_note}")

print(f"\nSummary: {n_pass} PASS   {n_fail} FAIL   "
      f"{multi_top} triples produced ≥2 distinct topologies")


Testing if random permutations change topology (using iso test)

In [None]:
import itertools, time, numpy as np
from collections import defaultdict

# ----------------------------------------------------------------------
# 0 .  Allocate mask & storage (same shapes as in the env)
# ----------------------------------------------------------------------
mask                  = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action  = defaultdict(list)      # aid → [(tgt, cut, key), …]
tt_counter            = defaultdict(list)      # key → [(tgt, cut)]

# ----------------------------------------------------------------------
# 1 .  Primary-input nodes
# ----------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ----------------------------------------------------------------------
# 2 .  Enumerate targets and  ≤4-input cones
# ----------------------------------------------------------------------
for tgt in current_solution.nodes():
    if tgt in source_nodes:
        continue
    if current_solution.out_degree(tgt) == 0:
        continue                               # skip primary output

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        key = _truth_key(sg)                   # (#inputs , tt-int)
        for aid in TTABLE_TO_ACTIONS.get(key, []):
            mask[aid] = True
            subgraphs_for_action[aid].append((tgt, tuple(cut), key))
        tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")


# ----------------------------------------------------------------------
# 3 .  For every (aid, tgt, cut) triple:
#     • generate *all* matching permutations
#     • apply each one
#     • test functional equivalence
#     • cluster by *isomorphism* (no hashing)
# ----------------------------------------------------------------------

from networkx.algorithms import isomorphism as iso

def _all_perms_matching(base: nx.DiGraph,
                        key: tuple[int, int]) -> list[nx.DiGraph]:
    """Return every permuted copy of `base` that matches `key`."""
    ins = [n for n in base if base.in_degree(n) == 0]
    if len(ins) != key[0]:
        return []
    out = []
    for perm in itertools.permutations(ins):
        g2 = nx.relabel_nodes(base, dict(zip(ins, perm)), copy=True)
        if _truth_key(g2) == key:
            out.append(g2)
    return out


tt_orig   = calculate_truth_table_v2(current_solution)
n_pass    = n_fail = 0
multi_top = 0          # #triples that yield ≥ 2 non-isomorphic topologies

for aid, triples in subgraphs_for_action.items():
    template = UNIQUE_GRAPHS[aid]

    for tgt, cut, key in triples:
        # --- all template permutations that satisfy `key`
        perm_list = _all_perms_matching(template, key)
        if not perm_list:
            print(f"  aid {aid:<4} tgt {tgt:<3} cut {cut}  no-match ❌")
            n_fail += 1
            continue

        cand = generate_subgraph(current_solution, tgt, cut, draw=False)
        topo_reps: list[nx.DiGraph] = []   # distinct topologies collected so far
        all_ok = True

        for repl in perm_list:
            new_sol = substitute_subgraph(current_solution, cand, repl)
            # functional check
            ok = (calculate_truth_table_v2(new_sol) == tt_orig)
            all_ok &= ok

            # --- isomorphism clustering ---------------------------------
            placed = False
            for rep in topo_reps:
                if iso.is_isomorphic(new_sol, rep):
                    placed = True
                    break
            if not placed:
                topo_reps.append(new_sol)

        # ----- reporting ------------------------------------------------
        label = "✔︎ PASS" if all_ok else "✗ FAIL"
        extra = ""
        if len(topo_reps) > 1:
            extra = f" ({len(topo_reps)} distinct topologies!)"
            multi_top += 1

        print(f"  aid {aid:<4} tgt {tgt:<3} cut {cut}  {label}{extra}")
        n_pass += all_ok
        n_fail += (not all_ok)

print(f"\nSummary: {n_pass} PASS   {n_fail} FAIL   "
      f"{multi_top} triples produced ≥2 distinct topologies")


Test if checking for the boolean function of the subgraph is enough to cover all possible permutation of the template that can be used to replace the subgraph

In [None]:
import itertools, time, numpy as np
from collections import defaultdict

# ----------------------------------------------------------------------
# 0 .  Allocate mask & storage (same shapes as in the env)
# ----------------------------------------------------------------------
mask                  = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action  = defaultdict(list)      # aid → [(tgt, cut, key), …]
tt_counter            = defaultdict(list)      # key → [(tgt, cut)]

# ----------------------------------------------------------------------
# 1 .  Primary-input nodes
# ----------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ----------------------------------------------------------------------
# 2 .  Enumerate targets and  ≤4-input cones
# ----------------------------------------------------------------------
for tgt in current_solution.nodes():
    if tgt in source_nodes or current_solution.out_degree(tgt) == 0:
        continue                               # skip primary inputs / outputs

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        key = _truth_key(sg)                   # (#inputs , tt-int)
        for aid in TTABLE_TO_ACTIONS.get(key, []):
            mask[aid] = True
            subgraphs_for_action[aid].append((tgt, tuple(cut), key))
        tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")

# ----------------------------------------------------------------------
# 3 .  Try all template permutations and analyse every EXTRA HIT
# ----------------------------------------------------------------------

from networkx.algorithms import isomorphism as iso
import math, itertools

# ---------- tiny helpers ------------------------------------------------
def _all_perms(base: nx.DiGraph):
    ins = [n for n in base if base.in_degree(n) == 0]
    return [
        nx.relabel_nodes(base, dict(zip(ins, p)), copy=True)
        for p in itertools.permutations(ins)
    ]

def _truth_bits(tt_int, n):
    size = 1 << n
    return [(tt_int >> k) & 1 for k in range(size)]

def influential_inputs(tt_int, n):
    """Return set of input indices that affect the output."""
    bits = _truth_bits(tt_int, n)
    infl = set()
    for i in range(n):
        stride = 1 << i
        for base in range(0, 1 << n, 2 * stride):
            for off in range(stride):
                if bits[base + off] != bits[base + off + stride]:
                    infl.add(i)
                    break
            if i in infl:
                break
    return infl
# ------------------------------------------------------------------------

tt_orig      = calculate_truth_table_v2(current_solution)

n_pass = n_fail = 0
multi_top = extra_hit = extra_newtop = 0
reason_tally = defaultdict(int)   # REDUND / SYM / RENAME counters

for aid, triples in subgraphs_for_action.items():
    template           = UNIQUE_GRAPHS[aid]
    all_template_perms = _all_perms(template)          # 1× per aid

    for tgt, cut, key in triples:

        # ---------- permutations that *match* the truth-key --------------
        prim_perms = [g for g in all_template_perms if _truth_key(g) == key]
        if not prim_perms:
            print(f"  aid {aid:<4} tgt {tgt:<3} cut {cut}  no-match ❌")
            n_fail += 1
            continue

        cand = generate_subgraph(current_solution, tgt, cut, draw=False)

        prim_tops, all_tops = [], []
        all_ok = True

        for repl in prim_perms:
            new_sol = substitute_subgraph(current_solution, cand, repl)
            ok      = calculate_truth_table_v2(new_sol) == tt_orig
            all_ok &= ok
            if not any(iso.is_isomorphic(new_sol, rep) for rep in prim_tops):
                prim_tops.append(new_sol)
                all_tops.append(new_sol)

        # ---------- permutations that *do NOT* match the truth-key -------
        extra_ok = extra_new_ok = False
        reason   = None          # will hold REDUND / SYM / RENAME

        for repl in (g for g in all_template_perms if _truth_key(g) != key):
            new_sol = substitute_subgraph(current_solution, cand, repl)
            if calculate_truth_table_v2(new_sol) != tt_orig:
                continue

            extra_ok = True

            if not reason:       # analyse only the first qualifying repl
                n_in  = len(cut)
                tt_cone_int   = key[1]                    # cone’s TT int
                infl = influential_inputs(tt_cone_int, n_in)

                if len(infl) < n_in:
                    reason = "REDUND"                     # unused inputs
                else:
                    # is the function symmetric?
                    ins_cone = [n for n in cand if cand.in_degree(n) == 0]
                    tt_set = {
                        _truth_key(
                            nx.relabel_nodes(cand,
                                             dict(zip(ins_cone, p)),
                                             copy=True)
                        )[1]
                        for p in itertools.permutations(ins_cone)
                    }
                    if len(tt_set) < math.factorial(n_in):
                        reason = "SYM"
                    else:
                        reason = "RENAME"
                reason_tally[reason] += 1

            # isomorphic to something we already saw?
            if not any(iso.is_isomorphic(new_sol, rep) for rep in all_tops):
                extra_new_ok = True
                all_tops.append(new_sol)

        # ---------- accounting -------------------------------------------
        topo_cnt = len(all_tops)
        if topo_cnt > 1:
            multi_top += 1

        if extra_ok:
            extra_hit += 1
            if extra_new_ok:
                extra_newtop += 1

        # ---------- per-triple report -----------------------------------
        label   = "✔︎ PASS" if all_ok else "✗ FAIL"
        details = []

        if topo_cnt > 1:
            details.append(f"{topo_cnt} topologies")

        if extra_ok:
            tag = "✨ EXTRA HIT"
            if extra_new_ok:
                tag += " + NEW topo"
            if reason:
                tag += f" [{reason}]"
            details.append(tag)

        suffix = ("  " + " · ".join(details)) if details else ""
        print(f"  aid {aid:<4} tgt {tgt:<3} cut {cut}  {label}{suffix}")

        n_pass += all_ok
        n_fail += (not all_ok)

# ---------------- final summary ----------------------------------------
print("\nSummary:")
print(f"  {n_pass} PASS   {n_fail} FAIL")
print(f"  {multi_top} triples → ≥ 2 topologies")
print(f"  {extra_hit} triples had EXTRA HITs "
      f"({extra_newtop} introduced NEW topologies)")
if extra_hit:
    print("  Reasons for EXTRA HITs:")
    for r, cnt in reason_tally.items():
        print(f"     {r:<6}: {cnt}")


[mask] 1195 rewrite actions enabled
  aid 18   tgt 5   cut (0, 1)  ✔︎ PASS
  aid 18   tgt 6   cut (33, 34)  ✔︎ PASS
  aid 18   tgt 7   cut (3, 35)  ✔︎ PASS
  aid 18   tgt 8   cut (36, 37)  ✔︎ PASS
  aid 18   tgt 9   cut (7, 8)  ✔︎ PASS
  aid 18   tgt 10  cut (0, 38)  ✔︎ PASS
  aid 18   tgt 11  cut (2, 39)  ✔︎ PASS
  aid 18   tgt 12  cut (3, 40)  ✔︎ PASS
  aid 18   tgt 13  cut (12, 41)  ✔︎ PASS
  aid 18   tgt 14  cut (42, 43)  ✔︎ PASS
  aid 18   tgt 15  cut (3, 44)  ✔︎ PASS
  aid 18   tgt 16  cut (15, 45)  ✔︎ PASS
  aid 18   tgt 17  cut (1, 46)  ✔︎ PASS
  aid 18   tgt 18  cut (2, 47)  ✔︎ PASS
  aid 18   tgt 19  cut (3, 48)  ✔︎ PASS
  aid 18   tgt 20  cut (19, 49)  ✔︎ PASS
  aid 18   tgt 21  cut (50, 51)  ✔︎ PASS
  aid 18   tgt 22  cut (21, 52)  ✔︎ PASS
  aid 18   tgt 23  cut (53, 54)  ✔︎ PASS
  aid 18   tgt 24  cut (2, 55)  ✔︎ PASS
  aid 18   tgt 25  cut (56, 57)  ✔︎ PASS
  aid 18   tgt 26  cut (25, 58)  ✔︎ PASS
  aid 18   tgt 27  cut (59, 60)  ✔︎ PASS
  aid 18   tgt 28  cut (3, 61)  ✔︎

Visual testing if random permutations change topology

In [None]:
# ──────────────────────────────────────────────────────────────────────────
#  VISUALISE non-isomorphic TOPOLOGIES  – edges flush with nodes
# ──────────────────────────────────────────────────────────────────────────
import itertools
import matplotlib.pyplot as plt
import networkx as nx
from networkx.algorithms.isomorphism import DiGraphMatcher

MAX_EXAMPLES = 50
SEED         = 42

# ------------------------------------------------------------------ helpers
def _all_perms_matching(base: nx.DiGraph, key):
    ins = [n for n in base if base.in_degree(n) == 0]
    if len(ins) != key[0]:
        return []
    out = []
    for perm in itertools.permutations(ins):
        g2 = nx.relabel_nodes(base, dict(zip(ins, perm)), copy=True)
        if _truth_key(g2) == key:
            out.append(g2)
    return out


def _energy(g):
    e, _ = energy_score(g, check_implicit_OR_existence_v2)
    return e


def _edge_colors_diff(g1, g2):
    """black = common red = only in g1 blue = only in g2"""
    GM = DiGraphMatcher(g1, g2)
    mapping = next(GM.isomorphisms_iter(), None)

    def canon_edges(G, m=None):
        if m is None:
            return {tuple(sorted(edge)) for edge in G.edges()}
        return {tuple(sorted((m[u], m[v]))) for u, v in G.edges()}

    E1 = canon_edges(g1)
    E2 = canon_edges(g2, mapping) if mapping else canon_edges(g2)
    common = E1 & E2
    col1 = ["black" if tuple(sorted(e)) in common else "red"
            for e in g1.edges()]
    col2 = ["black" if tuple(sorted(e)) in common else "blue"
            for e in g2.edges()]
    return col1, col2


def _draw_graph(ax, G, title, pos=None, edge_colors="black"):
    if pos is None:
        pos = nx.spring_layout(G, seed=SEED)

    # draw edges first, nodes on top
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors,
                           width=2.4, arrows=True, ax=ax)
    nx.draw_networkx_nodes(G, pos, node_size=150, alpha=0.8,
                           node_color="#1f78b4", edgecolors="#333",
                           linewidths=0.6, ax=ax)
    nx.draw_networkx_labels(G, pos, font_size=7, ax=ax)
    ax.set_title(title, fontsize=10)
    ax.axis("off")
    return pos

# ------------------------------------------------ find non-iso variants
def _find_non_iso_variants(template, cand, key):
    reps = []
    for repl in _all_perms_matching(template, key):
        g = substitute_subgraph(current_solution, cand, repl)
        if not any(nx.is_isomorphic(g, r) for _, r in reps):
            reps.append((_energy(g), g))
    return reps

# ------------------------------------------------ main driver
def show_multi_topologies(limit=MAX_EXAMPLES):
    shown = 0
    for aid, triples in subgraphs_for_action.items():
        template = UNIQUE_GRAPHS[aid]

        for tgt, cut, key in triples:
            cand = generate_subgraph(current_solution, tgt, cut, draw=False)
            reps = _find_non_iso_variants(template, cand, key)
            if len(reps) < 2:
                continue

            (e1, g1), (e2, g2) = reps[0], reps[1]
            col1, col2 = _edge_colors_diff(g1, g2)

            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))
            fig.suptitle(f"aid {aid}   tgt {tgt}   cut={cut}", fontsize=12)

            shared_pos = nx.spring_layout(g1, seed=SEED)
            _draw_graph(ax1, g1, f"orientation ①  E={e1:.2f}",
                        pos=shared_pos, edge_colors=col1)
            _draw_graph(ax2, g2, f"orientation ②  E={e2:.2f}",
                        pos=shared_pos, edge_colors=col2)
            _draw_graph(ax3, template, "template")

            fig.legend([plt.Line2D([0], [0], color="red",  lw=2),
                        plt.Line2D([0], [0], color="blue", lw=2),
                        plt.Line2D([0], [0], color="black", lw=2)],
                       ["only in orient ①", "only in orient ②", "common"],
                       loc="lower center", ncol=3, fontsize=9)
            plt.tight_layout()
            plt.show()

            shown += 1
            if shown >= limit:
                return

# ───────────── call it (repeat to paginate) ─────────────
show_multi_topologies()   # first batch
# show_multi_topologies() # call again for more


Check if energies change when different permutations are selected

In [None]:
import itertools, time, numpy as np
from collections import defaultdict

# ----------------------------------------------------------------------
# 0 .  Allocate mask & storage (same shapes as in the env)
# ----------------------------------------------------------------------
mask                  = np.zeros(NUM_ACTIONS + 1, dtype=bool)
subgraphs_for_action  = defaultdict(list)      # aid → [(tgt, cut, key), …]
tt_counter            = defaultdict(list)      # key → [(tgt, cut)]

# ----------------------------------------------------------------------
# 1 .  Primary-input nodes
# ----------------------------------------------------------------------
source_nodes = {n for n, deg in current_solution.in_degree() if deg == 0}

# ----------------------------------------------------------------------
# 2 .  Enumerate targets and  ≤4-input cones
# ----------------------------------------------------------------------
for tgt in current_solution.nodes():
    if tgt in source_nodes:
        continue
    if current_solution.out_degree(tgt) == 0:
        continue                               # skip primary output

    for cut in exhaustive_cut_enumeration_dag(
            current_solution, 4, tgt, filter_redundant=True):
        if not is_fanout_free_standalone(current_solution, tgt, cut):
            continue

        sg = generate_subgraph(current_solution, tgt, cut, draw=False)
        if len([n for n in sg if sg.in_degree(n) == 0]) != len(cut):
            continue

        key = _truth_key(sg)                   # (#inputs , tt-int)
        for aid in TTABLE_TO_ACTIONS.get(key, []):
            mask[aid] = True
            subgraphs_for_action[aid].append((tgt, tuple(cut), key))
        tt_counter[key].append((tgt, tuple(cut)))

print(f"[mask] {mask[:-1].sum()} rewrite actions enabled")

# ----------------------------------------------------------------------
# 3 .  For every (action, tgt, cut) pair, try *all* matching permutations
#     and record the spread of energy values they produce.
# ----------------------------------------------------------------------

def _all_perms_matching(base: nx.DiGraph, key: tuple[int, int]) -> list[nx.DiGraph]:
    """Return every permuted copy of `base` whose truth-table matches `key`."""
    inputs = [n for n in base if base.in_degree(n) == 0]
    if len(inputs) != key[0]:
        return []
    out = []
    for perm in itertools.permutations(inputs):
        g2 = nx.relabel_nodes(base, dict(zip(inputs, perm)), copy=True)
        if _truth_key(g2) == key:
            out.append(g2)
    return out


spread_cnt = 0                # how many triples show Δenergy > 0
same_cnt   = 0
max_delta  = 0.0

for aid, triples in subgraphs_for_action.items():
    template = UNIQUE_GRAPHS[aid]

    for tgt, cut, key in triples:
        rep_list = _all_perms_matching(template, key)
        if not rep_list:          # should not happen
            continue

        cand = generate_subgraph(current_solution, tgt, cut, draw=False)
        energies = []
        for repl in rep_list:
            new_sol = substitute_subgraph(current_solution, cand, repl)
            e, _ = energy_score(new_sol, check_implicit_OR_existence_v2)
            energies.append(e)

        e_min, e_max = min(energies), max(energies)
        delta = e_max - e_min
        if delta > 0:
            spread_cnt += 1
            max_delta  = max(max_delta, delta)
            note = f"Δ={delta:.3f}"
        else:
            same_cnt += 1
            note = "same"
        print(f"aid {aid:<4} tgt {tgt:<3} cut {cut}  {note}")

print(f"\nTriples with identical energy: {same_cnt}")
print(f"Triples with different energy: {spread_cnt}")
if spread_cnt:
    print(f"Largest Δenergy observed: {max_delta:.3f}")



Testing the rewriting for the full graph 

In [None]:
import networkx as nx
from pprint import pprint

# ----------------------------------------------------------------------
# expected helpers / globals already in scope:
#   • current_solution               : nx.DiGraph
#   • generate_subgraph(G, tgt, cut, draw=False) -> nx.DiGraph
#   • calculate_truth_table_v2(g)    : truth-table helper
#   • _truth_key(g)                  : (num_inputs, tt_int)
#   • TTABLE_TO_ACTIONS              : catalogue dict
# ----------------------------------------------------------------------

def _pretty_print_tt(tt):
    """Print the first few rows for readability."""
    if isinstance(tt, dict):
        rows = list(sorted(tt.items()))[:16]
        pprint(dict(rows))
        if len(tt) > 16:
            print("... (truncated)")
    else:
        print(tt)

PI = {n for n, deg in current_solution.in_degree() if deg == 0}
PO = [n for n, deg in current_solution.out_degree() if deg == 0]

MAX_PO = 5               # examine at most N primary outputs
cut     = tuple(PI)      # the “whole-graph” cut: every input node

for i, tgt in enumerate(PO[:MAX_PO]):
    sg = generate_subgraph(current_solution, tgt, cut, draw=True)

    # Warn if generate_subgraph dropped some PIs
    cone_sources = [n for n in sg if sg.in_degree(n) == 0]
    if len(cone_sources) != len(cut):
        dropped = set(cut) - set(cone_sources)
        print(f"[note] PO {tgt}: {len(dropped)} PIs not in cone:"
              f" {sorted(dropped)}")

    tt  = calculate_truth_table_v2(sg)
    key = _truth_key(sg)
    actions = TTABLE_TO_ACTIONS.get(key, [])

    print(f"\n======= PRIMARY OUTPUT #{i+1} (node {tgt}) =======")
    print(f"cut (ALL PIs): {cut}")
    print("truth table (first rows shown):")
    _pretty_print_tt(tt)
    print(f"key          : {key}")
    print(f"actions@key  : {actions if actions else '(none)'}")
    print("==================================================")


In [None]:
import networkx as nx
from pprint import pprint

# ----------------------------------------------------------------------
# expected helpers / globals already in scope:
#   • current_solution               : nx.DiGraph
#   • generate_subgraph(G, tgt, cut, draw=False) -> nx.DiGraph
#   • calculate_truth_table_v2(g)
#   • _truth_key(g)
#   • TTABLE_TO_ACTIONS
# ----------------------------------------------------------------------

def _pretty_print_tt(tt):
    """Pretty-print first few rows of a dict TT, or show raw if other type."""
    if isinstance(tt, dict):
        rows = list(sorted(tt.items()))[:16]
        pprint(dict(rows))
        if len(tt) > 16:
            print("... (truncated)")
    else:
        print(tt)

# ----------------------------------------------------------------------
# identify PIs, POs, and run the test
# ----------------------------------------------------------------------
PI = {n for n, deg in current_solution.in_degree() if deg == 0}
PO = [n for n, deg in current_solution.out_degree() if deg == 0]

MAX_CASES = 5          # examine at most this many (po, predecessor) pairs
cut       = tuple(PI)  # fixed cut: ALL inputs of the current solution

case_cnt = 0
for po in PO:
    preds = list(current_solution.predecessors(po))
    if not preds:
        continue                                # PO is dangling – skip

    for pred_idx, tgt in enumerate(preds):
        if case_cnt >= MAX_CASES:
            break

        sg = generate_subgraph(current_solution, tgt, cut, draw=True)

        # warn if some PIs are not part of the cone actually returned
        cone_sources = [n for n in sg if sg.in_degree(n) == 0]
        if len(cone_sources) != len(cut):
            dropped = set(cut) - set(cone_sources)
            print(f"[note] PO {po} pred {tgt}: "
                  f"{len(dropped)} PIs not in cone {sorted(dropped)}")

        tt  = calculate_truth_table_v2(sg)
        key = _truth_key(sg)
        actions = TTABLE_TO_ACTIONS.get(key, [])

        case_cnt += 1
        print(f"\n======= CASE {case_cnt}: PO {po}  •  target (pre-PO) {tgt} "
              f"({pred_idx+1}/{len(preds)}) =======")
        print(f"cut (ALL PIs): {cut}")
        print("truth table (first rows shown):")
        _pretty_print_tt(tt)
        print(f"key          : {key}")
        print(f"actions@key  : {actions if actions else '(none)'}")
        print("==================================================")

    if case_cnt >= MAX_CASES:
        break


In [None]:
import networkx as nx
import copy
from pprint import pprint

# ----------------------------------------------------------------------
# expected helpers / globals already in scope:
#   current_solution                 : nx.DiGraph
#   generate_subgraph(...)
#   _truth_key(...)
#   TTABLE_TO_ACTIONS
#   UNIQUE_GRAPHS                    : list[nx.DiGraph]
#   _permute_until_match(template, key, aid)
#   substitute_subgraph(original, old_sub, new_sub)
#   calculate_truth_table_v2(graph)
# ----------------------------------------------------------------------

PI = {n for n, deg in current_solution.in_degree() if deg == 0}
PO = [n for n, deg in current_solution.out_degree() if deg == 0]

cut = tuple(PI)
orig_tt = calculate_truth_table_v2(current_solution)

checked   = 0   # number of (po, aid) pairs examined
diff_tt   = 0   # pairs whose truth tables differ
diff_iso  = 0   # truth-table-identical but non-isomorphic graphs

for po in PO:
    preds = list(current_solution.predecessors(po))
    if not preds:
        continue
    tgt_B = preds[0]                     # pick first predecessor

    # --- build cones --------------------------------------------------
    sg_A = generate_subgraph(current_solution, po,    cut, draw=False)
    sg_B = generate_subgraph(current_solution, tgt_B, cut, draw=False)

    key_A = _truth_key(sg_A)
    key_B = _truth_key(sg_B)

    actions_A = set(TTABLE_TO_ACTIONS.get(key_A, []))
    actions_B = set(TTABLE_TO_ACTIONS.get(key_B, []))

    for aid in sorted(actions_A & actions_B):
        repl_A = _permute_until_match(UNIQUE_GRAPHS[aid], key_A, aid)
        repl_B = _permute_until_match(UNIQUE_GRAPHS[aid], key_B, aid)
        if repl_A is None or repl_B is None:
            continue                                     # skip if no match

        checked += 1

        sol_A = substitute_subgraph(current_solution, sg_A, repl_A)
        sol_B = substitute_subgraph(current_solution, sg_B, repl_B)

        tt_A  = calculate_truth_table_v2(sol_A)
        tt_B  = calculate_truth_table_v2(sol_B)

        same_tt  = (tt_A == tt_B)
        same_iso = nx.is_isomorphic(sol_A, sol_B)

        if not same_tt:
            diff_tt += 1
            print(f"\n[TT-DIFF] PO {po}  aid {aid}")
            print(f"  key_A {key_A}  key_B {key_B}")
        elif not same_iso:
            diff_iso += 1
            print(f"\n[ISO-DIFF] PO {po}  aid {aid}  (truth tables equal)")
            print(f"  key_A {key_A}  key_B {key_B}")

print("\n=== summary ===")
print(f"checked pairs   : {checked}")
print(f"truth-table diff: {diff_tt}")
print(f"isomorphic diff : {diff_iso}")
if diff_tt == 0 and diff_iso == 0:
    print("✓ All common rewrites produce identical truth tables "
          "and isomorphic graphs for Case A vs Case B.")


Testing the rewriting for the full graph with a NOR gate before the output 

In [None]:
import networkx as nx
from copy import deepcopy

def add_second_fanin_to_predecessor(G: nx.DiGraph) -> nx.DiGraph:
    """Return a copy of G where (at most) one PO’s predecessor now has 2 inputs."""
    G2 = deepcopy(G)

    # 1) identify PIs and POs
    PIs = [n for n, deg in G2.in_degree()  if deg == 0]
    POs = [n for n, deg in G2.out_degree() if deg == 0]

    for po in POs:
        preds = list(G2.predecessors(po))
        if not preds:
            continue
        pred = preds[0]                    # take first predecessor

        if G2.in_degree(pred) >= 2:
            continue                       # already has ≥2 inputs

        # 2) choose a PI not already feeding 'pred' and not creating a cycle
        for pi in PIs:
            if pi in G2.predecessors(pred):
                continue
            # adding pi→pred can only create a cycle if pred can reach pi
            if nx.has_path(G2, pred, pi):
                continue
            G2.add_edge(pi, pred)
            print(f"Added edge {pi} → {pred} (predecessor of PO {po})")
            return G2                      # done – return modified graph

    # nothing modified – just return a copy
    print("Every PO’s predecessor already has ≥2 fan-ins; graph unchanged.")
    return G2


# ----------------------------------------------------------------------
# Example usage
# ----------------------------------------------------------------------
current_solution_mod = add_second_fanin_to_predecessor(current_solution)

# Verify
for po in [n for n, d in current_solution_mod.out_degree() if d == 0]:
    preds = list(current_solution_mod.predecessors(po))
    print(f"PO {po}   predecessor(s): {preds}   "
          f"in-degrees: {[current_solution_mod.in_degree(p) for p in preds]}")


In [None]:
import networkx as nx
from pprint import pprint

# ----------------------------------------------------------------------
# expected helpers / globals already in scope:
#   • current_solution               : nx.DiGraph
#   • generate_subgraph(G, tgt, cut, draw=False) -> nx.DiGraph
#   • calculate_truth_table_v2(g)    : truth-table helper
#   • _truth_key(g)                  : (num_inputs, tt_int)
#   • TTABLE_TO_ACTIONS              : catalogue dict
# ----------------------------------------------------------------------

def _pretty_print_tt(tt):
    """Print the first few rows for readability."""
    if isinstance(tt, dict):
        rows = list(sorted(tt.items()))[:16]
        pprint(dict(rows))
        if len(tt) > 16:
            print("... (truncated)")
    else:
        print(tt)

PI = {n for n, deg in current_solution_mod.in_degree() if deg == 0}
PO = [n for n, deg in current_solution_mod.out_degree() if deg == 0]

MAX_PO = 5               # examine at most N primary outputs
cut     = tuple(PI)      # the “whole-graph” cut: every input node

for i, tgt in enumerate(PO[:MAX_PO]):
    sg = generate_subgraph(current_solution_mod, tgt, cut, draw=True)

    # Warn if generate_subgraph dropped some PIs
    cone_sources = [n for n in sg if sg.in_degree(n) == 0]
    if len(cone_sources) != len(cut):
        dropped = set(cut) - set(cone_sources)
        print(f"[note] PO {tgt}: {len(dropped)} PIs not in cone:"
              f" {sorted(dropped)}")

    tt  = calculate_truth_table_v2(sg)
    key = _truth_key(sg)
    actions = TTABLE_TO_ACTIONS.get(key, [])

    print(f"\n======= PRIMARY OUTPUT #{i+1} (node {tgt}) =======")
    print(f"cut (ALL PIs): {cut}")
    print("truth table (first rows shown):")
    _pretty_print_tt(tt)
    print(f"key          : {key}")
    print(f"actions@key  : {actions if actions else '(none)'}")
    print("==================================================")


In [None]:
import networkx as nx
from pprint import pprint

# ----------------------------------------------------------------------
# expected helpers / globals already in scope:
#   • current_solution               : nx.DiGraph
#   • generate_subgraph(G, tgt, cut, draw=False) -> nx.DiGraph
#   • calculate_truth_table_v2(g)
#   • _truth_key(g)
#   • TTABLE_TO_ACTIONS
# ----------------------------------------------------------------------

def _pretty_print_tt(tt):
    """Pretty-print first few rows of a dict TT, or show raw if other type."""
    if isinstance(tt, dict):
        rows = list(sorted(tt.items()))[:16]
        pprint(dict(rows))
        if len(tt) > 16:
            print("... (truncated)")
    else:
        print(tt)

# ----------------------------------------------------------------------
# identify PIs, POs, and run the test
# ----------------------------------------------------------------------
PI = {n for n, deg in current_solution_mod.in_degree() if deg == 0}
PO = [n for n, deg in current_solution_mod.out_degree() if deg == 0]

MAX_CASES = 5          # examine at most this many (po, predecessor) pairs
cut       = tuple(PI)  # fixed cut: ALL inputs of the current solution

case_cnt = 0
for po in PO:
    preds = list(current_solution_mod.predecessors(po))
    if not preds:
        continue                                # PO is dangling – skip

    for pred_idx, tgt in enumerate(preds):
        if case_cnt >= MAX_CASES:
            break

        sg = generate_subgraph(current_solution_mod, tgt, cut, draw=True)

        # warn if some PIs are not part of the cone actually returned
        cone_sources = [n for n in sg if sg.in_degree(n) == 0]
        if len(cone_sources) != len(cut):
            dropped = set(cut) - set(cone_sources)
            print(f"[note] PO {po} pred {tgt}: "
                  f"{len(dropped)} PIs not in cone {sorted(dropped)}")

        tt  = calculate_truth_table_v2(sg)
        key = _truth_key(sg)
        actions = TTABLE_TO_ACTIONS.get(key, [])

        case_cnt += 1
        print(f"\n======= CASE {case_cnt}: PO {po}  •  target (pre-PO) {tgt} "
              f"({pred_idx+1}/{len(preds)}) =======")
        print(f"cut (ALL PIs): {cut}")
        print("truth table (first rows shown):")
        _pretty_print_tt(tt)
        print(f"key          : {key}")
        print(f"actions@key  : {actions if actions else '(none)'}")
        print("==================================================")

    if case_cnt >= MAX_CASES:
        break


In [None]:
import networkx as nx
import copy
from pprint import pprint

# ----------------------------------------------------------------------
# expected helpers / globals already in scope:
#   current_solution                 : nx.DiGraph
#   generate_subgraph(...)
#   _truth_key(...)
#   TTABLE_TO_ACTIONS
#   UNIQUE_GRAPHS                    : list[nx.DiGraph]
#   _permute_until_match(template, key, aid)
#   substitute_subgraph(original, old_sub, new_sub)
#   calculate_truth_table_v2(graph)
# ----------------------------------------------------------------------

PI = {n for n, deg in current_solution_mod.in_degree() if deg == 0}
PO = [n for n, deg in current_solution_mod.out_degree() if deg == 0]

cut = tuple(PI)
orig_tt = calculate_truth_table_v2(current_solution_mod)

checked   = 0   # number of (po, aid) pairs examined
diff_tt   = 0   # pairs whose truth tables differ
diff_iso  = 0   # truth-table-identical but non-isomorphic graphs

for po in PO:
    preds = list(current_solution_mod.predecessors(po))
    if not preds:
        continue
    tgt_B = preds[0]                     # pick first predecessor

    # --- build cones --------------------------------------------------
    sg_A = generate_subgraph(current_solution_mod, po,    cut, draw=False)
    sg_B = generate_subgraph(current_solution_mod, tgt_B, cut, draw=False)

    key_A = _truth_key(sg_A)
    key_B = _truth_key(sg_B)

    actions_A = set(TTABLE_TO_ACTIONS.get(key_A, []))
    actions_B = set(TTABLE_TO_ACTIONS.get(key_B, []))

    for aid in sorted(actions_A & actions_B):
        repl_A = _permute_until_match(UNIQUE_GRAPHS[aid], key_A, aid)
        repl_B = _permute_until_match(UNIQUE_GRAPHS[aid], key_B, aid)
        if repl_A is None or repl_B is None:
            continue                                     # skip if no match

        checked += 1

        sol_A = substitute_subgraph(current_solution_mod, sg_A, repl_A)
        sol_B = substitute_subgraph(current_solution_mod, sg_B, repl_B)

        tt_A  = calculate_truth_table_v2(sol_A)
        tt_B  = calculate_truth_table_v2(sol_B)

        same_tt  = (tt_A == tt_B)
        same_iso = nx.is_isomorphic(sol_A, sol_B)

        if not same_tt:
            diff_tt += 1
            print(f"\n[TT-DIFF] PO {po}  aid {aid}")
            print(f"  key_A {key_A}  key_B {key_B}")
        elif not same_iso:
            diff_iso += 1
            print(f"\n[ISO-DIFF] PO {po}  aid {aid}  (truth tables equal)")
            print(f"  key_A {key_A}  key_B {key_B}")

print("\n=== summary ===")
print(f"checked pairs   : {checked}")
print(f"truth-table diff: {diff_tt}")
print(f"isomorphic diff : {diff_iso}")
if diff_tt == 0 and diff_iso == 0:
    print("✓ All common rewrites produce identical truth tables "
          "and isomorphic graphs for Case A vs Case B.")


Gemini explanation of checking tha sg permutation leads to same action

In [None]:
# Assumed to be previously defined:
# pins: a list of the input node names/identifiers for the current cone (sg).
# sg: the networkx DiGraph object representing the current cone.
# _truth_key: a function that computes a canonical key (e.g., from a truth table) for a cone.
# lookup: the TTABLE_TO_ACTIONS dictionary mapping truth keys to lists of action IDs.
# actions_default: a list of action IDs associated with the cone 'sg' in its original input order.
# gaps: a dictionary, likely initialized like {aid: [] for aid in actions_default},
#       to store permutation indices where an action 'aid' goes missing.
# verbose: a boolean flag to control detailed printing.
# indent: a string for formatting verbose output.

# Start of the loop in question:
for idx, perm in enumerate(itertools.permutations(pins)):
    # 1. Generate all permutations of input pins
    #    - itertools.permutations(pins): Generates all possible ordered arrangements of the items in the 'pins' list.
    #      For example, if pins = ['A', 'B', 'C'], one 'perm' would be ('A', 'C', 'B'), another ('B', 'A', 'C'), etc.
    #    - enumerate(...): Adds a counter 'idx' to each permutation. So, 'idx' is the 0-based index of the current
    #      permutation 'perm'.

    sg_perm = nx.relabel_nodes(sg, dict(zip(pins, perm)), copy=True)
    # 2. Create a new cone graph with permuted inputs
    #    - zip(pins, perm): Creates pairs from the original pin names and the current permuted pin names.
    #      If pins = ['in1', 'in2'] and perm = ('in2', 'in1'), then zip gives [('in1', 'in2'), ('in2', 'in1')].
    #    - dict(...): Converts these pairs into a mapping dictionary. For the example above: {'in1': 'in2', 'in2': 'in1'}.
    #      This dictionary tells relabel_nodes how to rename the nodes: "rename original 'in1' to 'in2'", etc.
    #    - nx.relabel_nodes(sg, ..., copy=True): Creates a *new* graph 'sg_perm' by taking the original cone 'sg'
    #      and applying the node renaming. 'copy=True' ensures 'sg' itself is not modified.
    #      Effectively, 'sg_perm' is the same logical cone as 'sg' but with its inputs "rewired" according to the permutation 'perm'.

    key_perm = _truth_key(sg_perm)
    # 3. Get the truth key of the permuted cone
    #    - _truth_key(sg_perm): Computes the canonical truth key for this newly created cone 'sg_perm'
    #      with its reordered inputs. If the logic function is sensitive to input order,
    #      this key might be different from the key of the original 'sg'.

    actions_perm = lookup.get(key_perm, [])
    # 4. Get the actions associated with the permuted cone
    #    - lookup.get(key_perm, []): Looks up this new 'key_perm' in the 'TTABLE_TO_ACTIONS' dictionary.
    #      - If 'key_perm' exists, 'actions_perm' becomes the list of action IDs associated with this permuted logic.
    #      - If 'key_perm' is not found (meaning this permuted cone's logic function isn't in the catalogue
    #        or has no actions defined), 'actions_perm' defaults to an empty list.

    if verbose:
        print(f"{indent}perm {idx:>2}: {perm}  "
              f"key={key_perm}  actions={sorted(actions_perm)}")
    # 5. Optional: Print debugging information
    #    - If 'verbose' is True, it prints the current permutation index, the permutation itself,
    #      the truth key of the permuted cone, and its associated actions (sorted for consistent display).

    for aid in actions_default:
        # 6. Check if default actions are present in the permuted cone's actions
        #    - This inner loop iterates through each action ID ('aid') that was present in the
        #      'actions_default' list (the actions of the original, unpermuted cone).

        if aid not in actions_perm:
            # 7. If a default action is MISSING in the permuted cone's actions:
            #    - This condition is true if an action that the original cone could perform
            #      CANNOT be performed by the cone when its inputs are ordered according to 'perm'.

            gaps[aid].append(idx)
            # 8. Record the "gap"
            #    - The index 'idx' of the current permutation 'perm' (which caused the action 'aid' to be lost)
            #      is appended to the list associated with 'aid' in the 'gaps' dictionary.
            #      This logs that for action 'aid', permutation number 'idx' makes it unavailable.

            if verbose:
                print(f"{indent}  └─ missing aid {aid} in perm {idx}")
            # 9. Optional: Print information about the missing action.

P-classes counter

In [None]:
from collections import defaultdict
import itertools, networkx as nx
from tqdm import tqdm              # pip install tqdm if missing

# ------------------------------------------------------------------
# helper – encode a truth-table dict → integer (same as _truth_key)
# ------------------------------------------------------------------
def _tt_dict_to_int(tt_dict):
    bits = "".join(str(o[0]) for _, o in sorted(tt_dict.items()))
    return int(bits, 2)

# ------------------------------------------------------------------
# compute P-class signatures with a progress bar
# ------------------------------------------------------------------
p_classes_per_arity = defaultdict(set)   # n_in → {frozenset(tt_ints)}

for g in tqdm(UNIQUE_GRAPHS,
              desc="Computing P-class signatures",
              unit="graph"):
    pins = [n for n in g if g.in_degree(n) == 0]
    n_in = len(pins)

    tt_ints = set()
    for perm in itertools.permutations(pins):
        g_perm  = nx.relabel_nodes(g, dict(zip(pins, perm)), copy=True)
        tt_dict = calculate_truth_table_v2(g_perm)
        tt_ints.add(_tt_dict_to_int(tt_dict))

    signature = frozenset(tt_ints)       # permutation class
    p_classes_per_arity[n_in].add(signature)

# ------------------------------------------------------------------
# report
# ------------------------------------------------------------------
for n in sorted(p_classes_per_arity):
    print(f"{n}-input P-classes : {len(p_classes_per_arity[n])}")

total = sum(len(s) for s in p_classes_per_arity.values())
print(f"TOTAL distinct P-classes across all arities: {total}")


In [None]:
# --- helper ------------------------------------------------------------------
def _tt_int(tt_dict):
    """Convert truth-table dict to integer (little endian)."""
    bits = "".join(str(o[0]) for _, o in sorted(tt_dict.items()))
    return int(bits, 2)

# --- build map: (arity -> {P-class signature: [action_ids]}) -----------------
pclass_to_aids = {1: defaultdict(list), 2: defaultdict(list),
                  3: defaultdict(list), 4: defaultdict(list)}

for aid, g in enumerate(tqdm(UNIQUE_GRAPHS,
                             desc="Collecting P-classes", unit="graph")):
    pins = [n for n in g if g.in_degree(n) == 0]
    n_in = len(pins)
    if n_in not in pclass_to_aids:
        continue

    # signature = set of truth-table ints over all permutations
    sig = frozenset(
        _tt_int(
            calculate_truth_table_v2(
                nx.relabel_nodes(g, dict(zip(pins, perm)), copy=True)
            )
        )
        for perm in itertools.permutations(pins)
    )
    pclass_to_aids[n_in][sig].append(aid)

# --- plot one bar chart per input arity --------------------------------------
for n_in in sorted(pclass_to_aids):
    counts = [len(aids) for aids in pclass_to_aids[n_in].values()]
    if not counts:
        continue

    counts.sort(reverse=True)
    x_vals = range(1, len(counts)+1)

    plt.figure(figsize=(10, 4))
    plt.bar(x_vals, counts, edgecolor="black")
    plt.title(f'{n_in}-input: unique topologies per P-class')
    plt.xlabel('P-class (sorted by size)')
    plt.ylabel('# unique topologies')
    plt.yscale('log')        # comment out if linear scale preferred
    plt.tight_layout()
    plt.show()


NPN-class


In [None]:
from collections import defaultdict
import itertools, networkx as nx
from tqdm import tqdm    # pip install tqdm if missing

# --- helper: TT-dict ➜ integer --------------------------------------
def _tt_dict_to_int(tt_dict):
    bits = "".join(str(o[0]) for _, o in sorted(tt_dict.items()))
    return int(bits, 2)

# --- enumerate all NPN-equivalent TT ints for n inputs --------------
def enumerate_npn(tt_dict, n):
    """
    Return a set of truth-table integers reachable from `tt_dict`
    by (permutations) × (input negations) × (optional output negation).
    """
    tt_ints = set()

    # pre-compute lexicographic list of input tuples once
    inputs_lex = [tuple((i >> b) & 1 for b in range(n))  # little endian
                  for i in range(2 ** n)]

    for perm in itertools.permutations(range(n)):          # P
        # tuples reordered according to perm
        permuted_inputs = [tuple(t[i] for i in perm) for t in inputs_lex]

        for mask in range(1 << n):                         # N on inputs
            mask_bits = [(mask >> i) & 1 for i in range(n)]

            bits0 = []     # original output polarity
            bits1 = []     # inverted output polarity

            for tup in permuted_inputs:
                looked_up = tuple(t ^ m for t, m in zip(tup, mask_bits))
                bit = tt_dict[looked_up][0]
                bits0.append(str(bit))
                bits1.append(str(bit ^ 1))

            tt_ints.add(int("".join(bits0), 2))
            tt_ints.add(int("".join(bits1), 2))            # N on output
    return tt_ints

# --- build signatures ------------------------------------------------
npn_classes_per_arity = defaultdict(set)   # n_in -> {frozenset(tt_ints)}

for g in tqdm(UNIQUE_GRAPHS,
              desc="Computing NPN-class signatures",
              unit="graph"):
    pins = [n for n in g if g.in_degree(n) == 0]
    n_in = len(pins)

    base_tt = calculate_truth_table_v2(g)
    signature = frozenset(enumerate_npn(base_tt, n_in))
    npn_classes_per_arity[n_in].add(signature)

# --- report ----------------------------------------------------------
print()
for n in sorted(npn_classes_per_arity):
    print(f"{n}-input NPN-classes : {len(npn_classes_per_arity[n])}")

total = sum(len(s) for s in npn_classes_per_arity.values())
print(f"TOTAL distinct NPN-classes across all arities: {total}")
