In [1]:
import sys

sys.path.insert(0, "../tests")
import logging

import datasets
import numba
import numpy as np
import numpy.typing as npt
import pandas as pd

import skrough as rgh

logging.basicConfig()

In [2]:
x, x_counts, y, y_count = rgh.dataprep.prepare_factorized_data(
    datasets.golf_dataset(), "Play"
)

In [3]:
logging.getLogger("skrough").setLevel(logging.ERROR)
# logging.getLogger('skrough').manager.disable = logging.NOTSET

In [4]:
logging.getLogger("skrough").setLevel(logging.INFO)

In [7]:
import importlib
import pprint

from attrs import evolve

import skrough.typing as rght
from skrough.algorithms import hooks
from skrough.algorithms.key_names import (
    CONFIG_CHAOS_FUN,
    CONFIG_EPSILON,
    CONFIG_RESULT_ATTRS_MAX_COUNT,
    CONFIG_SELECT_RANDOM_MAX_COUNT,
)
from skrough.algorithms.meta import processing, stage
from skrough.chaos_measures import gini_impurity
from skrough.structs.state import ProcessingState

importlib.reload(processing)
importlib.reload(stage)
importlib.reload(hooks)
importlib.reload(rght)


def prepare_result(state: ProcessingState):
    return state.values["result"]


def finalize(state: ProcessingState) -> None:
    if state.values["depth"] == 0:
        state.values["result"] = [0]
    else:
        res = [state.values["depth"]]
        tmp = state.processing_fun(
            evolve(state, values={"depth": state.values["depth"] - 1})
        )
        res.append(tmp)
        tmp = state.processing_fun(
            evolve(state, values={"depth": state.values["depth"] - 1})
        )
        res.append(tmp)
        state.values["result"] = res


fun = processing.ProcessingMultiStage.from_hooks(
    init_hooks=None,
    finalize_hooks=finalize,
    prepare_result_fun=prepare_result,
)


ps = ProcessingState(
    values={"depth": 4},
    config={},
    processing_fun=fun,
    rng=np.random.default_rng(),
)

pprint.pprint(fun(ps), width=20)

[4,
 [3,
  [2,
   [1, [0], [0]],
   [1, [0], [0]]],
  [2,
   [1, [0], [0]],
   [1, [0], [0]]]],
 [3,
  [2,
   [1, [0], [0]],
   [1, [0], [0]]],
  [2,
   [1, [0], [0]],
   [1, [0], [0]]]]]
