In [12]:
from truthnet import truthnet
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.stats.api as sms
from tqdm.notebook import tqdm
import tikzplotlib as tpl
from datetime import datetime
import glob
from IPython.display import Image, Markdown, display

import re
from pathlib import Path

import numpy as np
import seaborn as sns
from dot2tex import dot2tex

In [1]:
def _get_qnet(df):
    from quasinet import qnet

    qn = qnet.Qnet(
        feature_names=df.columns.values,
        min_samples_split=2,
        alpha=0.05,
        max_depth=-1,
        max_feats=-1,
        early_stopping=False,
        verbose=0,
        random_state=None,
        n_jobs=-1,
    )

    qn.fit(df.to_numpy(dtype="<U21"))

    return qn


def _get_tnets(df, df_pos, df_neg):
    non_null_cols = (
        (df.isna().sum() < len(df))
        & (df_pos.isna().sum() < len(df_pos))
        & (df_neg.isna().sum() < len(df_neg))
    )

    Tr = _get_qnet(df.loc[:, non_null_cols].fillna(-9).astype(int).replace(-9, ""))
    Tr_pos = _get_qnet(
        df_pos.loc[:, non_null_cols].fillna(-9).astype(int).replace(-9, "")
    )
    Tr_neg = _get_qnet(
        df_neg.loc[:, non_null_cols].fillna(-9).astype(int).replace(-9, "")
    )

    return {
        "all": Tr,
        "pos": Tr_pos,
        "neg": Tr_neg,
        "data": df.loc[:, non_null_cols].fillna(-9).astype(int).replace(-9, ""),
    }

In [4]:
glbl = _get_tnets(
    pd.read_csv("data/gibbons_global/gibbons_global.csv"),
    pd.read_csv("data/gibbons_global/gibbons_global_pos.csv"),
    pd.read_csv("data/gibbons_global/gibbons_global_neg.csv"),
)

data_samples = glbl["data"]
full_model = glbl["all"]
pos_model = glbl["pos"]
neg_model = glbl["neg"]

In [7]:
def export_qnet_tree_dotfiles(model, out_dirname):
        """Generate tree dotfiles for each feature of the model

        Args:
            out_dirname (str): the output directory, make one if doesn't exist
        """
        import os
        from quasinet import qnet
        
        if not os.path.exists(out_dirname):
            os.mkdir(out_dirname)
        for idx, feature_name in enumerate(model.feature_names):
            qnet.export_qnet_tree(model, idx,
            os.path.join(out_dirname, '{}.dot'.format(feature_name)),
            outformat='graphviz', detailed_output=True)

def _get_trees_png(model, outdir):
    import os
    import shutil

    ! rm -rf "tmp_trees/"
    os.makedirs("tmp_trees")
    export_qnet_tree_dotfiles(model, "tmp_trees")

    ! cd 'tmp_trees'; for file in `ls *dot`; do dot -Tpng $file -o "${file%.*}.png" ; done > /dev/null 2>&1

    shutil.copytree("tmp_trees", outdir, dirs_exist_ok=True)

In [8]:
_get_trees_png(full_model, "qnet_trees/bondcourt/")

Useful regex tester: https://pythex.org/

In [79]:
def process_dot(
    dotfile,
    target_range,
    width=150,
    height=219,
    leafstyle="nodeleaf",
    nodestyle="nodei",
    edgelabelstyle="midway,left,labelfont,yshift=\ystA",
    palette=None,
    outfile=None,
):
    if palette is None:
        palette = sns.color_palette("husl", len(target_range))

    dot = Path(dotfile).read_text()

    # tag nonleaf nodes
    dot = re.sub('label="(\w*)"', 'label="nonleaf\\1"', dot, flags=re.DOTALL)
    # trim leaf nodes (leave marker for tagging)
    dot = re.sub(
        '\nProb:?(.*?)"',
        'Prob:"',
        dot,
        flags=re.DOTALL,
    )
    # tag leaf nodes
    dot = re.sub(
        'label="(\w)Prob:',
        'label="leaf\\1',
        dot,
        flags=re.DOTALL,
    )

    dot = dot.replace(r"\n", "")

    tikz_str = dot2tex(
        dot,
        format="tikz",
        edgelabel=True,
        figonly=True,
    )

    to_replace = {
        "filled, rounded": "",
        r"[draw=black,rectangle,] {leaf": "[" + leafstyle + "] {",
        r"[draw=black,rectangle,] {nonleaf": "[" + nodestyle + "] {",
        r"\draw [black,]": r"\draw[edge1]",
        r"\begin{tikzpicture}[>=latex,line join=bevel,]": r"\begin{tikzpicture}[anchor=center,font=\bf\sffamily\fontsize{6}{6}\selectfont]",
    }
    for key in to_replace.keys():
        tikz_str = tikz_str.replace(key, to_replace[key])

    positions = dot2tex(
        dot,
        format="positions",
        edgelabel=True,
        figonly=True,
    )

    x_coords = np.array([x[0] for x in list(positions.values())])
    newx_coords = np.round(x_coords / (x_coords.max() / width))

    y_coords = np.array([x[1] for x in list(positions.values())])
    newy_coords = np.round(y_coords / (y_coords.max() / height))

    tikz_str = re.sub(
        "..controls(.+) .. \((\d*)\);(\n.*)*?\n.*node (.*);",  # "..controls(.+) .. \((\d)\);\n.* node (.*);",
        r"-- node[" + re.escape(edgelabelstyle) + "]\\4 (\\2);",
        tikz_str,
    )
    tikz_str = re.sub(
        r"\\begin{scope}?(.*?)\\end{scope}",
        "",
        tikz_str,
        flags=re.DOTALL,
    )

    # update node position based on width/height
    s = tikz_str.splitlines()
    node_idx = np.where(["\\node" in x for x in s])[0]
    sn = [x for x in s if "\\node" in x]

    upd_nodes = [
        re.sub("\([0-9]+.[0-9]bp,", "(" + str(newx_coords[i]) + "bp,", sn[i])
        for i in range(len(sn))
    ]

    for i in range(len(node_idx)):
        s[node_idx[i]] = upd_nodes[i]

    sn = [x for x in s if "\\node" in x]

    upd_nodes = [
        re.sub(",[0-9]+\.[0-9]+bp\)", "," + str(newy_coords[i]) + "bp)", sn[i])
        for i in range(len(sn))
    ]

    for i in range(len(node_idx)):
        s[node_idx[i]] = upd_nodes[i]

    tikz_str = "\n".join(s)

    # compute colors based on distribution
    dot = Path(dotfile).read_text()

    d = re.findall("\nProb:?(.*?)Frac", dot, flags=re.DOTALL)
    d = [x.replace(" A", "A").replace("\n", "") for x in d]
    d2 = [[x.split(":") for x in d.split(" ")] for d in d]

    distr = list()
    for x in d2:
        dd = dict()
        for d in x:
            if len(d) > 1:
                dd[d[0]] = float(d[1])
        distr.append(dd)

    mean_cols = list()
    for dd in distr:
        mean_col = np.array([0, 0, 0])

        for i in target_range:
            # key = chr(ord("@") + i)
            key = str(i)
            # print(dd)
            if key in dd.keys():
                mean_col = mean_col + tuple(
                    x * dd[key] for x in np.array(palette[i - 1])
                )
        mean_cols.append(tuple(mean_col))

    mean_cols_txt = [
        "fill={rgb,1:red,"
        + str(mean_col[0])
        + "; green,"
        + str(mean_col[1])
        + "; blue,"
        + str(mean_col[2])
        + "}"
        for mean_col in mean_cols
    ]

    # add leaf colors
    a = tikz_str.splitlines()
    # a = [x for x in a if x != '']

    leaves = [idx for idx, s in enumerate(a) if leafstyle in s]
    for i in range(len(leaves)):
        a[leaves[i]] = a[leaves[i]].replace(
            leafstyle, leafstyle + "," + mean_cols_txt[i]
        )

    # alternate edge text on left/right sides
    edges = [idx for idx, s in enumerate(a) if "draw" in s]

    edge_pos = ["left", "right"] * int(len(edges) / 2)
    for i in range(len(edges)):
        a[edges[i]] = a[edges[i]].replace("node[", "node[" + edge_pos[i] + ",")
        a[edges[i]] = a[edges[i]].replace("{ ", " {")

    out = "\n".join(a)

    print(out)
    if outfile is not None:
        with open(outfile, "w") as text_file:
            text_file.write(out)

    return out

In [81]:
question = str(4439)

# display(Image(f"qnet_trees/bondcourt/{question}.png"))

process_dot(
    f"qnet_trees/bondcourt/{question}.dot",
    target_range=range(1, 6),
    width=100,
    height=100,
    edgelabelstyle="midway,labelfont,yshift=\ystA",
    outfile=f"tree_{question}.tex",
)


\begin{tikzpicture}[anchor=center,font=\bf\sffamily\fontsize{6}{6}\selectfont]
%%

  \node (0) at (13.0bp,43.0bp) [nodeleaf,fill={rgb,1:red,0.737894642081866; green,0.5803747131925207; blue,0.26240567999273673}] {2};
  \node (1) at (21.0bp,62.0bp) [nodei] {4584};
  \node (3) at (56.0bp,43.0bp) [nodei] {4441};
  \node (2) at (41.0bp,24.0bp) [nodeleaf,fill={rgb,1:red,0.91032182998946; green,0.4760495983668104; blue,0.4674591566275868}] {1};
  \node (5) at (83.0bp,24.0bp) [nodei] {4555};
  \node (4) at (66.0bp,5.0bp) [nodeleaf,fill={rgb,1:red,0.8729714179081297; green,0.5799038447958372; blue,0.4470747589385788}] {1};
  \node (6) at (100.0bp,5.0bp) [nodeleaf,fill={rgb,1:red,0.747453797327665; green,0.5531639079242984; blue,0.3593181479488501}] {2};
  \node (7) at (42.0bp,81.0bp) [nodei] {4257};
  \node (8) at (64.0bp,62.0bp) [nodeleaf,fill={rgb,1:red,0.6927793152279502; green,0.6443672388374346; blue,0.4310248091807105}] {2};
  \node (9) at (64.0bp,100.0bp) [nodei] {4255};
  \node (10) a

'\n\\begin{tikzpicture}[anchor=center,font=\\bf\\sffamily\\fontsize{6}{6}\\selectfont]\n%%\n\n  \\node (0) at (13.0bp,43.0bp) [nodeleaf,fill={rgb,1:red,0.737894642081866; green,0.5803747131925207; blue,0.26240567999273673}] {2};\n  \\node (1) at (21.0bp,62.0bp) [nodei] {4584};\n  \\node (3) at (56.0bp,43.0bp) [nodei] {4441};\n  \\node (2) at (41.0bp,24.0bp) [nodeleaf,fill={rgb,1:red,0.91032182998946; green,0.4760495983668104; blue,0.4674591566275868}] {1};\n  \\node (5) at (83.0bp,24.0bp) [nodei] {4555};\n  \\node (4) at (66.0bp,5.0bp) [nodeleaf,fill={rgb,1:red,0.8729714179081297; green,0.5799038447958372; blue,0.4470747589385788}] {1};\n  \\node (6) at (100.0bp,5.0bp) [nodeleaf,fill={rgb,1:red,0.747453797327665; green,0.5531639079242984; blue,0.3593181479488501}] {2};\n  \\node (7) at (42.0bp,81.0bp) [nodei] {4257};\n  \\node (8) at (64.0bp,62.0bp) [nodeleaf,fill={rgb,1:red,0.6927793152279502; green,0.6443672388374346; blue,0.4310248091807105}] {2};\n  \\node (9) at (64.0bp,100.0bp) [

In [None]:
question = str(1076)

# display(Image(f"qnet_trees/bondcourt/{question}.png"))

process_dot(
    f"qnet_trees/bondcourt/{question}.dot",
    target_range=range(1, 6),
    width=100,
    height=100,
    edgelabelstyle="midway,labelfont,yshift=\ystA",
    outfile=f"tree_{question}.tex",
)