To launch in SuperCloud from a Computed Node


LLsub -i full #for an exclusive node

LLsub -i -s 40 #for node with 40 CPUs

LLsub -i -s 40 -g volta:1 #for node with 40 CPUs and 1 Volta GPU

salloc  --job-name=interactive --qos=high --time=00:60:00 --partition=debug-gpu --gres=gpu:volta:1 --cpus-per-task=40 srun    --pty bash -i

salloc  --job-name=interactive --qos=high --time=00:60:00 --partition=debug-cpu --cpus-per-task=40 srun  --pty bash -i

module load anaconda/2023a-pytorch

jupyter lab --no-browser --ip=0.0.0.0 --port=8890



In [None]:
import pickle, time
import matplotlib.pyplot as plt
import networkx as nx
import itertools
import re
from pathlib import Path
from tqdm.notebook import tqdm   
from collections import defaultdict
import itertools
from collections import defaultdict
from pathlib import Path
import re, itertools, pickle, networkx as nx
from tqdm import tqdm
import matplotlib.pyplot as plt
import itertools
from collections import defaultdict
import networkx as nx    
import pandas as pd
from datetime import datetime
from pathlib import Path       

In [None]:
from dgd.utils.utils5 import (
    calculate_truth_table_v2,
    energy_score,
    check_implicit_OR_existence_v3
)

from dgd.environments.drl3env_loader5 import _apply_implicit_or, _compute_truth_key, _compute_hash, _apply_implicit_or

Folder with the biological circuit designs 

In [None]:
run_dir = "/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/scripts/runs/Fig3_4input_200_logic_functions_registry_sampling_drl3env_loader5/seed_1"
run_dir = Path(run_dir) 

Load action space motifs

In [None]:
MOTIFS_PATH = "/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/scripts/action_motifs.pkl"
with open(MOTIFS_PATH, "rb") as f:
    action_motifs = pickle.load(f)
    
UNIQUE_GRAPHS = action_motifs["graphs"] 

General utility functions

In [None]:
def build_motif_canonicals():
    """
    Compute canonical form for every motif in global UNIQUE_GRAPHS and
    store them in UNIQUE_GRAPHS_canonical, with a progress bar.
    """
    global UNIQUE_GRAPHS_canonical
    bar = tqdm(UNIQUE_GRAPHS, desc="Canonicalising motifs", unit="motif")
    UNIQUE_GRAPHS_canonical = [_apply_implicit_or(g) for g in bar]
    print(f"Built canonical bank for {len(UNIQUE_GRAPHS_canonical)} motifs.")

def check_vs_motif_bank(graphs):
    """
    For each graph in `graphs`, report whether it is isomorphic to any
    canonical motif.  Shows a progress bar over the input list.
    """
    if "UNIQUE_GRAPHS_canonical" not in globals():
        build_motif_canonicals()

    bar = tqdm(enumerate(graphs, 1), total=len(graphs), desc="Matching", unit="graph")
    for idx, g in bar:
        canon_g = _apply_implicit_or(g)
        
        match = False                              # default: no match yet
        for m in UNIQUE_GRAPHS_canonical:          # scan every canonical motif
            if nx.is_isomorphic(canon_g, m):       # found an isomorphic partner?
                match = True                       # mark it
                break                              # stop checking further

        status = "MATCH" if match else "NEW  "
        bar.set_postfix({"last": status})      # live status in the bar
        print(f"Graph {idx:>2}: {status}   "
              f"nodes={g.number_of_nodes():>2}  edges={g.number_of_edges():>2}")

def load_registry(pkl_file):
    """Load the pickle and rebuild NetworkX graphs."""
    with open(pkl_file, "rb") as f:
        saved = pickle.load(f)

    registry = {}
    for h, bucket in saved.items():
        restored = []
        for canon_nl, orig_nl, e in bucket:
            canon = nx.node_link_graph(canon_nl)
            orig  = nx.node_link_graph(orig_nl)
            restored.append((canon, orig, e))
        registry[h] = restored
    return registry


def registry_size(reg):
    #reg = load_registry(pkl_file)
    length = sum(len(b) for b in reg.values())
    print(f"Registry length: {length}")
    return length

def fast_registry_size(pkl_file):
    """
    Quick-check the contents of a saved registry file.

    Prints:
      • total triples  (canon, orig, e)
      • distinct hash buckets
    Returns a tuple (n_hashes, n_items) in case you want to use it programmatically.
    """
    import pickle, pathlib, os

    pkl_file = os.fspath(pkl_file)          # accept Path or str
    with open(pkl_file, "rb") as f:
        saved = pickle.load(f)

    n_hashes = len(saved)                   # one key per bucket
    n_items  = sum(len(bucket) for bucket in saved.values())

    print(f"{n_items:,} items across {n_hashes:,} hash buckets")
    return n_hashes, n_items


def draw_pair(canon, orig, h, e, seed=42):
    fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
    for ax, g, ttl in zip(axes, (canon, orig), ("canonical", "original")):
        pos = nx.spring_layout(g, seed=seed)
        nx.draw(g, pos, ax=ax, with_labels=True, node_size=100, font_size=7)
        ax.set_title(ttl)
    fig.suptitle(f"hash={h}   energy={e:.3f}", fontsize=10)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


def top_lowest_energy_plots(reg, n=20, pause=0):
    """
    Show a ranked table and inline plots for the `n` lowest-energy entries.

    Parameters
    ----------
    pkl_file : str or Path
        Pickled registry file.
    n : int, optional (default 20)
        Number of graph pairs to display.
    pause : float, optional (default 0)
        Seconds to wait between figures.
    """
    #reg = load_registry(pkl_file)
    flat = [
        (e, h, canon, orig)
        for h, bucket in reg.items()
        for canon, orig, e in bucket
    ]
    flat.sort(key=lambda t: t[0])

    print(f"\nTop-{n} by lowest energy:\n")
    print(f"{'rank':>4} │ {'energy':>10} │ {'hash':<16} │ nodes")
    print("────┼────────────┼────────────────┼──────")
    for i, (e, h, canon, orig) in enumerate(flat[:n], 1):
        print(f"{i:>4} │ {e:10.4f} │ {h:<16} │ {orig.number_of_nodes():>5}")
        draw_pair(canon, orig, h, e)
        if pause:
            time.sleep(pause)

def iso_pairs_lowest(reg, top=20):
    """
    Among the `top` lowest-energy *canonical* graphs, print every pair that
    is isomorphic.  Returns a list of (i, j) index pairs.
    """
    #reg = load_registry(pkl_file)

    # flatten and keep canonical graph
    flat = [(e, h, canon)
            for h, bucket in reg.items()
            for canon, _, e in bucket]
    flat.sort(key=lambda t: t[0])
    flat = flat[:top]                       # lowest-energy slice

    duplicates = []
    for (i, (e_i, h_i, g_i)), (j, (e_j, h_j, g_j)) in itertools.combinations(enumerate(flat), 2):
        if nx.is_isomorphic(g_i, g_j):
            duplicates.append((i, j))
            print(f"Duplicate pair: idx {i} ↔ {j}   "
                  f"energies {e_i:.4f} / {e_j:.4f}   "
                  f"hashes {h_i} / {h_j}")

    if not duplicates:
        print(f"No canonical duplicates among the lowest {top} energies.")

    return duplicates

def remove_redundant_edges(g):
    """Return a copy of `g` with every non-essential edge pruned."""
    g = g.copy()
    tt_ref = calculate_truth_table_v2(g)
    changed = True
    while changed:
        changed = False
        for u, v in list(g.edges()):
            g_tmp = g.copy()
            g_tmp.remove_edge(u, v)
            if calculate_truth_table_v2(g_tmp) == tt_ref:
                g.remove_edge(u, v)
                changed = True
    return g

def unique_by_isomorphism(graphs):
    uniq = []
    for g in graphs:
        if not any(nx.is_isomorphic(g, h) for h in uniq):
            uniq.append(g)
    return uniq


def plot_graphs(graphs, cols=3, seed=42):
    rows = (len(graphs) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    axes = axes.ravel()
    for ax, g in zip(axes, graphs):
        pos = nx.spring_layout(g, seed=seed)
        nx.draw(g, pos, ax=ax, with_labels=True, node_size=100, font_size=7)
        ax.set_title(f"nodes={g.number_of_nodes()}  edges={g.number_of_edges()}")
    for ax in axes[len(graphs):]:
        ax.axis("off")
    plt.tight_layout()
    plt.show()


def prune_and_plot_optimal(reg):
    """
    • keep every graph at the absolute minimum energy
    • remove redundant edges in each
    • drop isomorphic duplicates
    • plot & print truth tables
    • RETURN the list of unique pruned graphs
    """
    #reg = load_registry(pkl_file)

    # -- keep only the optimal energy ----------------------------------------
    flat = [(e, canon) for bucket in reg.values() for canon, _, e in bucket]
    best_energy = min(e for e, _ in flat)
    best_graphs = [canon for e, canon in flat if e == best_energy]
    print(f"Optimal energy: {best_energy:.4f}   raw count: {len(best_graphs)}")

    # -- prune redundant edges -----------------------------------------------
    pruned = [remove_redundant_edges(g) for g in best_graphs]

    # -- drop isomorphic duplicates ------------------------------------------
    unique = unique_by_isomorphism(pruned)
    print(f"{len(unique)} unique graph(s) remain after pruning + iso check.\n")

    # -- plot each graph and print its truth table ---------------------------
    for idx, g in enumerate(unique, 1):
        # plot
        plt.figure(figsize=(4, 4))
        pos = nx.spring_layout(g, seed=42)
        nx.draw(g, pos, with_labels=True, node_size=100, font_size=7)
        plt.title(f"Graph {idx}   nodes={g.number_of_nodes()}  edges={g.number_of_edges()}")
        plt.show()

        # truth table
        tt = calculate_truth_table_v2(g)
        print(f"Truth table for Graph {idx}:")
        for inputs, out in sorted(tt.items()):
            print(f"  {inputs}  →  {out[0]}")
        print()

    return unique


def energies_from_log(txt_file):
    pat = re.compile(r"^\s*(?:\Selected)\s+(.*_NIG_unoptimized\.pkl)$")
    paths = [m.group(1).strip() for line in open(txt_file)
             if (m := pat.match(line))]
    if not paths:
        print("No circuit paths found in the file."); return

    rows = []
    for fp in tqdm(paths, desc="Processing circuits", unit="circuit"):
        G = load_graph_pickle(fp)
        E_orig, _ = energy_score(G, check_implicit_OR_existence_v3)

        canon = _apply_implicit_or(G)
        E_canon, _ = energy_score(canon, check_implicit_OR_existence_v3)

        hex_id = Path(fp).name.split("_")[0]
        rows.append((hex_id, E_orig, E_canon))

    # pretty print
    print(f"\nEnergy summary ({len(rows)} circuits):\n")
    print(f"{'hex':<6} │ {'orig':>8} │ {'canon':>8}")
    print("───────┼──────────┼──────────")
    for h, e1, e2 in rows:
        print(f"{h:<6} │ {e1:8.3f} │ {e2:8.3f}")

def load_graph_pickle(filename):
    num_nodes, edges, node_attrs = pickle.load(open(filename, "rb"))
    G = nx.DiGraph()
    for n, attr in node_attrs.items():
        G.add_node(n, type=attr) if attr is not None else G.add_node(n)
    G.add_edges_from(edges)
    return G

def truth_table_signature(tt):
    """
    Convert the dict {inputs → outputs} returned by calculate_truth_table_v2
    into a single immutable tuple that can be used as a dict key.
    For multi-output circuits we simply concatenate the bits.
    """
    flat = []
    for inp in sorted(tt):                 # lexicographic input order
        # 'outputs' is a tuple even for single-output circuits
        flat.extend(tt[inp])
    return tuple(flat)                     # e.g. (0,1,1,0,1,0,0,1)

# ── 2.  regroup the existing registry by truth table ──────────────────────────
def split_registry_by_truth_table(registry):
    """
    Return a dict  {signature → list of (hash_id, canon, orig, energy)}.
    """
    buckets = defaultdict(list)

    for h, bucket in registry.items():
        for canon, orig, e in bucket:
            sig = truth_table_signature(calculate_truth_table_v2(canon))
            buckets[sig].append((h, canon, orig, e))

    print(f"Found {len(buckets)} distinct truth tables.")
    return buckets



def split_registry_by_truth_table_with_inputs_permutations(registry):
  
    
    print("Calculating allowed boolean functions")
    allowed_signatures = {truth_table_signature(calculate_truth_table_v2(canon)) for bucket in registry.values() for canon, _, _ in bucket}    
    print(f"Done calculating allowed boolean functions. Found {len(allowed_signatures)}")         
    
    buckets = defaultdict(list)

    for h, bucket in tqdm(registry.items(), total=len(registry), desc="Hash groups", unit="group"):
        for canon, orig, energy in bucket:

            # primary inputs (sources) — same criterion as in _permute_and_match
            inputs = [n for n in canon if canon.in_degree(n) == 0]

            permutations = itertools.permutations(inputs)

            for perm in permutations:
                
                mapping = dict(zip(inputs, perm))
                canon_permuted_inputs = nx.relabel_nodes(canon, mapping, copy=True)

                sig = truth_table_signature(calculate_truth_table_v2(canon_permuted_inputs))
                
                if sig in allowed_signatures:
                    buckets[sig].append((h, canon_permuted_inputs, orig, energy))

    print(f"Found {len(buckets)} distinct truth tables (after input permutations).")
    return buckets



def run_analysis_per_bucket(registry, permute_inputs = True):
    """
    For every distinct truth table:
        • build a mini-registry with the same shape as the original one
        • call prune_and_plot_optimal, check_vs_motif_bank, …
    """
    
    if permute_inputs:
        buckets = split_registry_by_truth_table_with_inputs_permutations(registry)
    else:
        buckets = split_registry_by_truth_table(registry)
    
    for idx, (sig, entries) in enumerate(buckets.items(), 1):
        print("\n" + "═"*60)
        print(f"Truth table {idx}/{len(buckets)}  –  {len(entries)} circuit(s)")
        print("Signature:", sig)    
        print("Hex ID   :", hex_from_signature(sig))
        print("═"*60)

        # ---- rebuild a “registry-shaped” dict so your helpers work unchanged
        mini = defaultdict(list)
        for h, canon, orig, e in entries:
            mini[h].append((canon, orig, e))

        # ---- now reuse the tools you already wrote --------------------------
        unique_graphs = prune_and_plot_optimal(mini)
        check_vs_motif_bank(unique_graphs)
        
def plot_bucket_by_hex(
    registry,
    hex_id: str,
    *,
    permute_inputs=True,
    show=True,
):
    # normalise input ("3c96", "0X3C96", … → "3c96")
    hex_id = hex_id.lower().lstrip("0x")

    # 1) build buckets -------------------------------------------------------
    buckets = (split_registry_by_truth_table_with_inputs_permutations(registry)
               if permute_inputs else
               split_registry_by_truth_table(registry))

    # 2) locate the signature that matches the requested hex ----------------
    target_sig = None
    for sig in buckets:
        if hex_from_signature(sig).lower().lstrip("0x") == hex_id:
            target_sig = sig
            break

    if target_sig is None:
        raise ValueError(f"Hex ID {hex_id} not found in registry.")

    entries = buckets[target_sig]
    print("\n" + "═" * 60)
    print(f"Hex ID   : 0x{hex_id.upper()}")
    print(f"Signature: {target_sig}")
    print(f"Circuits : {len(entries)}")
    print("═" * 60)

    # 3) rebuild mini-registry ----------------------------------------------
    mini = defaultdict(list)
    for entry in entries:
        h, canon, orig, e = entry[:4]
        mini[h].append((canon, orig, e))

    # -------- run your analysis helper ------------------------------------
    unique_graphs = prune_and_plot_optimal(mini)   # <- no return_fig kwarg
    check_vs_motif_bank(unique_graphs)

'''     
def _apply_implicit_or(G, fanin_size: int = 2):
    G_copy = G.copy()
    output_nodes = [n for n in G_copy if G_copy.out_degree(n) == 0]
    if not output_nodes:
        return G_copy
    output_node = output_nodes[0]
    results_check_implicit_OR_existence = check_implicit_OR_existence_v3(G_copy, output_node, fanin_size)
    best_node_reduction_found, best_node_reduction_found_key = 0, None
    for key, value in results_check_implicit_OR_existence.items():
        if value["is_there_an_implicit_OR"] and value["number_of_nodes_available_for_removal"] > best_node_reduction_found:
            best_node_reduction_found_key, best_node_reduction_found = key, value["number_of_nodes_available_for_removal"]
    if best_node_reduction_found_key is None:
        return G_copy
    cut = results_check_implicit_OR_existence[best_node_reduction_found_key]["cut"]
    cone = results_check_implicit_OR_existence[best_node_reduction_found_key]["cone"]
    return add_implicit_OR_to_dag_v2(G_copy, output_node, cut, cone)
'''

def tag_value(p: Path) -> int:
    """Return the numeric tag, e.g. '4_800' -> 4800."""
    m = rx.fullmatch(p.name)
    if not m:
        raise ValueError(f"Not a registry file: {p}")
    return int(m.group(1).replace("_", ""))     # drop underscores, then int()


def load_selected_paths(txt_file):
    """
    Parse *selected_graphs.txt* and return a list of absolute Path objects.

    Handles lines such as
        Selected /path/with spaces/0x66_NIG_unoptimized.pkl
        /full/or/relative/0x23_NIG_unoptimized.pkl

    • If the line starts with the literal word “Selected”, the remainder of
      the line is taken verbatim as the path.
    • Otherwise the whole line is treated as the path.
    • Relative paths (very unlikely in your case) are resolved relative to
      the folder that contains *selected_graphs.txt*.
    """
    txt_file = Path(txt_file)
    base_dir = txt_file.parent
    paths    = []

    for line in txt_file.read_text().splitlines():
        line = line.strip()
        if not line:
            continue                                # skip empty lines

        path_str = line.split(maxsplit=1)[1] if line.startswith("Selected ") else line
        p = Path(path_str)

        if not p.is_absolute():
            p = base_dir / p                       # anchor relative paths

        p = p.resolve()
        if not p.exists():
            raise FileNotFoundError(f"{p} (from {txt_file}) does not exist")
        paths.append(p)

    if not paths:
        raise RuntimeError(f"No pickle paths found in {txt_file}")

    return paths


def bucket_unoptimised(txt_file):
    """
    Return {signature → list of (file_path, G_orig, E_orig)} for the unoptimised set.
    """
    buckets = defaultdict(list)
    paths   = load_selected_paths(txt_file)

    for fp in tqdm(paths, desc="Loading unoptimised graphs", unit="graph"):
        G_orig  = load_graph_pickle(fp)                      # Path object works fine
        E_orig, _ = energy_score(G_orig, check_implicit_OR_existence_v3)

        canon  = _apply_implicit_or(G_orig)                  # stabilise truth table
        sig    = truth_table_signature(calculate_truth_table_v2(canon))

        buckets[sig].append((fp, G_orig, E_orig))
    return buckets

def compare_registry_vs_unoptimised(registry, selected_txt, permute_inputs = True):
    """
    For every truth-table that exists in *either* source:
        • print the best energy in the registry   (optimised / canon)
        • print the best energy in selected_txt   (unoptimised)
        • plot the two winning graphs
    """
    if permute_inputs:
        can_buckets = split_registry_by_truth_table_with_inputs_permutations(registry)
    else:
        can_buckets = split_registry_by_truth_table(registry)
        
    unopt_buckets = bucket_unoptimised(selected_txt)

    header = f"{'TT#':>4} │ {'E canon':>9} │ {'hash':<12} │ "\
             f"{'E unopt':>9} │ {'file':<28} │ ΔE"
    sep = "────┼──────────┼──────────────┼──────────┼────────────────────────────┼────────"
    print("\n"+header);  print(sep)

    # iterate over union of truth-tables so nothing is missed
    for idx, sig in enumerate(sorted(set(can_buckets) | set(unopt_buckets)), 1):

        # -- best canonical ---------------------------------------------------
        if sig in can_buckets:
            h_c, g_canon, _, e_canon = min(can_buckets[sig], key=lambda t: t[3])
        else:                                      # TT not in registry
            h_c, g_canon, e_canon = "—", None, float("nan")

        # -- best unoptimised -------------------------------------------------
        if sig in unopt_buckets:
            fp_u, g_unopt, e_unopt = min(unopt_buckets[sig], key=lambda t: t[2])
        else:                                      # TT not in selected list
            fp_u, g_unopt, e_unopt = "—", None, float("nan")

        dE = (e_unopt - e_canon) if (not (e_canon!=e_canon) and not (e_unopt!=e_unopt)) else float("nan")
        print(f"{idx:>4} │ {e_canon:9.4f} │ {h_c:<12} │ "
              f"{e_unopt:9.4f} │ {Path(fp_u).name:<28} │ {dE:+.4f}")

        # -- plot the graphs so you can eyeball the difference ----------------
        if g_canon and g_unopt:
            g_unopt_can = _apply_implicit_or(g_unopt)       # <<< NEW line

            plt.figure(figsize=(7, 3.5))
            axes = plt.subplots(1, 2)[1]
            for ax, g, ttl in zip(
                    axes,
                    (g_canon, g_unopt_can),                 # <<< use canonical here
                    ("optimised canonical", "unoptimised canonical")):
                pos = nx.spring_layout(g, seed=42)
                nx.draw(g, pos, ax=ax,
                        node_size=140, font_size=7, with_labels=True)
                ax.set_title(ttl)

            plt.suptitle(f"Truth table {idx}   ΔE = {dE:+.4f}", fontsize=10)
            plt.tight_layout(); plt.show()    


def hex_from_signature(sig):
    """
    Convert a truth-table signature tuple, e.g. (0,1,1,0,0,1,1,0),
    into its hexadecimal representation.

    • The first element of `sig` is treated as the most-significant bit,
      matching the convention in your NIG file names.
    • Pads with leading zeroes so that 4 signature bits → 1 hex digit.
    """
    val = 0
    for bit in sig:
        val = (val << 1) | bit
    width = max(1, len(sig) // 4)              # 8 bits → 2 hex digits, etc.
    return f"0x{val:0{width}X}"                # uppercase hex

def bucket_unoptimised(txt_file):
    buckets = defaultdict(list)
    paths   = load_selected_paths(txt_file)

    for fp in tqdm(paths, desc="Loading unoptimised graphs", unit="graph"):
        G_orig   = load_graph_pickle(fp)
        E_orig, _ = energy_score(G_orig, check_implicit_OR_existence_v3)

        canon    = _apply_implicit_or(G_orig)
        sig      = truth_table_signature(calculate_truth_table_v2(canon))

        buckets[sig].append((fp, G_orig, E_orig))
    return buckets

def export_energies_optimized_versus_unoptimized(registry, selected_txt, csv_path=None, skip_incomplete=True, permute_inputs = True):
    """
    CSV columns:
        truth_table_hex, E_unoptimised, E_optimised
    """
    if (permute_inputs):
        can_buckets   = split_registry_by_truth_table_with_inputs_permutations(registry)
    else:
        can_buckets   = split_registry_by_truth_table(registry)
    
    unopt_buckets = bucket_unoptimised(selected_txt)

    records = []
    for sig in sorted(set(can_buckets) | set(unopt_buckets)):

        hex_tt = hex_from_signature(sig)               

        # best optimised
        if sig in can_buckets:
            e_c = min(can_buckets[sig], key=lambda t: t[3])[3]
        else:
            e_c = None

        # best unoptimised
        if sig in unopt_buckets:
            _, _, e_u = min(unopt_buckets[sig], key=lambda t: t[2])
        else:
            e_u = None

        if skip_incomplete and (e_c is None or e_u is None):
            continue

        records.append({
            "truth_table_hex": hex_tt,
            "E_unoptimised" : e_u,
            "E_optimised"   : e_c
        })

    df = pd.DataFrame(records)

    # choose output path
    if csv_path is None:
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        csv_path = Path(selected_txt).with_name(f"energy_comparison_{ts}.csv")
    else:
        csv_path = Path(csv_path)

    df.to_csv(csv_path, index=False)
    print(f"Saved {len(df)} rows to {csv_path}")

    return df

In [None]:
def lowest_bank_energy(graph, *, debug=False):
    """
    Energy of the best motif in UNIQUE_GRAPHS that shares the exact
    truth-table key with `graph`, using NO permutations.

    The key is computed with `_compute_truth_key`, which returns a tuple
    (n_inputs, int_key) – identical to the keys stored in TTABLE_TO_ACTIONS.
    """
    key   = _compute_truth_key(graph)        
    idxs  = TTABLE_TO_ACTIONS.get(key, [])

    if debug:
        print(f"[DEBUG] key = {key}  → mapped indices = {idxs}")

    if not idxs:
        return None                          # bank has no entry

    return min(
        energy_score(UNIQUE_GRAPHS[j], check_implicit_OR_existence_v3)[0]
        for j in idxs
    )



def export_energy_as_compared_to_action_motifs(
    registry,
    *,
    csv_path="energy_vs_bank.csv",
    skip_incomplete=True,
    debug=False,
    permute_inputs=True,
):
    """
    For each truth-table bucket in `registry`, write:

        truth_table_hex, E_candidates, E_bank
    """
        
    if (permute_inputs):
        buckets   = split_registry_by_truth_table_with_inputs_permutations(registry)
    else:
        buckets = split_registry_by_truth_table(registry)
    
    
    print(f"Found {len(buckets)} distinct truth tables.")

    rows = []
    for designs in buckets.values():
        g0       = designs[0][1]                # any graph in the bucket
        sig  = truth_table_signature(calculate_truth_table_v2(g0))
        hex_key = hex_from_signature(sig)            
        #key      = _compute_truth_key(g0)       # tuple (n_inputs, int_key)
        #hex_key  = hex(key[1])                  # just the integer part

        best_cand_E = min(
            energy_score(g[1], check_implicit_OR_existence_v3)[0]
            for g in designs
        )
        best_bank_E = lowest_bank_energy(g0, debug=debug)

        print(f"  key {hex_key} – bank energy: {best_bank_E}")

        if skip_incomplete and best_bank_E is None:
            continue

        rows.append({
            "truth_table_hex": hex_key,
            "E_candidates":    best_cand_E,
            "E_bank":          best_bank_E,
        })

    df = pd.DataFrame(rows)

    Path(csv_path).with_suffix(".csv").write_text(df.to_csv(index=False))
    print(f"Saved {len(df)} rows to {csv_path}")

    return df


Point to the latest shared registry in the folder

In [None]:
rx = re.compile(r"shared_registry_([\d_]+)\.pkl$")

try:
    latest = max(
        (p for p in run_dir.glob("shared_registry_*.pkl")),
        key=tag_value
    )
    print("Latest snapshot:", latest, "steps =", tag_value(latest))
    pkl_path = latest
except ValueError:
    print("No shared_registry_*.pkl found in", run_dir)    

Calculate the registry size

In [None]:
fast_registry_size(pkl_path)

Load the registry

In [None]:
registry = load_registry(pkl_path)     

Testing speed of operations on registry

In [None]:
import multiprocessing, math

manager   = multiprocessing.Manager()
registry_across_workers = manager.dict()          
multiprocessing_lock  = manager.Lock()
best_energy_across_workers  = manager.Value('d', math.inf)  

reg_path = Path(pkl_path).expanduser()
print(f"Loading registry at {reg_path}")

with reg_path.open("rb") as f:
    reg = pickle.load(f)       

    for h, bucket in reg.items():
        registry_across_workers[h] = [
            (nx.node_link_graph(canon_nl),
            nx.node_link_graph(orig_nl),
            e)
            for canon_nl, orig_nl, e in bucket
        ]
        for _c, _o, e in registry_across_workers[h]:
            if e < best_energy_across_workers.value:
                best_energy_across_workers.value = e


print(f"Loaded registry with {len(registry_across_workers)} hash buckets"
    f"best Energy = {best_energy_across_workers.value:.3f}")     

In [None]:
len(registry_across_workers)

In [55]:
first_bucket = list(registry_across_workers.values())[1]

In [57]:
first_bucket

[(<networkx.classes.digraph.DiGraph at 0x7f37a3562be0>,
  <networkx.classes.digraph.DiGraph at 0x7f37a355dd60>,
  50)]

In [58]:
first_bucket_cannon_design = first_bucket[0][0]
first_bucket_original_design = first_bucket[0][1]

In [59]:
canon = _apply_implicit_or(first_bucket_original_design)
h     = _compute_hash(canon)

In [60]:
nx.is_isomorphic(canon, first_bucket_cannon_design)

True

In [61]:
h

'69c6ea5a6682880170d8f612ccda2c50'

In [62]:
bucket = registry_across_workers.setdefault(h, [])
bucket


[(<networkx.classes.digraph.DiGraph at 0x7f36897fe7f0>,
  <networkx.classes.digraph.DiGraph at 0x7f36897feca0>,
  50)]

In [63]:
bucket

[(<networkx.classes.digraph.DiGraph at 0x7f36897fe7f0>,
  <networkx.classes.digraph.DiGraph at 0x7f36897feca0>,
  50)]

In [64]:
print(nx.is_isomorphic(canon, bucket[0][0]))
print(nx.is_isomorphic(first_bucket_cannon_design, bucket[0][0]))

True
True


In [74]:
h = 'nonexistenhash'
bucket = registry_across_workers.setdefault(h, [])
bucket

[('test', 'test', 1), ('test', 'test', 2)]

In [75]:
bucket.append(('test', 'test', 3))
registry_across_workers[h] = bucket

In [76]:
h = 'nonexistenhash'
bucket = registry_across_workers.setdefault(h, [])
bucket

[('test', 'test', 1), ('test', 'test', 2), ('test', 'test', 3)]

In [None]:
def only_singletons(registry):
    for k in registry.keys():         # keys() is cheap (no big payloads)
        if len(registry[k]) != 1:     # pulls that one list; exits early if >1
            return False
    return True

print("Only one item per bucket?", only_singletons(registry_across_workers))


In [None]:
total = 0
for k in registry_across_workers.keys():
    total += len(registry_across_workers[k])
print(total)

In [None]:
len(registry_across_workers)

In [None]:
import numpy as np
rng = np.random.default_rng(123)



In [None]:
import numpy as np

In [None]:
keys = tuple(registry_across_workers.keys())
h = keys[np.random.randint(len(keys))]
bucket = registry_across_workers[h]
canon, orig, e = bucket[np.random.randint(len(bucket))] 
current_solution = orig.copy()     

In [None]:
keys

In [None]:
keys = list(registry_across_workers.keys())
#k = int(self.np_random.integers(len(keys)))   # reproducible with your RNG
#h = keys[k]

In [None]:
import networkx as nx
from collections import Counter
import numpy as np

def analyze_buckets_iso_after_canon(
    registry,
    canon_fn,                   # e.g., _apply_implicit_or
    check_attr_key="type",      # set to None to skip attribute-aware check
    sample=None,                # int: randomly sample this many multi-item buckets (for speed)
    rng_seed=0,
    max_examples=10
):
    """
    For each bucket with size > 1:
      - Recompute canonical form: canon_fn(orig.copy())
      - Check structural isomorphism among canonical graphs
      - Optionally check attribute-aware isomorphism (node attribute == check_attr_key)

    Returns summary dict and prints a brief report.
    """
    # 1) Which buckets have >1?
    gt1_keys = [k for k,b in registry.items() if len(b) > 1]
    total_buckets = len(registry)
    num_gt1 = len(gt1_keys)

    # Optional subsample for speed
    if sample is not None and sample < num_gt1:
        rng = np.random.default_rng(rng_seed)
        gt1_keys = list(rng.choice(gt1_keys, size=sample, replace=False))
        sampled = True
    else:
        sampled = False

    size_hist = Counter(len(registry[k]) for k in gt1_keys)

    all_iso_structural = 0
    all_iso_with_attr = 0
    not_iso_examples = []

    def degsig(g):
        if g.is_directed():
            indeg = sorted(d for _, d in g.in_degree())
            outdeg = sorted(d for _, d in g.out_degree())
            return (tuple(indeg), tuple(outdeg))
        else:
            deg = sorted(d for _, d in g.degree())
            return (tuple(deg), ())

    node_match = None
    if check_attr_key is not None:
        node_match = lambda a, b: a.get(check_attr_key, None) == b.get(check_attr_key, None)

    for h in gt1_keys:
        bucket = registry[h]
        # Recompute canonical graphs from the originals
        canons = [canon_fn(orig.copy()) for _, orig, _ in bucket]
        g0 = canons[0]

        # fast prechecks
        same_counts = all(
            (g.number_of_nodes() == g0.number_of_nodes() and
             g.number_of_edges() == g0.number_of_edges())
            for g in canons[1:]
        )
        if not same_counts:
            if len(not_iso_examples) < max_examples:
                not_iso_examples.append((h, len(bucket), "node/edge mismatch after canon"))
            continue

        sig0 = degsig(g0)
        if any(degsig(g) != sig0 for g in canons[1:]):
            if len(not_iso_examples) < max_examples:
                not_iso_examples.append((h, len(bucket), "degree signature mismatch after canon"))
            continue

        # structural iso on canonical graphs
        structural_ok = all(nx.is_isomorphic(g0, g) for g in canons[1:])
        if not structural_ok:
            if len(not_iso_examples) < max_examples:
                not_iso_examples.append((h, len(bucket), "failed structural isomorphism after canon"))
            continue

        all_iso_structural += 1

        # attribute-aware iso (optional)
        if node_match is not None:
            with_attr_ok = all(nx.is_isomorphic(g0, g, node_match=node_match) for g in canons[1:])
            if with_attr_ok:
                all_iso_with_attr += 1
            else:
                if len(not_iso_examples) < max_examples:
                    not_iso_examples.append((h, len(bucket), f"failed attr isomorphism on '{check_attr_key}' after canon"))

    denom = len(gt1_keys) if gt1_keys else 0
    summary = {
        "total_buckets": total_buckets,
        "buckets_with_size_gt1": num_gt1,
        "analyzed_buckets": denom,
        "sampled": sampled,
        "size_hist_gt1": dict(sorted(size_hist.items())),
        "all_iso_structural_count": all_iso_structural,
        "all_iso_structural_pct": (100.0 * all_iso_structural / denom) if denom else 0.0,
        "all_iso_with_attr_count": all_iso_with_attr if check_attr_key is not None else None,
        "all_iso_with_attr_pct": (100.0 * all_iso_with_attr / denom) if (denom and check_attr_key is not None) else None,
        "not_iso_examples": not_iso_examples,
    }

    print(f"Total buckets: {total_buckets}")
    print(f"Buckets with >1 graph: {num_gt1}")
    print(f"Analyzed buckets: {denom}" + (" (sampled)" if sampled else ""))
    if denom:
        print("Bucket-size histogram (sizes >1):", dict(sorted(size_hist.items())))
        print(f"All-iso after canon (structural): {all_iso_structural}/{denom} "
              f"({summary['all_iso_structural_pct']:.2f}%)")
        if check_attr_key is not None:
            print(f"All-iso after canon (node_match '{check_attr_key}'): "
                  f"{all_iso_with_attr}/{denom} "
                  f"({summary['all_iso_with_attr_pct']:.2f}%)")
        if not_iso_examples:
            print("\nExamples of non-isomorphic buckets (up to first "
                  f"{len(not_iso_examples)}):")
            for key, sz, reason in not_iso_examples:
                print(f"  key={key!r}  size={sz}  reason={reason}")

    return summary

summary = analyze_buckets_iso_after_canon(
    registry_across_workers,                       # your loaded registry dict
    canon_fn=_apply_implicit_or,    # <- your canon transform
    check_attr_key="type",          # or None
    sample=None,                    # or e.g. 2000 to subsample for speed
    rng_seed=0
)

In [None]:
from dgd.environments.drl3env_loader5 import _apply_implicit_or, _compute_hash, _compute_truth_key

In [None]:
import networkx as nx
from itertools import combinations

buket_test = registry_across_workers["42eb45f79ef21f4d75c9210008069fc3"]
cannon_graphs = []
for cannon, original, energy in buket_test:
    h = _compute_hash(cannon)
    print(h)
    cannon_graphs.append(cannon)

isomorphic_pairs = []
n = len(cannon_graphs)
total = n * (n - 1) // 2
print(f"Checking {total} pairs...")

for (i, Gi), (j, Gj) in combinations(enumerate(cannon_graphs), 2):
    iso = nx.is_isomorphic(Gi, Gj)
    print(f"({i}, {j}): {'isomorphic' if iso else 'not isomorphic'}")
    if iso:
        isomorphic_pairs.append((i, j))

print("\nSummary:")
print(f"Graphs: {n}")
print(f"Pairs checked: {total}")
print(f"Isomorphic pairs: {len(isomorphic_pairs)}")
print(isomorphic_pairs)


In [None]:
for g in cannon_graphs:
    print(_compute_truth_key(g))   
    

In [None]:
def plot_side_by_side(graphs, k=None, indices=None, seed=7, with_labels=True):
    """
    Plot graphs side-by-side.
    - Set k to plot the first k graphs.
    - Or pass explicit indices=[i,j,k].
    """
    if not graphs:
        raise ValueError("No graphs provided.")

    if indices is not None:
        sel = list(indices)
    else:
        k = len(graphs) if k is None else min(k, len(graphs))
        sel = list(range(k))

    fig, axes = plt.subplots(1, len(sel), figsize=(4*len(sel), 4))
    if len(sel) == 1:
        axes = [axes]

    for ax, i in zip(axes, sel):
        G = graphs[i]
        pos = nx.spring_layout(G, seed=seed)
        nx.draw_networkx(G, pos=pos, ax=ax, with_labels=with_labels,
                         node_size=400, font_size=8, arrows=G.is_directed())
        ax.set_title(f"G{i}: n={G.number_of_nodes()}, m={G.number_of_edges()}")
        ax.axis("off")

    plt.tight_layout()
    plt.show()

# Examples:
# First 3:
plot_side_by_side(cannon_graphs, k=3, seed=7)

In [None]:
h = "432f"
bucket = registry_across_workers.setdefault(h, [])
if not bucket:
    print("No encontre")

In [None]:
from networkx.algorithms import weisfeiler_lehman_graph_hash 
for g in cannon_graphs:      
    h = weisfeiler_lehman_graph_hash(g, node_attr=None, edge_attr=None, iterations=100, digest_size=16)
    print(h)

In [None]:
import numpy as np   

bi = np.random.randint(len(registry_across_workers))
bucket = registry_across_workers[bi]
item_idx = np.random.randint(len(bucket))
canon, orig, e = bucket[item_idx]

In [None]:
import numpy as np   

In [None]:
keys = list(registry_across_workers.keys())
len(keys)


In [None]:
total = sum(len(b) for b in buckets)
if total == 0:
    raise ValueError("Registry is empty")

# pick an index in [0, total)
r = np.random.randint(total)   # instead of int(np.random.integers(total))

for b in buckets:
    if r < len(b):
        _, orig, _ = b[r]
        current_solution = orig.copy()
        break
    r -= len(b)

In [None]:
reg_items  = [item for bucket in registry_across_workers.values() for item in bucket]



In [None]:
pool     = [orig for _, orig, _ in reg_items]
energies = [e    for _, _,    e in reg_items]

In [None]:
p = 0
weights = [1.0 / (e ** p) for e in energies]

# self.np_random draw proportional to the weights
#self.current_solution = random.choices(pool, weights=weights, k=1)[0].copy()
weights = np.asarray(weights, dtype=np.float64)
prob    = weights / weights.sum()          # must sum to 1 for NumPy
idx     = np.random.choice(len(pool), p=prob)
current_solution = pool[idx].copy()

In [None]:
existing_keys = ["0x13CE, 0x4A32"]

In [None]:
inputs = [n for n in current_solution if current_solution.in_degree(n) == 0]       
perms = list(itertools.permutations(inputs))
np.random.shuffle(perms)
for perm in perms:                     # iterate without replacement
    mapping  = dict(zip(inputs, perm))
    g_perm   = nx.relabel_nodes(current_solution, mapping, copy=True)
    
    if _compute_truth_key(g_perm) in existing_keys: 
        current_solution = g_perm
        break # stop at the first valid permutation     

In [None]:
current_solution

Calcualte the registry size after loading

In [None]:
registry_size(registry)

Specify file with database graph paths

In [None]:
selected_txt = run_dir / "selected_graphs.txt"

Compare optimized and unoptimized graphs and save the result

In [None]:
file = run_dir/"energies_optimized_versus_unoptimized.csv"
df = export_energies_optimized_versus_unoptimized(registry, selected_txt, file, permute_inputs = False)
df

Compare optimized graphs versus action motifs

In [None]:
MOTIFS_PATH = "/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/scripts/action_motifs.pkl"
with open(MOTIFS_PATH, "rb") as f:
    action_motifs = pickle.load(f)

UNIQUE_GRAPHS = action_motifs["graphs"]       
TTABLE_TO_ACTIONS = action_motifs["lookup"]


In [None]:
file = run_dir/"energies_optimized_versus_action_motifs.csv"
df2 = export_energy_as_compared_to_action_motifs(registry, csv_path=file, debug = True, permute_inputs = False)
df2

Other functions to expand on later

In [None]:
stop

In [None]:
buckets_by_truth_table = split_registry_by_truth_table_with_inputs_permutations(registry)
print(len(buckets_by_truth_table))

run_analysis_per_bucket(registry, permute_inputs = False)   

hex_id = "0xB3"
plot_bucket_by_hex(registry, hex_id = hex_id, permute_inputs=False, show=True)