# NITTA ML for Synthesis Prototype

## 1. Data Crawling

In [None]:
import asyncio
import os
import pickle
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Any

import pandas as pd
from aiohttp import ClientSession
from cached_property import cached_property
from cachetools import cached
from dataclasses_json import dataclass_json, LetterCase
import datetime

In [None]:
NITTA_PORT = 53829
NITTA_BASEURL = f"http://localhost:{NITTA_PORT}"
NITTA_ROOT_DIR = r'../..'
# may needed to be updated
NITTA_EXE_PATH = "stack exec nitta -- " # NITTA_ROOT_DIR + r"\.stack-work\dist\29cc6475\build\NITTA\nitta.exe"
METRICS_WEIGHTS = pd.Series(dict(duration=-1, depth=-0.1))
LAMBDA = 0.6
WAIT_NITTA_DELAY = 0.5
pd.set_option('display.max_colwidth', 120)

In [None]:
def cached_node_method(wrapped):
    return cached({}, key=lambda self, *args: hash(self.sid))(wrapped)

In [None]:
def debounce(s):
    """Decorator ensures function that can only be called once every `s` seconds.
    """
    def decorate(f):
        t = None
        n = 0

        def wrapped(*args, **kwargs):
            nonlocal t, n
            t_ = time.time()
            if t is None or t_ - t >= s:
                if n > 0: print(f"-- skipped {n} calls")
                result = f(*args, **kwargs)
                t = time.time()
                n = 0
                return result
            else:
                n += 1
        return wrapped
    return decorate

@debounce(1)
def log_debug(*args):
    print("--", datetime.datetime.now().strftime("%T"), *args)

In [None]:
nitta_dataclass_params = dict(letter_case=LetterCase.CAMEL)


@dataclass_json(**nitta_dataclass_params)
@dataclass
class NittaNodeDecision:
    tag: str
    
@dataclass_json(**nitta_dataclass_params)
@dataclass
class NittaNode: 
    score: Optional[int]
    is_terminal: bool
    is_finish: bool
    parameters: Any # NittaNodeParameters
    decision: NittaNodeDecision
    duration: Optional[int]
    sid: str
        
    children: Optional[List['NittaNode']] = field(default=None, repr=False)
    parent: Optional['NittaNode'] = field(default=None, repr=False)
        
    @property
    def is_leaf(self):
        return self.is_terminal
        
    @cached_property
    def subtree_size(self):
        assert self.children is not None
        return sum(child.subtree_size for child in self.children) + 1
        
    @cached_property
    def depth(self) -> int:
        return self.sid.count('-') if self.sid != '-' else 0
       
    @cached_property
    def subtree_leafs_metrics(self) -> pd.DataFrame:
        if self.is_leaf:
            if not self.is_finish:
                return pd.DataFrame()
            return pd.DataFrame(dict(duration=[self.duration], depth=[self.depth]))
        else:
            return pd.concat([child.subtree_leafs_metrics for child in self.children])
        
    @cached_node_method
    def get_subtree_leafs_labels(self, metrics_distrib: pd.DataFrame) -> pd.Series:
        if self.is_leaf:
            return pd.Series([self.compute_label(metrics_distrib)])
        else:
            return pd.concat([child.get_subtree_leafs_labels(metrics_distrib) for child in self.children])
    
    @cached_node_method
    def compute_label(self, metrics_distrib: pd.DataFrame) -> float:
        if self.is_leaf: 
            if not self.is_finish:
                # unsuccessful synthesis, very low artificial label
                return -3
            
            metrics = self.subtree_leafs_metrics.iloc[0]
            normalized_metrics = (metrics - metrics_distrib.mean()) / metrics_distrib.std()
            return normalized_metrics.dot(METRICS_WEIGHTS)
        
        subtree_labels = self.get_subtree_leafs_labels(metrics_distrib)
        return LAMBDA * subtree_labels.max() + (1 - LAMBDA) * subtree_labels.mean()
    
    @cached_property
    def alternative_siblings(self) -> dict:
        bindings, refactorings, dataflows = 0, 0, 0
        
        if self.parent:
            for sibling in self.parent.children:
                if sibling.sid == self.sid:
                    continue
                target = None
                if sibling.decision.tag == "BindDecisionView":
                    bindings += 1
                elif sibling.decision.tag == "DataflowDecisionView":
                    dataflows += 1
                else:
                    refactorings += 1
            
        return dict(alt_bindings=bindings, 
                    alt_refactorings=refactorings, 
                    alt_dataflows=dataflows)
            
    async def retrieve_subforest(self, session, levels_left=None):
        self.children = []
        if self.is_leaf or levels_left == -1:
            return

        async with session.get(NITTA_BASEURL + f"/node/{self.sid}/subForest") as resp:
            children_raw = await resp.json()
            
        log_debug(f"{len(children_raw)} children from {self.sid}")

        for child_raw in children_raw:
            child = NittaNode.from_dict(child_raw)
            child.parent = self
            self.children.append(child)

        levels_left_for_child = None if levels_left is None else levels_left - 1
        await asyncio.gather(
            *[child.retrieve_subforest(session, levels_left_for_child) for child in self.children]
        )
        
                    
async def retrieve_whole_nitta_tree(max_depth=None) -> NittaNode:
    start_time = time.perf_counter()
    async with ClientSession() as session:
        async with session.get(NITTA_BASEURL + f"/node/-") as resp:
            root_raw = await resp.json()
        root = NittaNode.from_dict(root_raw)
        await root.retrieve_subforest(session, max_depth)
    
    print(f"Finished tree retrieval in {time.perf_counter() - start_time:.2f} s")
    return root
    

In [None]:
print("Test retrieve_whole_nitta_tree")

example = "examples/counter.lua"
nitta_tree = None
with subprocess.Popen(f"{NITTA_EXE_PATH} -p={NITTA_PORT} {example}", 
                      cwd=NITTA_ROOT_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
                     ) as proc:
    try:
        time.sleep(WAIT_NITTA_DELAY)
        print(proc.stdout.read1().decode())
        nitta_tree = await retrieve_whole_nitta_tree()
        print(f"Nodes: {nitta_tree.subtree_size}")
    finally:
        proc.kill()

In [None]:
# with open("treedump.pickle", "wb") as f:
#     pickle.dump(nitta_tree, f)

In [None]:
# with open("treedump.pickle", "rb") as f:
#     nitta_tree = pickle.load(f)

In [None]:
tuple(nitta_tree.subtree_leafs_metrics.iloc[0].items())

In [None]:
print(f"Tree size: {nitta_tree.subtree_size} nodes")
_ = nitta_tree.subtree_leafs_metrics.hist(figsize=(9, 4))

In [None]:
nitta_tree.subtree_leafs_metrics.value_counts()

In [None]:
nitta_tree.compute_label(nitta_tree.subtree_leafs_metrics)

In [None]:
nitta_tree.get_subtree_leafs_labels(nitta_tree.subtree_leafs_metrics).hist()

In [None]:
def collect_all_labels(tree, metrics_distrib):
    labels = []
    def nodejob(node):
        labels.append({"label": node.compute_label(metrics_distrib), 
                       "depth": node.depth, 
                       "duration": node.duration})
        for child in node.children:
            nodejob(child)
            
    nodejob(tree)
    return pd.DataFrame(labels)
    
ds = nitta_tree.subtree_leafs_metrics
lbls = collect_all_labels(nitta_tree, ds)
_ = lbls.label.hist(bins=50)

In [None]:
def select_best(node, metrics_distrib):
    if node.is_leaf:
        return node

    best_child = max([(child.compute_label(metrics_distrib), child) for child in node.children], key=lambda v: v[0])[1]
    return select_best(best_child, metrics_distrib)

best = select_best(nitta_tree, ds)
best

In [None]:
lfs = nitta_tree.subtree_leafs_metrics
print(f"depths with best node's duration ({best.duration}):")
display(lfs[lfs.duration == best.duration].depth.value_counts())
print(f"best node's depth: {best.depth}")

# NITTA synthesis tree to CSV dataset implementation

In [None]:
import subprocess
import shlex
from glob import glob
import os
from pathlib import Path

In [None]:
def _extract_params_dict(node: NittaNode) -> dict:
    if node.decision.tag in ["BindDecisionView", "DataflowDecisionView"]:
        result = node.parameters.copy()
        if node.decision.tag == "DataflowDecisionView":
            result["pNotTransferableInputs"] = sum(result["pNotTransferableInputs"])
        return result
    elif node.decision.tag == "RootView":
        return {}
    else:
        # refactorings
        return {"pRefactoringType": node.decision.tag}


def assemble_tree_dataframe(example: str, node: NittaNode, metrics_distrib=None, include_label=True,
                            levels_left=None) -> pd.DataFrame:
    if include_label and metrics_distrib is None:
        metrics_distrib = node.subtree_leafs_metrics

    self_df = pd.DataFrame(dict(
        example=example,
        sid=node.sid,
        tag=node.decision.tag,
        old_score=node.score,
        is_leaf=node.is_leaf,
        **node.alternative_siblings,
        **_extract_params_dict(node),
    ), index=[0])
    if include_label:
        self_df["label"] = node.compute_label(metrics_distrib)

    levels_left_for_child = None if levels_left is None else levels_left - 1
    if node.is_leaf or levels_left == -1:
        return self_df
    else:
        result = [assemble_tree_dataframe(example, child, metrics_distrib, include_label, levels_left_for_child)
                  for child in node.children]
        if node.sid != "-":
            result.insert(0, self_df)
        return pd.concat(result)

In [None]:
tdf = assemble_tree_dataframe("spi3", nitta_tree, levels_left=0)
display(tdf)

In [None]:
tdf = assemble_tree_dataframe("spi3", nitta_tree, levels_left=2)
display(tdf)

In [None]:
tdf[(tdf.sid == "-0-0-0") | (tdf.sid == "-0-0-1") | (tdf.sid == "-0-0-2")]

In [None]:
DATA_ROOT = Path("data")
DATA_ROOT.mkdir(exist_ok=True)

# Source: https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python
def is_port_in_use(port):
    import socket
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(('localhost', port)) == 0

async def process_example(example: str) -> pd.DataFrame:
    if is_port_in_use(NITTA_PORT):
        raise RuntimeError(f"Port {NITTA_PORT} is already in use, shutdown NITTA server if that's running.")
        
    example_name = os.path.basename(example)
    df = None

    print(f"Processing example {example!r}")
    with subprocess.Popen(f"{NITTA_EXE_PATH} -p={NITTA_PORT} {example}", cwd=NITTA_ROOT_DIR,
                          stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) as proc:
        try:
            print(f"NITTA is running.")
            time.sleep(WAIT_NITTA_DELAY)
            print(f"Retrieving tree...")

            tree = await retrieve_whole_nitta_tree()
            with open(DATA_ROOT / f"{example_name}.pickle", "wb") as f:
                pickle.dump(tree, f)

            print(f"Nodes: {tree.subtree_size}. Building dataframe...")
            df = assemble_tree_dataframe(example_name, tree).reset_index(drop=True)

            print(f"Data's ready, {len(df)} rows")

            target_filepath = DATA_ROOT / f"{example_name}.csv"
            print(f"Saving to {target_filepath}")
            df.to_csv(target_filepath, index=False)
        finally:
            proc.kill()
            print(f"NITTA is dead")
    print("DONE")
    return df

print("Available examples:")
examples = list(map(os.path.abspath, glob(f'{NITTA_ROOT_DIR}/examples/*.lua')))
pd.Series(examples)

In [None]:
r = await process_example(examples[1])

# 2. Model Training

In [None]:
import tensorflow as tf
import pandas as pd
import numpy as np
from glob import glob
from pathlib import Path
import missingno
from IPython.display import display
import matplotlib.pyplot as plt

In [None]:
DATA_ROOT = Path("data")

data_csvs = glob(str(DATA_ROOT / "*.csv"))
print("Available CSVs:")
data_csvs

In [None]:
df = pd.concat([pd.read_csv(d) for d in data_csvs]).reset_index(drop=True)
df

In [None]:
# upsampling to cope with imbalanced data
# dfu = pd.concat([
#     df,
#     pd.concat([df[df.example == "counter.lua"]]*(8 - 1)),
#     pd.concat([df[df.example == "fibonacci.lua"]]*(50 - 1)),
#     pd.concat([df[df.example == "spi2.lua"]]*(100 - 1))
# ])
dfu = df
dfu.example.value_counts()

In [None]:
def preprocess_df(df: pd.DataFrame) -> pd.DataFrame:
    def map_bool(c):
        return c.apply(lambda v: 1 if v is True else (0 if v is False else v))
    
    def map_categorical(df, c, options=None):
        return pd.concat([df.drop([c.name], axis=1), pd.get_dummies(c, prefix=c.name, columns=options)], axis=1)
    
    df = df.copy()
    df.is_leaf = map_bool(df.is_leaf)
    df.pCritical = map_bool(df.pCritical)
    df.pPossibleDeadlock = map_bool(df.pPossibleDeadlock)
    df.pRestrictedTime = map_bool(df.pRestrictedTime)
    df = map_categorical(df, df.tag, ['tag_BindDecisionView','tag_BreakLoopView','tag_ConstantFoldingView','tag_DataflowDecisionView','tag_OptimizeAccumView','tag_ResolveDeadlockView'])
    df = df.drop(["pWave", "example", "sid", "old_score", "is_leaf", "pRefactoringType"], axis="columns")
    
    df = df.fillna(0)
    return df

pdf = preprocess_df(dfu)
pdf.reset_index()

In [None]:
FINAL_COLUMNS = pdf.columns.tolist()
display(FINAL_COLUMNS)
METRICS_COLUMNS = [cn for cn in FINAL_COLUMNS if cn.startswith("p")] + ["pRefactoringType", "pWave"]
display(METRICS_COLUMNS)

In [None]:
missingno.matrix(pdf.sort_values(["pOutputNumber", "pWaitTime"]))

In [None]:
from sklearn.model_selection import train_test_split
from typing import Tuple

TARGET_COLUMNS = ["label"]
def create_datasets(df) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    # create training and evaluation datasets
    train_df, test_df = train_test_split(df.sample(frac=1), test_size=0.2)

    N = len(df)
    print(f"N:\t{N}")
    print(f"Train:\t{len(train_df)}, {len(train_df) / N * 100:.0f}%")
    print(f"Test:\t{len(test_df)}, {len(test_df) / N * 100:.0f}%")
    
    def df_to_dataset(df, shuffle=True, batch_size=16, repeat=False, print_cols=False):
        df = df.copy()

        # split df into features and labels
        targets = df[TARGET_COLUMNS].copy()
        df.drop(TARGET_COLUMNS, axis=1, inplace=True)
        features = df
        if print_cols:
            print(f"Feature columns: {features.columns.values.tolist()}")

        ds = tf.data.Dataset.from_tensor_slices((features.values, targets.values))
        ds = ds.shuffle(buffer_size=10000) if shuffle else ds
        ds = ds.batch(batch_size) if batch_size else ds
        ds = ds.repeat() if repeat else ds
        return ds

    train_ds = df_to_dataset(train_df, batch_size=16, repeat=True, print_cols=True)
    test_ds = df_to_dataset(test_df)
    
    return train_ds, test_ds

train_ds, test_ds = create_datasets(pdf)

In [None]:
next(iter(train_ds))

In [None]:
from tensorflow.keras import layers
from tensorflow.keras import regularizers

def create_model(df, **kwargs) -> tf.keras.Model:
    feature_columns = [
        tf.feature_column.numeric_column(c)
        for c in df.columns
    ]
    
    model = tf.keras.Sequential([
        layers.InputLayer(input_shape=(20,)),
        layers.Dense(128, activation="relu", kernel_regularizer="l2"),
        layers.Dense(128, activation="relu", kernel_regularizer="l2"),
        layers.Dense(64, activation="relu", kernel_regularizer="l2"),
        layers.Dense(64, activation="relu", kernel_regularizer="l2"),
        layers.Dense(32, activation="relu"),
        layers.Dense(1)
    ])
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),  # Optimizer
        # Loss function to minimize
        loss="mse",
        # List of metrics to monitor
        metrics=["mae"],
    )
    return model


In [None]:
total_hist_df = pd.DataFrame()
model = create_model(pdf)
model.summary()

In [None]:
raw_history = model.fit(
    train_ds, 
    epochs=60,
    steps_per_epoch=2250,
    validation_data=test_ds
)
hist_df = pd.DataFrame(raw_history.history)
total_hist_df = pd.concat([total_hist_df, hist_df]).reset_index(drop=True)

In [None]:
hist_df[["loss", "val_loss"]].plot()
plt.grid()
hist_df[["mae", "val_mae"]].plot()
plt.grid()
hist_df

In [None]:
total_hist_df[["loss", "val_loss"]].plot()
plt.grid()
total_hist_df[["mae", "val_mae"]].plot()
plt.grid()
total_hist_df

In [None]:
print("Feature importance by weights absolute value:")
pd.DataFrame(zip(np.abs(model.weights[0].numpy()).sum(axis=1), set(pdf.columns) - set(TARGET_COLUMNS))) \
    .set_index(1) \
    .sort_values(0, ascending=False)

In [None]:
MODELS_ROOT = Path("models")
MODELS_ROOT.mkdir(exist_ok=True)

mname = "model_N"
model.save(MODELS_ROOT / mname)

In [None]:
# optionally save total_hist_df for reference
# with open(MODELS_ROOT / f"{mname}_total_hist.pickle", "wb") as f:
#     pickle.dump(total_hist_df, f)

# 3. Model evaluation

In [None]:
MODELS_ROOT = Path("models")

# model = tf.keras.models.load_model(MODELS_ROOT / "example_model_1")
model = tf.keras.models.load_model(MODELS_ROOT / mname)
model.summary()

In [None]:
def new_evaluator(node: NittaNode):
    final_columns = ['alt_bindings', 'alt_refactorings', 'alt_dataflows', 'pOutputNumber', 'pAlternative', 'pAllowDataFlow', 'pCritical', 'pPercentOfBindedInputs', 'pPossibleDeadlock', 'pNumberOfBindedFunctions', 'pRestless', 'pNotTransferableInputs', 'pRestrictedTime', 'pWaitTime', 'tag_BindDecisionView', 'tag_BreakLoopView', 'tag_ConstantFoldingView', 'tag_DataflowDecisionView', 'tag_OptimizeAccumView', 'tag_ResolveDeadlockView']
    metrics_columns = [cn for cn in final_columns if cn.startswith("p")] + ["pRefactoringType", "pWave"]
    
    node_df = assemble_tree_dataframe("", node, include_label=False, levels_left=-1)
    filled_metrics_df = pd.concat([pd.DataFrame(columns=metrics_columns), node_df])
    preprocessed_df = preprocess_df(filled_metrics_df)
    right_final_columns_df = pd.concat([pd.DataFrame(columns=final_columns), preprocessed_df])[final_columns]
    ohe_flags_zero_filled_df = right_final_columns_df.fillna(0)
    final_df = ohe_flags_zero_filled_df
    
    return model.predict(final_df.values)[0][0]

In [None]:
from aiohttp import ServerDisconnectedError
from collections import defaultdict

def old_evaluator(node: NittaNode):
    return node.score

counters = defaultdict(lambda: 0)
def reset_counters():
    global counters
    counters = defaultdict(lambda: 0)

async def select_best_by_evaluator(session, evaluator, node, children_limit=None):
    counters[evaluator.__name__] += 1
    
    if node.is_leaf:
        if not node.is_finish:
            return None
            
        return node

    try:
        await node.retrieve_subforest(session, 0)
    except ServerDisconnectedError:
#         print(f"Invalid node with NITTA exception: {node}")
        return None
    
    children = [(evaluator(child), child) for child in node.children]
    children.sort(key=lambda v: v[0], reverse=True)
#     print(f"children: {[d[0] for d in children]}")
    if children_limit:
        children = children[:children_limit]
    
    while children:
        next_best_child = children.pop(0)[1]
#         print(f"next best: {next_best_child}")
        result = await select_best_by_evaluator(session, evaluator, next_best_child, children_limit)
        if result is not None:
            return result         
        
    return None

In [None]:
reset_counters()

example = "examples/counter.lua"
nitta_tree = None
with subprocess.Popen(f"{NITTA_EXE_PATH} -p={NITTA_PORT} {example}", 
                      cwd=NITTA_ROOT_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
                     ) as proc:
    try:
        time.sleep(WAIT_NITTA_DELAY)
        print(proc.stdout.read1().decode())
        nitta_tree = await retrieve_whole_nitta_tree()
    finally:
        proc.kill()
        
new_evaluator(nitta_tree.children[0])

In [None]:
reset_counters()

example = "examples/counter.lua"
nitta_tree = None
with subprocess.Popen(f"{NITTA_EXE_PATH} -p={NITTA_PORT} {example}", 
                      cwd=NITTA_ROOT_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
                     ) as proc:
    try:
        time.sleep(WAIT_NITTA_DELAY)
        print(proc.stdout.read1().decode())

        root = await retrieve_whole_nitta_tree()

        async with ClientSession() as session:
            best_new = await select_best_by_evaluator(session, new_evaluator, root, 2)
            print("NEW DONE", best_new)
            best_old = await select_best_by_evaluator(session, old_evaluator, root, 2)
            print("OLD DONE", best_old)
    finally:
        proc.kill()

In [None]:
display(best_old)
display(best_new)
display(pd.DataFrame(dict(duration=[best_old.duration, best_new.duration],
                          depth=[best_old.depth, best_new.depth],
                          evaluator_calls=[counters["old_evaluator"], counters["new_evaluator"]]), index=["old", "new"]))