In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from architecture.classifier import DeepBinaryClassifier
from architecture.nodes.ripper import make_ripper_node
from architecture.nodes.lut import make_lut_node


df   = pd.read_csv("./data/100_bit_artificial/1a.csv")
X = df.drop(columns="class").to_numpy(bool)
y = df["class"].to_numpy(bool)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

config = dict(layer_node_counts=[32]*5 + [1], layer_bit_counts=[6]*6, seed=42)

In [3]:
# the LUT network runs faster on a single thread

lut_net = DeepBinaryClassifier(**config, node_factory=make_lut_node, jobs=1)
%time lut_net.fit(X_train, y_train)
pred_test = lut_net.predict(X_test)
acc_lut = accuracy_score(y_test, pred_test)
print(f"LUT network accuracy: {acc_lut:.4f}")

CPU times: user 38.6 ms, sys: 5.02 ms, total: 43.7 ms
Wall time: 42.7 ms
LUT network accuracy: 0.7315


In [10]:
# the Ripper nodes profit from parallelization

rip_net = DeepBinaryClassifier(**config, node_factory=make_ripper_node, jobs=8)
%time rip_net.fit(X_train, y_train)
pred_test = rip_net.predict(X_test)
acc_rip = accuracy_score(y_test, pred_test)
print(f"Rule network accuracy: {acc_rip:.4f}")

CPU times: user 334 ms, sys: 243 ms, total: 577 ms
Wall time: 33.8 s
Rule network accuracy: 0.8825


In [13]:
rip_node = rip_net.layers[4][5]
rip_node_rule = rip_node.get_expression()
print(rip_node_rule)

L4N12 | L4N18 | (L4N0 & L4N15) | (L4N15 & L4N17) | (L4N0 & L4N15 & L4N18) | (L4N12 & L4N15 & L4N17) | (L4N0 & L4N15 & L4N18 & ~L4N12)


In [14]:
# RIPPER may save on some nodes
rip_node = rip_net.layers[4][5]
lut_node = lut_net.layers[4][5]

print(f"RIPPER node uses following columns: {rip_node.input_names}")
print(f"LUT node uses following columns: {lut_node.input_names}")

RIPPER node uses following columns: ['L4N0', 'L4N12', 'L4N15', 'L4N17', 'L4N18']
LUT node uses following columns: ['L4N0', 'L4N12', 'L4N15', 'L4N17', 'L4N18', 'L4N28']
