In [1]:
import pickle, time
import matplotlib.pyplot as plt
import networkx as nx
import itertools, re
from collections import defaultdict
import numpy as np

In [6]:
from dgd.utils.utils5 import (
    calculate_truth_table_v2,
    generate_one_hot_features_from_adj,
    resize_matrix,
    energy_score,
    check_implicit_OR_existence_v3,
    add_implicit_OR_to_dag_v2,
    exhaustive_cut_enumeration_dag,
    is_fanout_free_standalone,
    generate_subgraph,
)


def load_registry(pkl_file):
    """Load the pickle and rebuild NetworkX graphs."""
    with open(pkl_file, "rb") as f:
        saved = pickle.load(f)

    registry = {}
    for h, bucket in saved.items():
        restored = []
        for canon_nl, orig_nl, e in bucket:
            canon = nx.node_link_graph(canon_nl)
            orig  = nx.node_link_graph(orig_nl)
            restored.append((canon, orig, e))
        registry[h] = restored
    return registry


def registry_size(pkl_file):
    reg = load_registry(pkl_file)
    length = sum(len(b) for b in reg.values())
    print(f"Registry length: {length}")
    return length


def draw_pair(canon, orig, h, e, seed=42):
    fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
    for ax, g, ttl in zip(axes, (canon, orig), ("canonical", "original")):
        pos = nx.spring_layout(g, seed=seed)
        nx.draw(g, pos, ax=ax, with_labels=True, node_size=100, font_size=7)
        ax.set_title(ttl)
    fig.suptitle(f"hash={h}   energy={e:.3f}", fontsize=10)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


def top_lowest_energy_plots(pkl_file, n=20, pause=0):
    """
    Show a ranked table and inline plots for the `n` lowest-energy entries.

    Parameters
    ----------
    pkl_file : str or Path
        Pickled registry file.
    n : int, optional (default 20)
        Number of graph pairs to display.
    pause : float, optional (default 0)
        Seconds to wait between figures.
    """
    reg = load_registry(pkl_file)
    flat = [
        (e, h, canon, orig)
        for h, bucket in reg.items()
        for canon, orig, e in bucket
    ]
    flat.sort(key=lambda t: t[0])

    print(f"\nTop-{n} by lowest energy:\n")
    print(f"{'rank':>4} │ {'energy':>10} │ {'hash':<16} │ nodes")
    print("────┼────────────┼────────────────┼──────")
    for i, (e, h, canon, orig) in enumerate(flat[:n], 1):
        print(f"{i:>4} │ {e:10.4f} │ {h:<16} │ {orig.number_of_nodes():>5}")
        draw_pair(canon, orig, h, e)
        if pause:
            time.sleep(pause)

            

def iso_pairs_lowest(pkl_file, top=20):
    """
    Among the `top` lowest-energy *canonical* graphs, print every pair that
    is isomorphic.  Returns a list of (i, j) index pairs.
    """
    reg = load_registry(pkl_file)

    # flatten and keep canonical graph
    flat = [(e, h, canon)
            for h, bucket in reg.items()
            for canon, _, e in bucket]
    flat.sort(key=lambda t: t[0])
    flat = flat[:top]                       # lowest-energy slice

    duplicates = []
    for (i, (e_i, h_i, g_i)), (j, (e_j, h_j, g_j)) in itertools.combinations(enumerate(flat), 2):
        if nx.is_isomorphic(g_i, g_j):
            duplicates.append((i, j))
            print(f"Duplicate pair: idx {i} ↔ {j}   "
                  f"energies {e_i:.4f} / {e_j:.4f}   "
                  f"hashes {h_i} / {h_j}")

    if not duplicates:
        print(f"No canonical duplicates among the lowest {top} energies.")

    return duplicates


# ────────────────── edge-pruning + iso-dedup ──────────────────
def remove_redundant_edges(g):
    """Return a copy of `g` with every non-essential edge pruned."""
    g = g.copy()
    tt_ref = calculate_truth_table_v2(g)
    changed = True
    while changed:
        changed = False
        for u, v in list(g.edges()):
            g_tmp = g.copy()
            g_tmp.remove_edge(u, v)
            if calculate_truth_table_v2(g_tmp) == tt_ref:
                g.remove_edge(u, v)
                changed = True
    return g

def unique_by_isomorphism(graphs):
    uniq = []
    for g in graphs:
        if not any(nx.is_isomorphic(g, h) for h in uniq):
            uniq.append(g)
    return uniq


def plot_graphs(graphs, cols=3, seed=42):
    rows = (len(graphs) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    axes = axes.ravel()
    for ax, g in zip(axes, graphs):
        pos = nx.spring_layout(g, seed=seed)
        nx.draw(g, pos, ax=ax, with_labels=True, node_size=100, font_size=7)
        ax.set_title(f"nodes={g.number_of_nodes()}  edges={g.number_of_edges()}")
    for ax in axes[len(graphs):]:
        ax.axis("off")
    plt.tight_layout()
    plt.show()


def prune_and_plot_optimal(pkl_file):
    """
    • keep every graph at the absolute minimum energy
    • remove redundant edges in each
    • drop isomorphic duplicates
    • plot & print truth tables
    • RETURN the list of unique pruned graphs
    """
    reg = load_registry(pkl_file)

    # -- keep only the optimal energy ----------------------------------------
    flat = [(e, canon) for bucket in reg.values() for canon, _, e in bucket]
    best_energy = min(e for e, _ in flat)
    best_graphs = [canon for e, canon in flat if e == best_energy]
    print(f"Optimal energy: {best_energy:.4f}   raw count: {len(best_graphs)}")

    # -- prune redundant edges -----------------------------------------------
    pruned = [remove_redundant_edges(g) for g in best_graphs]

    # -- drop isomorphic duplicates ------------------------------------------
    unique = unique_by_isomorphism(pruned)
    print(f"{len(unique)} unique graph(s) remain after pruning + iso check.\n")

    # -- plot each graph and print its truth table ---------------------------
    for idx, g in enumerate(unique, 1):
        # plot
        plt.figure(figsize=(4, 4))
        pos = nx.spring_layout(g, seed=42)
        nx.draw(g, pos, with_labels=True, node_size=100, font_size=7)
        plt.title(f"Graph {idx}   nodes={g.number_of_nodes()}  edges={g.number_of_edges()}")
        plt.show()

        # truth table
        tt = calculate_truth_table_v2(g)
        print(f"Truth table for Graph {idx}:")
        for inputs, out in sorted(tt.items()):
            print(f"  {inputs}  →  {out[0]}")
        print()

    return unique


def energies_from_log(txt_file):
    pat = re.compile(r"^\s*(?:\Selected)\s+(.*_NIG_unoptimized\.pkl)$")
    paths = [m.group(1).strip() for line in open(txt_file)
             if (m := pat.match(line))]
    if not paths:
        print("No circuit paths found in the file."); return

    rows = []
    for fp in tqdm(paths, desc="Processing circuits", unit="circuit"):
        G = load_graph_pickle(fp)
        E_orig, _ = energy_score(G, check_implicit_OR_existence_v3)

        canon = _apply_implicit_or(G)
        E_canon, _ = energy_score(canon, check_implicit_OR_existence_v3)

        hex_id = Path(fp).name.split("_")[0]
        rows.append((hex_id, E_orig, E_canon))

    # pretty print
    print(f"\nEnergy summary ({len(rows)} circuits):\n")
    print(f"{'hex':<6} │ {'orig':>8} │ {'canon':>8}")
    print("───────┼──────────┼──────────")
    for h, e1, e2 in rows:
        print(f"{h:<6} │ {e1:8.3f} │ {e2:8.3f}")

def load_graph_pickle(filename):
    num_nodes, edges, node_attrs = pickle.load(open(filename, "rb"))
    G = nx.DiGraph()
    for n, attr in node_attrs.items():
        G.add_node(n, type=attr) if attr is not None else G.add_node(n)
    G.add_edges_from(edges)
    return G



In [7]:

pkl_path = "runs/20250609_run1_design_circuits_10logicfunctions_noregistrysampling/shared_registry_17_600.pkl" 

In [8]:
# --- Load the registry ---
reg = load_registry(pkl_path)

print(f"Loaded registry with {len(reg)} original hash buckets (v1 truth tables).")

Loaded registry with 36 original hash buckets (v1 truth tables).


In [9]:
# --- Helper: group by truth table as returned by calculate_truth_table_v2 ---

def group_by_truth_table_v2(registry):
    """Re‑groups the registry using calculate_truth_table_v2.

    Parameters
    ----------
    registry : dict | iterable
        The registry as produced by the search (either a mapping `hash->bucket`
        or a flat iterable of `(canon_graph, orig_graph, energy)` tuples).

    Returns
    -------
    dict[tuple[int]] -> list[tuple]
        Mapping a canonical truth‑table (returned as a tuple of ints) to all
        design tuples that implement it.
    """
    groups = defaultdict(list)

    # Flatten if registry is the usual dict-of-buckets
    if isinstance(registry, dict):
        iterable = itertools.chain.from_iterable(registry.values())
    else:
        iterable = registry

    for canon, orig, energy in iterable:
        tt = tuple(calculate_truth_table_v2(canon))
        groups[tt].append((canon, orig, energy))

    return groups

# Group the registry
reg_by_tt = group_by_truth_table_v2(reg)
print(f"Grouped into {len(reg_by_tt)} distinct truth tables using calculate_truth_table_v2.")

Grouped into 1 distinct truth tables using calculate_truth_table_v2.


In [10]:
# --- Helper: run the original analysis pipeline for *one* truth table group ---

def analyse_truth_table(tt_key, bucket, n_top=5, max_show=5):
    """Re‑runs the original analysis cells for a single truth‑table group.

    Parameters
    ----------
    tt_key : tuple[int]
        Canonical representation of the truth table.
    bucket : list[tuple]
        List of `(canon_graph, orig_graph, energy)` triples for this truth table.
    """
    print("\n" + "="*100)
    print(f"Truth table: {tt_key}   |   designs = {len(bucket)}")
    print("="*100)

    # ----- Registry size -----
    print("Total designs in this group:", len(bucket))

    # ----- Visualise top designs by energy -----
    # Reuse your helper by wrapping the bucket in a temporary registry‑like dict
    temp_reg = {0: bucket}
    top_lowest_energy_plots(temp_reg, n=n_top)

    # ----- Isomorphism analysis -----
    iso_pairs_lowest(temp_reg, top=n_top)

    # ----- Redundancy pruning & optimal plots -----
    unique_graphs = prune_and_plot_optimal(temp_reg)

    # ----- Motif comparison (optional) -----
    try:
        check_vs_motif_bank(unique_graphs)
    except NameError:
        print("↳ motif‑bank helpers not available – skipping this step.")

In [11]:
# --- Example: analyse just the first truth table to sanity‑check ---
first_tt, first_bucket = next(iter(reg_by_tt.items()))
analyse_truth_table(first_tt, first_bucket)


Truth table: ((0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 1, 1), (0, 1, 0, 0), (0, 1, 0, 1), (0, 1, 1, 0), (0, 1, 1, 1), (1, 0, 0, 0), (1, 0, 0, 1), (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 0, 0), (1, 1, 0, 1), (1, 1, 1, 0), (1, 1, 1, 1))   |   designs = 36
Total designs in this group: 36


TypeError: unhashable type: 'dict'

In [12]:
# --- Full batch analysis ---
for tt_key, bucket in reg_by_tt.items():
    analyse_truth_table(tt_key, bucket)


Truth table: ((0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 1, 1), (0, 1, 0, 0), (0, 1, 0, 1), (0, 1, 1, 0), (0, 1, 1, 1), (1, 0, 0, 0), (1, 0, 0, 1), (1, 0, 1, 0), (1, 0, 1, 1), (1, 1, 0, 0), (1, 1, 0, 1), (1, 1, 1, 0), (1, 1, 1, 1))   |   designs = 36
Total designs in this group: 36


TypeError: unhashable type: 'dict'