In [1]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from architecture.deep_binary_classifier import DeepBinaryClassifier
from architecture.lut_node import make_lut_node
from architecture.ripper_node import make_ripper_node

In [2]:
df   = pd.read_csv("./data/100_bit_artificial/1a.csv")
X = df.drop(columns="class").to_numpy(bool)
y = df["class"].to_numpy(bool)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

print(f"Dataset Shape               : {df.shape}")
print(f"Train-Test-Split            : {X_train.shape[0]} vs. {X_test.shape[0]}")
print(f"Train label distribution    : {y_train.sum()} (True) vs. {y_train.size - y_train.sum()} (False)")

Dataset Shape               : (10000, 101)
Train-Test-Split            : 8000 vs. 2000
Train label distribution    : 4605 (True) vs. 3395 (False)


In [3]:
# Training a raw LUT network

hidden_layer_count  = 4
layer_node_count    = 32
layer_bit_count     = 4

lut_net = DeepBinaryClassifier(
    layer_node_counts   = [layer_node_count]*hidden_layer_count + [1],
    layer_bit_counts    = [layer_bit_count]*(hidden_layer_count + 1),
    node_factory        = make_lut_node,
    seed                = 42
)
lut_net.fit(X_train, y_train)
lut_pred_test = lut_net.predict(X_test)
lut_acc_test = accuracy_score(y_test, lut_pred_test)
print(f"LUT network accuracy on test set: {lut_acc_test:.4f}")

LUT network accuracy on test set: 0.7585


In [9]:
lut_node = lut_net.layers[0][3]
lut_node_truth_table = lut_node.get_truth_table()

lut_node_column_names = [f"bit_{i}" for i in lut_node.X_cols] + ["pred_lut"]
lut_node_truth_table_df = pd.DataFrame(lut_node_truth_table, columns=lut_node_column_names)
lut_node_truth_table_df

Unnamed: 0,bit_54,bit_48,bit_4,bit_45,pred_lut
0,False,False,False,False,True
1,False,False,False,True,True
2,False,False,True,False,True
3,False,False,True,True,False
4,False,True,False,False,True
5,False,True,False,True,True
6,False,True,True,False,True
7,False,True,True,True,True
8,True,False,False,False,True
9,True,False,False,True,True


In [10]:
def distil_ripper_node(lut_node) -> "RipperNode":
    lut_node_truth_table = lut_node.get_truth_table()

    # RIPPER requires multiple samples of each pattern to learn from else it will predict False for everything
    double_lut_node_truth_table = np.vstack([lut_node_truth_table, lut_node_truth_table])

    X_lut_node = double_lut_node_truth_table[:, :-1]
    y_lut_node = double_lut_node_truth_table[:, -1]

    rip_node = make_ripper_node(X_cols=lut_node.X_cols, X_node=X_lut_node, y_node=y_lut_node, seed=lut_node.seed)
    return rip_node


rip_node = distil_ripper_node(lut_node)

rip_node_truth_table = rip_node.get_truth_table()

rip_node_column_names = [f"bit_{i}" for i in rip_node.X_cols] + ["pred_rip"]
rip_node_truth_table_df = pd.DataFrame(rip_node_truth_table, columns=rip_node_column_names)
rip_node_truth_table_df

Unnamed: 0,bit_4,bit_45,bit_54,pred_rip
0,False,False,False,True
1,False,False,True,True
2,False,True,False,True
3,False,True,True,True
4,True,False,False,True
5,True,False,True,True
6,True,True,False,False
7,True,True,True,True


In [7]:
rip_node_rule = rip_node.get_ruleset(disjunction_str=' V ')
print(rip_node_rule)

[x_4=False] V [x_45=False] V [x_54=True]


In [6]:
# compare both on full column set

bits_lut = {c for c in lut_node_truth_table_df.columns if c.startswith("bit_")}
bits_rip = {c for c in rip_node_truth_table_df.columns if c.startswith("bit_")}
bits_common = sorted(bits_lut & bits_rip, key=lambda x: int(x.split("_")[1]))

comp_node_truth_table_df = lut_node_truth_table_df.copy()
comp_node_truth_table_df = comp_node_truth_table_df.merge(
    rip_node_truth_table_df[bits_common + ["pred_rip"]],
    how="left",
    on=bits_common
)
comp_node_truth_table_df["mismatch"] = comp_node_truth_table_df["pred_lut"] != comp_node_truth_table_df["pred_rip"]
comp_node_truth_table_df

Unnamed: 0,bit_0,bit_1,bit_2,bit_3,pred_lut,pred_rip,mismatch
0,False,False,False,False,True,True,False
1,False,False,False,True,True,True,False
2,False,False,True,False,True,True,False
3,False,False,True,True,False,True,True
4,False,True,False,False,True,True,False
5,False,True,False,True,True,True,False
6,False,True,True,False,True,True,False
7,False,True,True,True,True,True,False
8,True,False,False,False,True,True,False
9,True,False,False,True,True,True,False


In [7]:
# replace every LUT node in the network with its distilled RipperNode
# warning! this will overwrite the original LUT network

def distil_ripper_net(lut_net) -> "DeepBinaryClassifier":
    for layer_idx, layer in enumerate(lut_net.layers):
        for i, lut_node in enumerate(layer):
            rip_node = distil_ripper_node(lut_node)
            lut_net.layers[layer_idx][i] = rip_node
    return lut_net


rip_net = distil_ripper_net(lut_net)
rip_pred_test = rip_net.predict(X_test)
rip_acc_test = accuracy_score(y_test, rip_pred_test)
print(f"Ripper network accuracy on test set: {rip_acc_test:.4f}")

No negative samples. Existing target labels=[True].

Ruleset is empty. All predictions it makes with method .predict will be negative. It may be untrained or was trained on a dataset split lacking positive examples.

Ruleset is empty. All predictions it makes with method .predict will be negative. It may be untrained or was trained on a dataset split lacking positive examples.

No negative samples. Existing target labels=[True].

Ruleset is empty. All predictions it makes with method .predict will be negative. It may be untrained or was trained on a dataset split lacking positive examples.

Ruleset is empty. All predictions it makes with method .predict will be negative. It may be untrained or was trained on a dataset split lacking positive examples.

No negative samples. Existing target labels=[True].

Ruleset is empty. All predictions it makes with method .predict will be negative. It may be untrained or was trained on a dataset split lacking positive examples.

Ruleset is empty. All

Ripper network accuracy on test set: 0.6765
