To launch in SuperCloud from a Computed Node


LLsub -i full #for an exclusive node

LLsub -i -s 40 #for node with 40 CPUs

LLsub -i -s 40 -g volta:1 #for node with 40 CPUs and 1 Volta GPU

salloc  --job-name=interactive --qos=high --time=00:60:00 --partition=debug-gpu --gres=gpu:volta:1 --cpus-per-task=40 srun    --pty bash -i

salloc  --job-name=interactive --qos=high --time=00:60:00 --partition=debug-cpu --cpus-per-task=40 srun  --pty bash -i

module load anaconda/2023a-pytorch

jupyter lab --no-browser --ip=0.0.0.0 --port=8890

Last server: http://d-5-3-4:8890/lab?token=39d6d1b162255e18b8d6ba30b7db7087eef9d08eeed2fa99

In [1]:
import pickle, time
import matplotlib.pyplot as plt
import networkx as nx
import itertools
import re
from pathlib import Path
import pandas as pd
from typing import List, Dict
from tqdm.notebook import tqdm   
import numpy as np        

from dgd.environments.drl3env_loader4 import _apply_implicit_or, _compute_hash, _compute_truth_key

from dgd.utils.utils5 import (
    calculate_truth_table_v2,
    generate_one_hot_features_from_adj,
    resize_matrix,
    energy_score,
    check_implicit_OR_existence_v3,
    add_implicit_OR_to_dag_v2,
    exhaustive_cut_enumeration_dag,
    is_fanout_free_standalone,
    generate_subgraph,
)

Done loading action motifs. There are 15928 unique motifs.


In [6]:
MOTIFS_PATH = "/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/scripts/action_motifs.pkl"
with open(MOTIFS_PATH, "rb") as f:
    action_motifs = pickle.load(f)

UNIQUE_GRAPHS = action_motifs["graphs"]       
TTABLE_TO_ACTIONS = action_motifs["lookup"]

General utility functions

In [18]:
def load_registry(pkl_file):
    """Load the pickle and rebuild NetworkX graphs."""
    with open(pkl_file, "rb") as f:
        saved = pickle.load(f)

    registry = {}
    for h, bucket in saved.items():
        restored = []
        for canon_nl, orig_nl, e in bucket:
            canon = nx.node_link_graph(canon_nl)
            orig  = nx.node_link_graph(orig_nl)
            restored.append((canon, orig, e))
        registry[h] = restored
    return registry

def registry_size(reg):
    #reg = load_registry(pkl_file)
    length = sum(len(b) for b in reg.values())
    print(f"Registry length: {length}")
    return length

def fast_registry_size(pkl_file):
    """Fast‐path: load the pickle and count (canon, orig, e) triples
       without rebuilding NetworkX graphs."""
    with open(pkl_file, "rb") as f:
        saved = pickle.load(f)          

    # each bucket is a list of (canon_nl, orig_nl, e) tuples
    return sum(len(bucket) for bucket in saved.values())


def draw_pair(canon, orig, h, e, seed=42):
    fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
    for ax, g, ttl in zip(axes, (canon, orig), ("canonical", "original")):
        pos = nx.spring_layout(g, seed=seed)
        nx.draw(g, pos, ax=ax, with_labels=True, node_size=100, font_size=7)
        ax.set_title(ttl)
    fig.suptitle(f"hash={h}   energy={e:.3f}", fontsize=10)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


def top_lowest_energy_plots(reg, n=20, pause=0):
    """
    Show a ranked table and inline plots for the `n` lowest-energy entries.

    Parameters
    ----------
    pkl_file : str or Path
        Pickled registry file.
    n : int, optional (default 20)
        Number of graph pairs to display.
    pause : float, optional (default 0)
        Seconds to wait between figures.
    """
    #reg = load_registry(pkl_file)
    flat = [
        (e, h, canon, orig)
        for h, bucket in reg.items()
        for canon, orig, e in bucket
    ]
    flat.sort(key=lambda t: t[0])

    print(f"\nTop-{n} by lowest energy:\n")
    print(f"{'rank':>4} │ {'energy':>10} │ {'hash':<16} │ nodes")
    print("────┼────────────┼────────────────┼──────")
    for i, (e, h, canon, orig) in enumerate(flat[:n], 1):
        print(f"{i:>4} │ {e:10.4f} │ {h:<16} │ {orig.number_of_nodes():>5}")
        draw_pair(canon, orig, h, e)
        if pause:
            time.sleep(pause)

            

def iso_pairs_lowest(reg, top=20):
    """
    Among the `top` lowest-energy *canonical* graphs, print every pair that
    is isomorphic.  Returns a list of (i, j) index pairs.
    """
    #reg = load_registry(pkl_file)

    # flatten and keep canonical graph
    flat = [(e, h, canon)
            for h, bucket in reg.items()
            for canon, _, e in bucket]
    flat.sort(key=lambda t: t[0])
    flat = flat[:top]                       # lowest-energy slice

    duplicates = []
    for (i, (e_i, h_i, g_i)), (j, (e_j, h_j, g_j)) in itertools.combinations(enumerate(flat), 2):
        if nx.is_isomorphic(g_i, g_j):
            duplicates.append((i, j))
            print(f"Duplicate pair: idx {i} ↔ {j}   "
                  f"energies {e_i:.4f} / {e_j:.4f}   "
                  f"hashes {h_i} / {h_j}")

    if not duplicates:
        print(f"No canonical duplicates among the lowest {top} energies.")

    return duplicates


def remove_redundant_edges(g):
    """Return a copy of `g` with every non-essential edge pruned."""
    g = g.copy()
    tt_ref = calculate_truth_table_v2(g)
    changed = True
    while changed:
        changed = False
        for u, v in list(g.edges()):
            g_tmp = g.copy()
            g_tmp.remove_edge(u, v)
            if calculate_truth_table_v2(g_tmp) == tt_ref:
                g.remove_edge(u, v)
                changed = True
    return g

def unique_by_isomorphism(graphs):
    uniq = []
    for g in graphs:
        if not any(nx.is_isomorphic(g, h) for h in uniq):
            uniq.append(g)
    return uniq


def plot_graphs(graphs, cols=3, seed=42):
    rows = (len(graphs) + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    axes = axes.ravel()
    for ax, g in zip(axes, graphs):
        pos = nx.spring_layout(g, seed=seed)
        nx.draw(g, pos, ax=ax, with_labels=True, node_size=100, font_size=7)
        ax.set_title(f"nodes={g.number_of_nodes()}  edges={g.number_of_edges()}")
    for ax in axes[len(graphs):]:
        ax.axis("off")
    plt.tight_layout()
    plt.show()


def prune_and_plot_optimal(reg):
    """
    • keep every graph at the absolute minimum energy
    • remove redundant edges in each
    • drop isomorphic duplicates
    • plot & print truth tables
    • RETURN the list of unique pruned graphs
    """
    #reg = load_registry(pkl_file)

    # -- keep only the optimal energy ----------------------------------------
    flat = [(e, canon) for bucket in reg.values() for canon, _, e in bucket]
    best_energy = min(e for e, _ in flat)
    best_graphs = [canon for e, canon in flat if e == best_energy]
    print(f"Optimal energy: {best_energy:.4f}   raw count: {len(best_graphs)}")

    # -- prune redundant edges -----------------------------------------------
    pruned = [remove_redundant_edges(g) for g in best_graphs]

    # -- drop isomorphic duplicates ------------------------------------------
    unique = unique_by_isomorphism(pruned)
    print(f"{len(unique)} unique graph(s) remain after pruning + iso check.\n")

    # -- plot each graph and print its truth table ---------------------------
    for idx, g in enumerate(unique, 1):
        # plot
        plt.figure(figsize=(4, 4))
        pos = nx.spring_layout(g, seed=42)
        nx.draw(g, pos, with_labels=True, node_size=100, font_size=7)
        plt.title(f"Graph {idx}   nodes={g.number_of_nodes()}  edges={g.number_of_edges()}")
        plt.show()

        # truth table
        tt = calculate_truth_table_v2(g)
        print(f"Truth table for Graph {idx}:")
        for inputs, out in sorted(tt.items()):
            print(f"  {inputs}  →  {out[0]}")
        print()

    return unique


def energies_from_log(txt_file):
    pat = re.compile(r"^\s*(?:\Selected)\s+(.*_NIG_unoptimized\.pkl)$")
    paths = [m.group(1).strip() for line in open(txt_file)
             if (m := pat.match(line))]
    if not paths:
        print("No circuit paths found in the file."); return

    rows = []
    for fp in tqdm(paths, desc="Processing circuits", unit="circuit"):
        G = load_graph_pickle(fp)
        E_orig, _ = energy_score(G, check_implicit_OR_existence_v3)

        canon = _apply_implicit_or(G)
        E_canon, _ = energy_score(canon, check_implicit_OR_existence_v3)

        hex_id = Path(fp).name.split("_")[0]
        rows.append((hex_id, E_orig, E_canon))

    df = pd.DataFrame(rows,
                      columns=["truth_table_hex",
                               "E_original",
                               "E_canonical"])

    # pretty print
    print(f"\nEnergy summary ({len(rows)} circuits):\n")
    print(f"{'hex':<6} │ {'orig':>8} │ {'canon':>8}")
    print("───────┼──────────┼──────────")
    for h, e1, e2 in rows:
        print(f"{h:<6} │ {e1:8.3f} │ {e2:8.3f}")
 
    return df       
        
        

def load_graph_pickle(filename):
    num_nodes, edges, node_attrs = pickle.load(open(filename, "rb"))
    G = nx.DiGraph()
    for n, attr in node_attrs.items():
        G.add_node(n, type=attr) if attr is not None else G.add_node(n)
    G.add_edges_from(edges)
    return G

def energy_from_pickle(fp):
    """
    Compute original and canonical energies for a single pickled circuit.

    Parameters
    ----------
    fp : str | Path
        Full path to *_NIG_unoptimized.pkl

    Returns
    -------
    pandas.DataFrame with one row (hex_id, E_original, E_canonical)
    """
    G = load_graph_pickle(fp)
    E_orig, _  = energy_score(G, check_implicit_OR_existence_v3)

    canon      = _apply_implicit_or(G)
    E_canon, _ = energy_score(canon, check_implicit_OR_existence_v3)

    return pd.DataFrame(
        [(Path(fp).name.split('_')[0], E_orig, E_canon)],
        columns=["truth_table_hex", "E_original", "E_canonical"],
    )

    
def compare_to_motif_bank_simple(unique_graphs, *, verbose = False):
    """Lightweight comparison of *candidate* graphs to the motif bank.

    Parameters
    ----------
    unique_graphs
        A list of candidate graphs that all realise **one** common truth table.
    verbose
        If *True*, prints the key, bank indices, energies, and the final
        *results* list for manual inspection.

    Returns
    -------
    List[dict]
        A list where each element has the fields::

            {
                "idx"             : <index of the candidate graph>,
                "key"             : <truth‑table key>,
                "own_energy"      : <energy of candidate>,
                "best_bank_energy": <lowest energy among matching bank motifs
                                      or *None* if no match>,
                "delta"           : <own_energy − best_bank_energy> or *None*,
            }
    """
    # -- 0) Sanity checks -----------------------------------------------------
    if not unique_graphs:
        raise ValueError("unique_graphs is empty")

    ref_key = _compute_truth_key(unique_graphs[0])
    
    
    if any(_compute_truth_key(g) != ref_key for g in unique_graphs[1:]):
        raise ValueError("All graphs must share the same truth‑table key")

    # -- 1) Retrieve bank motifs known to implement this truth table ----------
    bank_idxs = TTABLE_TO_ACTIONS.get(ref_key, [])
    if verbose:
        print(f"Truth‑table key: {ref_key}\nBank motif indices: {bank_idxs}\n")

    # -- 2) Compute energies --------------------------------------------------
    bank_energies: Dict[int, float] = {
        j: energy_score(UNIQUE_GRAPHS[j], check_implicit_OR_existence_v3)[0]
        for j in bank_idxs
    }
    best_bank_energy = min(bank_energies.values()) if bank_energies else None

    candidate_energies = [
        energy_score(g, check_implicit_OR_existence_v3)[0] for g in unique_graphs
    ]

    # -- 3) Assemble *results* list ------------------------------------------
    results = []
    for idx, own_e in enumerate(candidate_energies):
        delta = None if best_bank_energy is None else own_e - best_bank_energy
        results.append(
            dict(
                idx=idx,
                key=ref_key,
                own_energy=own_e,
                best_bank_energy=best_bank_energy,
                delta=delta,
            )
        )

    if verbose:
        print("Candidate energies:", candidate_energies)
        print("Bank energies     :", bank_energies)
        print("\nResults list:")
        for r in results:
            print(r)

    return results    


Check size of unoptimized NIG (3 inputs)

In [3]:
circuit_hex = "0xD9"
df = energy_from_pickle(f"/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/dgd/data/NIGs_3_inputs/{circuit_hex}_NIG_unoptimized.pkl")
print(df)

  truth_table_hex  E_original  E_canonical
0            0xD9          27           27


Check size of unoptimized NIG (4 inputs)

In [25]:
#Individual
circuit_hex = "0x22C6"
df = energy_from_pickle(f"/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/dgd/data/NIGs_4_inputs/{circuit_hex}_NIG_unoptimized.pkl")
print(df)

  truth_table_hex  E_original  E_canonical
0          0x22C6          43           43


In [26]:
G = load_graph_pickle(f"/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/dgd/data/NIGs_4_inputs/{circuit_hex}_NIG_unoptimized.pkl")
G

<networkx.classes.digraph.DiGraph at 0x7fd56b23bd90>

In [27]:
df2 = compare_to_motif_bank_simple([G], verbose=True)
pd.DataFrame(df2).head()

Truth‑table key: (4, 8902)
Bank motif indices: [3309, 4605]

Candidate energies: [43]
Bank energies     : {3309: 11, 4605: 11}

Results list:
{'idx': 0, 'key': (4, 8902), 'own_energy': 43, 'best_bank_energy': 11, 'delta': 32}


Unnamed: 0,idx,key,own_energy,best_bank_energy,delta
0,0,"(4, 8902)",43,11,32


Many graphs

In [15]:
#as a group
import warnings 

circuit_hexes = [
    # block 1
    "0x3133","0x5155","0x0D0F","0x4555","0x0B0F","0x2333","0x3313","0x5515",
    "0x00DF","0x00BF","0x0F07","0x00F7",
    # block 2
    "0x00EF","0x00FB","0x00FD","0x0E0F","0x0F0B","0x0F0D","0x3233","0x3323",
    "0x3331","0x5455","0x5545","0x5551",
    # block 3
    "0x0FD5","0x5B1B","0x3D1D","0x6727","0x3D35","0x6747","0x5B53","0x51BB",
    "0x31DD","0x6277","0x33D5","0x6477","0x55B3","0x45AF","0x4A5F","0x0DF5",
    "0x585F","0x558F","0x23CF","0x0BF3","0x2C3F","0x0FB3","0x383F","0x338F",
    # block 4
    "0x22C6","0x44A6","0x0AD2","0x509A","0x0CB4","0x309C","0x2C26","0x4A46",
    "0x0DA2","0x590A","0x0BC4","0x390C","0x381A","0x6252","0x318A","0x6522",
    "0x23D0","0x2D30","0x581C","0x6434","0x518C","0x6344","0x45B0","0x4B50",
]

base_dir = Path("/home/gridsan/spalacios/Designing complex biological circuits with deep neural networks/dgd/data/NIGs_4_inputs")

# ──────────────────────────────────────────────────────────────────────────────
# 3. Compute energies, skipping anything missing
# ──────────────────────────────────────────────────────────────────────────────
rows, missing = [], []

for hx in tqdm(circuit_hexes, desc="Processing circuits", unit="circuit"):
    fp = base_dir / f"{hx}_NIG_unoptimized.pkl"
    try:
        df_one = energy_from_pickle(fp)
        rows.append(df_one.iloc[0])
    except FileNotFoundError:
        warnings.warn(f"❌ Skipping {hx}: file not found")
        missing.append(hx)
    except Exception as e:
        warnings.warn(f"⚠️  Skipping {hx}: {e}")
        missing.append(hx)

if not rows:
    raise RuntimeError("No circuits were processed successfully.")

df = pd.DataFrame(rows)

# ──────────────────────────────────────────────────────────────────────────────
# 4. Summary
# ──────────────────────────────────────────────────────────────────────────────
print(f"\nEnergy summary ({len(df)} circuits processed, {len(missing)} skipped):\n")
print(f"{'hex':<6} │ {'orig':>8} │ {'canon':>8}")
print("───────┼──────────┼──────────")
for h, e1, e2 in df.itertuples(index=False):
    print(f"{h:<6} │ {e1:8.3f} │ {e2:8.3f}")

if missing:
    print("\nFiles not found:")
    print(", ".join(missing))

Processing circuits:   0%|          | 0/72 [00:00<?, ?circuit/s]




Energy summary (70 circuits processed, 2 skipped):

hex    │     orig │    canon
───────┼──────────┼──────────
0x3133 │   48.000 │   48.000
0x5155 │   57.000 │   57.000
0x0D0F │   43.000 │   43.000
0x4555 │   56.000 │   56.000
0x0B0F │   42.000 │   42.000
0x2333 │   47.000 │   47.000
0x3313 │   48.000 │   48.000
0x5515 │   57.000 │   57.000
0x00DF │   43.000 │   43.000
0x00BF │   42.000 │   42.000
0x0F07 │   43.000 │   43.000
0x00F7 │   43.000 │   43.000
0x00EF │   42.000 │   42.000
0x00FB │   42.000 │   42.000
0x00FD │   43.000 │   43.000
0x0E0F │   42.000 │   42.000
0x0F0B │   42.000 │   42.000
0x0F0D │   43.000 │   43.000
0x3233 │   47.000 │   47.000
0x3323 │   47.000 │   47.000
0x3331 │   48.000 │   48.000
0x5455 │   56.000 │   56.000
0x5545 │   56.000 │   56.000
0x5551 │   57.000 │   57.000
0x0FD5 │   60.000 │   60.000
0x5B1B │   63.000 │   63.000
0x3D1D │   62.000 │   62.000
0x3D35 │   62.000 │   62.000
0x5B53 │   64.000 │   64.000
0x51BB │   63.000 │   63.000
0x31DD │   62.000 