In [None]:
# stdlib
import sys
import warnings

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins

warnings.filterwarnings("ignore")

log.remove()
log.add(sink=sys.stderr, level="DEBUG")

eval_plugin = "decaf"

## Load dataset

In [None]:
# stdlib
from typing import Any, Tuple

# third party
import networkx as nx
import numpy as np
import pandas as pd


# It will apply a perturbation at each node provided in perturb.
def gen_data_nonlinear(
    G: Any,
    base_mean: float = 0,
    base_var: float = 0.3,
    mean: float = 0,
    var: float = 1,
    SIZE: int = 10000,
    err_type: str = "normal",
    perturb: list = [],
    sigmoid: bool = True,
    expon: float = 1.1,
) -> pd.DataFrame:
    list_edges = G.edges()
    list_vertex = G.nodes()

    order = []
    for ts in nx.algorithms.dag.topological_sort(G):
        order.append(ts)

    g = []
    for v in list_vertex:
        if v in perturb:
            g.append(np.random.normal(mean, var, SIZE))
            print("perturbing ", v, "with mean var = ", mean, var)
        else:
            if err_type == "gumbel":
                g.append(np.random.gumbel(base_mean, base_var, SIZE))
            else:
                g.append(np.random.normal(base_mean, base_var, SIZE))

    for o in order:
        for edge in list_edges:
            if o == edge[1]:  # if there is an edge into this node
                if sigmoid:
                    g[edge[1]] += 1 / 1 + np.exp(-g[edge[0]])
                else:
                    g[edge[1]] += g[edge[0]] ** 2
    g = np.swapaxes(g, 0, 1)

    return pd.DataFrame(g, columns=list(map(str, list_vertex)))


def generate_synth(size: int = 100) -> Tuple[pd.DataFrame, list, dict]:
    # causal structure is in dag_seed
    dag_seed = [
        [1, 2],
        [1, 3],
        [1, 4],
        [2, 5],
        [2, 0],
        [3, 0],
        [3, 6],
        [3, 7],
        [6, 9],
        [0, 8],
        [0, 9],
    ]
    # edge removal dictionary
    bias_dict = {6: [3]}  # This removes the edge into 6 from 3.

    # DATA SETUP according to dag_seed
    G = nx.DiGraph(dag_seed)
    data = gen_data_nonlinear(G, SIZE=size)
    return data, dag_seed, bias_dict


## Synthetic dataset

In [None]:
data, dag, bias = generate_synth(1000)

data

In [None]:
model = Plugins().get(eval_plugin, n_iter=200)

model.fit(data, dag=dag)

In [None]:
model.generate(10, biased_edges=bias)

## Synthetic dataset - with DAG learning

In [None]:
data, _, _ = generate_synth(200)

data

In [None]:
model = Plugins().get(
    eval_plugin,
    n_iter=200,
    struct_learning_enabled=True,
    #struct_learning_search_method="d-struct",
    batch_size=100,
)

model.fit(data)

In [None]:
model.generate(10)