In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import random
import uuid
from graphviz import Source
from itertools import chain
import os
import subprocess
import re

from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from tqdm.notebook import tqdm

import pdb

In [None]:
class Node:
    def __init__(self):
        self.id = uuid.uuid1()
        pass


class InputNode(Node):
    def __init__(self, var_id):
        super().__init__()
        self.var_id = var_id
        self.children = []
        self.loc = None
        
    def __repr__(self):
        return f"InputNode at {self.loc}"

class AndNode(Node):
    def __init__(self, var_id):
        super().__init__()
        self.var_id = var_id
        self.is_output = False
        self.parents = []
        self.children = []
        self.loc = None
        
    def __repr__(self):
        return f"AndNode at {self.loc}"
        
class OutputNode(Node):
    def __init__(self, var_id):
        super().__init__()
        self.var_id = var_id
        self.is_output = False
        self.parents = []
        self.children = []
        self.loc = None
        
    def __repr__(self):
        return f"OutputNode at {self.loc}"


class NodeNetwork:
    def __init__(self, num_inputs, hidden_layers, num_outputs):
        self.num_inputs = num_inputs
        self.hidden_layers = hidden_layers
        self.num_outputs = num_outputs
        self.max_var_idx = 0
        self.nodes_ = []

        # Input nodes
        input_nodes = []
        for i in range(num_inputs):
            self.max_var_idx += 1
            inp_node = InputNode(self.max_var_idx)
            inp_node.loc = (0, i)
            input_nodes.append(inp_node)
        self.nodes_.append(input_nodes)

        # And nodes
        for i in range(1, len(self.hidden_layers) + 1):
            # i is column index
            # We start at 1 since the 0th row is the input row
            and_nodes = []
            for j in range(self.hidden_layers[i - 1]):
                # Since we shift i with 1, we have to shift it back here
                self.max_var_idx += 1
                and_node = AndNode(self.max_var_idx)
                and_node.loc = (i, j)
                while len(and_node.parents) < 2:
                    candidate_parent_idx = np.random.randint(
                        low=0, high=len(self.nodes_[-1])
                    )
                    candidate_parent = self.nodes_[-1][candidate_parent_idx]
                    if candidate_parent.loc not in and_node.parents:
                        and_node.parents.append(candidate_parent.loc)
                        candidate_parent.children.append(and_node.loc)
                and_node.negate_0 = random.choice([True, False])
                and_node.negate_1 = random.choice([True, False])

                and_nodes.append(and_node)
            self.nodes_.append(and_nodes)

        # Output nodes
        output_nodes = []
        for i in range(num_outputs):
            self.max_var_idx += 1
            output_node = OutputNode(self.max_var_idx)
            output_node.loc = (len(self.hidden_layers) + 1, i)
            output_node.is_output = True

            while len(output_node.parents) < 2:
                candidate_parent_idx = np.random.randint(
                    low=0, high=len(self.nodes_[-1])
                )
                candidate_parent = self.nodes_[-1][candidate_parent_idx]
                if candidate_parent.loc not in output_node.parents:
                    output_node.parents.append(candidate_parent.loc)
                    candidate_parent.children.append(output_node.loc)
            output_node.negate_0 = random.choice([True, False])
            output_node.negate_1 = random.choice([True, False])

            output_nodes.append(output_node)
        self.nodes_.append(output_nodes)

    def create_aag_repr_(self):
        # File header
        out_str = f"aag {self.max_var_idx} {self.num_inputs} 0 {self.num_outputs} {sum(self.hidden_layers) + self.num_outputs}\n"

        # Input nodes
        for node in self.nodes_[0]:
            out_str += f"{node.var_id * 2}\n"

        # Output nodes
        for node in self.nodes_[-1]:
            out_str += f"{node.var_id * 2}\n"

        for i in range(len(self.nodes_)):
            if i == 0:
                continue
            for upper_node in self.nodes_[i]:
                idx00 = upper_node.parents[0][0]
                idx01 = upper_node.parents[0][1]
                idx10 = upper_node.parents[1][0]
                idx11 = upper_node.parents[1][1]
                p0 = f"{(self.nodes_[idx00][idx01].var_id * 2) + upper_node.negate_0}"
                p1 = f"{(self.nodes_[idx10][idx11].var_id * 2) + upper_node.negate_1}"
                out_str += f"{2 * upper_node.var_id} {p0} {p1}\n"

        return out_str

    def export_to_aag(self, path):
        out_str = self.create_aag_repr_()
        with open(path, "w") as f:
            f.write(out_str)

    def viz(self):
        self.export_to_aag("tmp/tmp_aig.aag")
        dot = subprocess.check_output(["aiger/aigtodot", "tmp/tmp_aig.aag"])
        s = Source(dot.decode("utf-8"), filename="test", format="png")
        os.remove("tmp/tmp_aig.aag")
        return s

    def predict(self, X):
        assert X.dtype == "bool", f"Array has to be of dtype bool"
        stim_str = ""
        for x in X:
            for char in x:
                stim_str += f"{char:d}"
            stim_str += "\n"
        stim_str += "."

        with open("tmp/stim", "w") as f:
            f.write(stim_str)
        self.export_to_aag("tmp/tmp_aig.aag")

        out = subprocess.check_output(["aiger/aigsim", "tmp/tmp_aig.aag", "tmp/stim"])
        out = re.sub("Trace is a witness.+", "", out.decode("utf-8"))
        out = out[1:].replace("\n", "")
        out = out[:-1].split("  ")
        preds = np.array([bool(int(x[-1])) for x in out])

        os.remove("tmp/stim")
        os.remove("tmp/tmp_aig.aag")

        return preds

In [None]:
def load_mnist(n_samples=10_000, pca=False, n_components=8):
    data = np.load("data/lut/MNIST.npz", allow_pickle=True)
    X_ = data["X"]
    y_ = data["y"]
    
    assert n_samples <= 70_000, f"Full data available only has 70_000 samples"

    if pca:
        pca_ = PCA(n_components=n_components)
        X_ = pca_.fit_transform(X_)

    scaler = MinMaxScaler(feature_range=(0, 1))
    X_tf = scaler.fit_transform(X_)

    X = (X_tf > 0.5).astype(bool)
    y = (y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)

    X, y = shuffle(X, y, n_samples=n_samples, random_state=100)

    X_train, X_test, y_train, y_test, = train_test_split(
        X, y, test_size=0.2, random_state=42, shuffle=False
    )
    return X_train, X_test, y_train, y_test

X_train, X_test, y_train, y_test = load_mnist(n_samples=10_000, pca=True, n_components=8)

In [None]:
nn = NodeNetwork(8, [8, 8, 8, 8], 1)
# nn = NodeNetwork(784, [784, 784, 784], 1)

preds_train = nn.predict(X_train)
preds_test = nn.predict(X_test)
acc_train = accuracy_score(preds_train, y_train)
acc_test = accuracy_score(preds_test, y_test)
print(f"Accuracy on training set: {acc_train:.2f}%")
print(f"Accuracy on test set: {acc_test:.2f}%")

## MLP performance

In [None]:
from sklearn.neural_network import MLPClassifier

In [None]:
clf = MLPClassifier(hidden_layer_sizes=(100,))
clf.fit(X_train, y_train)
preds_mlp_train = clf.predict(X_train)
preds_mlp_test = clf.predict(X_test)

print(f"Accuracy on training set: {accuracy_score(preds_mlp_train, y_train):.2f}%")
print(f"Accuracy on test set: {accuracy_score(preds_mlp_test, y_test):.2f}%")

## AIG local search

In [None]:
import copy
import ipywidgets as widgets
from IPython.display import display

In [None]:
nn = NodeNetwork(8, [8, 8, 8, 8], 1)

TODOs:
- [x] Prevent from 2 parents being the same
- Traverse graph downwards

In [None]:
nn = NodeNetwork(8, [16, 16, 16, 16], 1)

acc_train = 0
iteration = 0
acc_hist = []

w = widgets.HTML(
    value=f"Iteration {iteration} Accuracy: {acc_train_best}",
    placeholder="Iteration progress",
    description="",
)
display(w)

while True:
    iteration += 1
    successor_candidates = []
    accuracies = []
    
    for i in range(len(list(chain(*nn.nodes_[1:])))):
        nn_new = copy.deepcopy(nn)
        and_node = list(chain(*nn_new.nodes_[1:]))[i]
        # Get new parent
        if len(nn_new.nodes_[and_node.loc[0] - 1]) > 2:
            old_parent_loc = and_node.parents.pop(random.choice([0, 1]))
            old_parent = nn_new.nodes_[old_parent_loc[0]][old_parent_loc[1]]
            _ = old_parent.children.pop(old_parent.children.index(and_node.loc))
            new_parent = random.choice(nn_new.nodes_[and_node.loc[0] - 1])
            # To prevent from changing nothing or having 2 identical parents
            while (new_parent.loc == old_parent_loc) or (new_parent.loc in and_node.parents):
                new_parent = random.choice(nn_new.nodes_[and_node.loc[0] - 1])
            and_node.parents.insert(0, new_parent.loc)
            new_parent.children.append(and_node.loc)
        # Change polarity
        and_node.negate_0 = random.choice([True, False])
        and_node.negate_1 = random.choice([True, False])
        successor_candidates.append(nn_new)
        
    for candidate in successor_candidates:
        preds_train = candidate.predict(X_train)
        acc_train = accuracy_score(preds_train, y_train)
        accuracies.append(acc_train)
        
    acc_train_best = max(accuracies)
    acc_hist.append(acc_train_best)
        
    
    if acc_train_best > acc_train:
        acc_train = acc_train_best
        nn = copy.deepcopy(successor_candidates[accuracies.index(acc_train_best)])
#     else:
#         break
        
    w.value = f"Iteration {iteration} Accuracy: {acc_train_best}"

In [None]:
fig, ax = plt.subplots(1, 1)
ax.plot(acc_hist)

In [None]:
nn.viz()

In [None]:
preds_test = nn.predict(X_test)
acc_test = accuracy_score(preds_test, y_test)
print(f"Accuracy on test set: {acc_test:.2f}%")