In [None]:

def prune(self) -> None:
    """
    Prune away any node that doesn’t contribute to the final output,
    and reindex each node’s X_cols accordingly.
    """
    if not self.layers:
        raise RuntimeError("Cannot prune before calling fit(); no layers present.")

    n_layers = len(self.layers)
    # 1) start from the final node (last layer, index 0)
    keep = [set() for _ in range(n_layers)]
    keep[n_layers - 1].add(0)

    # 2) walk backwards, marking dependencies
    for L in range(n_layers - 1, 0, -1):
        for j in keep[L]:
            node = self.layers[L][j]
            for prev_idx in node.X_cols:
                keep[L - 1].add(int(prev_idx))

    # 3) prune & rebuild each layer, building old→new index maps
    mappings: list[dict[int,int]] = []
    for L in range(n_layers):
        survivors = sorted(keep[L])
        mapper = {old: new for new, old in enumerate(survivors)}
        mappings.append(mapper)
        self.layers[L] = [self.layers[L][i] for i in survivors]

    # 4) reindex X_cols for nodes in all but the input layer
    for L in range(1, n_layers):
        for node in self.layers[L]:
            node.X_cols = np.array(
                [mappings[L - 1][int(c)] for c in node.X_cols],
                dtype=int
            )

def describe_architecture_graph(self) -> dict[str, dict | list]:
    """
    Snapshot only the nodes (including inputs) that actually reach the final output.
    """
    # 1) build full connection map for hidden+output
    conn = {
        (l+1, i): node.X_cols.tolist()
        for l, layer in enumerate(self.layers)
        for i, node in enumerate(layer)
    }

    # 2) traverse backwards from final (layer L, idx 0)
    last = len(self.layers)
    stack = [(last, 0)]
    visited = set(stack)
    while stack:
        l, i = stack.pop()
        for c in conn.get((l, i), []):
            prev = (l-1, c)
            if prev not in visited:
                visited.add(prev)
                stack.append(prev)

    # 3) collect relevant inputs
    relevant_inputs = sorted(j for (l, j) in visited if l == 0)

    # 4) build filtered layers_info & connections
    layers_info = [{'nodes': len(relevant_inputs), 'bits_per_node': None}]
    for l, layer in enumerate(self.layers, start=1):
        kept = [i for i, _ in enumerate(layer) if (l, i) in visited]
        layers_info.append({'nodes': len(kept), 'bits_per_node': layer[0].X_cols.size if layer else 0})

    filtered_conn = {n: cols for n, cols in conn.items() if n in visited}

    return {
        'layers': layers_info,
        'connections': filtered_conn,
        'relevant_inputs': relevant_inputs
    }



def relevant_inputs(self) -> list[int]:
    """
    Return the sorted list of input‐column indices that actually feed into
    the network (after any pruning).
    """
    # collect every c from layer 1’s connections
    return sorted(
        c
        for (layer, _), cols in self.describe_architecture_graph()["connections"].items()
        if layer == 1
        for c in cols
    )

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

desc = lut_net.describe_architecture_graph()

G = nx.DiGraph()
pos = {}

# 1) build per‐layer index lists from desc['layers'] + desc['relevant_inputs']
layers = {}
# input layer:
layers[0] = desc['relevant_inputs']
# hidden/output layers:
for l, info in enumerate(desc['layers'][1:], start=1):
    # find all node‐indices that appear in connections for this layer
    layers[l] = sorted(i for (layer, i) in desc['connections'] if layer == l)

# 2) position nodes evenly in each column
for layer, indices in sorted(layers.items()):
    n = len(indices)
    if n == 0:
        continue
    ys = np.linspace(0, -1, n)
    for y, idx in zip(ys, indices):
        label = f"{layer}-{idx}"
        G.add_node(label)
        pos[label] = (layer, y)

# 3) add edges
for (layer, idx), cols in desc['connections'].items():
    to_label = f"{layer}-{idx}"
    for c in cols:
        from_label = f"{layer-1}-{c}"
        if from_label in G:
            G.add_edge(from_label, to_label)

# 4) draw
plt.figure(figsize=(8, 4))
node_colors = [
    'skyblue' if int(label.split('-')[0]) == 0 else 'lightgray'
    for label in G.nodes()
]
nx.draw(
    G, pos,
    with_labels=True,
    node_color=node_colors,
    node_size=500,
    font_size=10,
    edge_color='k'
)
plt.title("DeepBinaryClassifier Architecture (pruned)")
plt.axis('off')
plt.show()
