In [1]:
# 1) DATA IMPORT
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

df = pd.read_csv("./data/100_bit_artificial/1a.csv")
X = df.drop(columns="class").to_numpy(bool)
y = df["class"].to_numpy(bool)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"X_train: {X_train.shape}, X_test: {X_test.shape}")
print(f"y_train: {y_train.shape}, y_test: {y_test.shape}")

X_train: (8000, 100), X_test: (2000, 100)
y_train: (8000,), y_test: (2000,)


In [2]:
# 2) NETWORK TRAINING
from sklearn.metrics import accuracy_score
from architecture.deep_binary_classifier import DeepBinaryClassifier
from architecture.ripper_node import make_ripper_node

config = dict(
    layer_node_counts=[32]*5 + [1],
    layer_bit_counts=[6]*6,
    seed=42
)

rip_net = DeepBinaryClassifier(**config, node_factory=make_ripper_node, jobs=8)
_ = rip_net.fit(X_train, y_train)

pred_test = rip_net.predict(X_test)
acc_before = accuracy_score(y_test, pred_test)

print(f"Rule network accuracy (before pruning): {acc_before:.4f}")

Rule network accuracy (before pruning): 0.8830


In [4]:
# 4) PRUNING CODE CELL

from typing import Iterable

def prune_network(
        net: DeepBinaryClassifier,
        outputs_to_keep: Iterable[int] | None = None,
        verbose: bool = True
) -> None:
    """
    In-place pruning:
      - Walks the graph backwards from the final layer outputs_to_keep.
      - Keeps only nodes that contribute to those outputs.
      - Reindexes X_cols at each layer to reflect pruned indices.

    Assumptions:
      - net.layers is a list[list[BaseNode]]
      - Each node's X_cols refers to indices within the *previous* layer's node list.
      - Final prediction uses net.layers[-1] outputs; by default keep node 0 only.

    Parameters
    ----------
    net : DeepBinaryClassifier
    outputs_to_keep : iterable of ints or None
        Which nodes in the final layer constitute the network's "outputs".
        If None, defaults to {0}.
    verbose : bool
        If True, prints a compact summary.
    """
    if not hasattr(net, "layers") or not net.layers:
        raise RuntimeError("Cannot prune: the network has no layers. Did you call fit()?")

    n_layers = len(net.layers)
    if outputs_to_keep is None:
        outputs_to_keep = {0}
    else:
        outputs_to_keep = set(int(i) for i in outputs_to_keep)

    # Validate outputs_to_keep
    last_count = len(net.layers[-1])
    for i in outputs_to_keep:
        if i < 0 or i >= last_count:
            raise IndexError(f"outputs_to_keep index {i} out of range for final layer of size {last_count}")

    if verbose:
        before = [len(L) for L in net.layers]
        print(f"Before pruning: {before}")

    # 1) Seed keep-sets with final layer outputs
    keep: list[set[int]] = [set() for _ in range(n_layers)]
    keep[-1] = set(outputs_to_keep)

    # 2) Walk backwards: mark all dependencies in prior layers
    for L in range(n_layers - 1, 0, -1):
        prev_layer_size = len(net.layers[L - 1])
        for j in keep[L]:
            node = net.layers[L][j]
            # X_cols indexes must be within [0, prev_layer_size)
            for prev_idx in node.X_cols:
                prev_i = int(prev_idx)
                if prev_i < 0 or prev_i >= prev_layer_size:
                    raise IndexError(
                        f"Layer {L} node {j} references invalid prev index {prev_i} "
                        f"(prev layer size {prev_layer_size})"
                    )
                keep[L - 1].add(prev_i)

    # 3) For each layer, prune to survivors and build old->new index mapping
    index_maps: list[dict[int, int]] = []
    for L in range(n_layers):
        survivors = sorted(keep[L])
        # Edge case: if a layer ends up empty (shouldn’t happen if outputs_to_keep is valid), raise
        if len(survivors) == 0:
            raise RuntimeError(f"Pruning resulted in empty layer {L}; cannot reindex network.")
        mapper = {old: new for new, old in enumerate(survivors)}
        index_maps.append(mapper)

        # Keep only survivor nodes, preserving order by new index
        net.layers[L] = [net.layers[L][old_idx] for old_idx in survivors]

    # 4) Reindex X_cols of layers 1..end using the maps of previous layers
    for L in range(1, n_layers):
        prev_map = index_maps[L - 1]
        for node in net.layers[L]:
            node.X_cols = np.array([prev_map[int(c)] for c in node.X_cols], dtype=int)

    if verbose:
        after = [len(L) for L in net.layers]
        print(f"After pruning:  {after}")

In [5]:
# 5) PRUNING EXECUTION

# Optional: keep only the first output (index 0) from the final layer
prune_network(rip_net, outputs_to_keep={0}, verbose=True)

# Sanity check: accuracy should be identical
from sklearn.metrics import accuracy_score

pred_test_pruned = rip_net.predict(X_test)
acc_after = accuracy_score(y_test, pred_test_pruned)

print(f"\nAccuracy after pruning: {acc_after:.4f}")
if abs(acc_after - acc_before) < 1e-12:
    print("✅ Accuracy preserved exactly.")
else:
    print("⚠️ Accuracy changed — investigate dependencies and X_cols integrity.")

Before pruning: [32, 32, 32, 32, 32, 1]
After pruning:  [32, 31, 29, 16, 5, 1]

Accuracy after pruning: 0.8830
✅ Accuracy preserved exactly.


In [6]:
# 6) ARCHITECTURE INSPECTION (AFTER PRUNING)

describe_net(rip_net)

# Optional: show a ruleset again from the (now-pruned) deepest layer node 0, if it exists
try:
    node = rip_net.layers[-1][0]
    if hasattr(node, "get_ruleset"):
        print("\nFinal output node ruleset:")
        print(node.get_ruleset(disjunction_str=' V '))
except Exception as e:
    print(f"(Ruleset preview skipped: {e})")

Layers (count per layer): [32, 31, 29, 16, 5, 1]

Layer 0: 32 node(s)
  Node 0: uses input indices -> [84, 86, 6, 82, 53]
  Node 1: uses input indices -> [88, 94, 43, 33, 6, 67]
  Node 2: uses input indices -> [4]

Layer 1: 31 node(s)
  Node 0: uses input indices -> [15, 3, 28, 26]
  Node 1: uses input indices -> [2, 21, 20]
  Node 2: uses input indices -> [15, 2, 9, 4, 28]

Layer 2: 29 node(s)
  Node 0: uses input indices -> [8, 21, 24, 13, 26, 0]
  Node 1: uses input indices -> [8, 18, 24, 1, 28, 22]
  Node 2: uses input indices -> [20, 27, 2, 30, 17, 3]

Layer 3: 16 node(s)
  Node 0: uses input indices -> [25, 28, 22, 16, 18, 27]
  Node 1: uses input indices -> [25, 19, 13, 23, 18, 8]
  Node 2: uses input indices -> [26, 16, 18, 21, 27]

Layer 4: 5 node(s)
  Node 0: uses input indices -> [12, 9, 8, 0]
  Node 1: uses input indices -> [12, 14, 0, 13]
  Node 2: uses input indices -> [11, 9, 3, 5, 1, 13]

Layer 5: 1 node(s)
  Node 0: uses input indices -> [2, 1, 4, 3, 0]

Final output n