Skip to content

Commit

Permalink
Causal trees bootstrapping and max_leaf_nodes fixes with minor upda…
Browse files Browse the repository at this point in the history
…te (#583)

* Add bootstrap fix for causal forest fit
* Black reformat
* Fix max_leaf_nodes behaviour with BestFirstCausalTreeBuilder
* Add penalty for treatment and control samples distribution
* Add treatment and control info to tree builders
* Manage criterias public variables via BaseCausalDecisionTree
* Add groups info to causal tree plot
* Update CausalTreRegressor and CausalRandomForestRegressor interfaces
* Update causal trees notebooks in examples
* Black code reformat
* Update causal trees tests
* Remove Literal for python=3.7 support
* Fix setup_requires
  • Loading branch information
alexander-pv committed Dec 12, 2022
1 parent 7050c74 commit aa02308
Show file tree
Hide file tree
Showing 14 changed files with 1,947 additions and 1,227 deletions.
3 changes: 2 additions & 1 deletion causalml/inference/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .causal.causaltree import CausalTreeRegressor, CausalRandomForestRegressor
from .causal.causaltree import CausalTreeRegressor
from .causal.causalforest import CausalRandomForestRegressor
from .plot import uplift_tree_string, uplift_tree_plot, plot_dist_tree_leaves_values
from .uplift import DecisionTree, UpliftTreeClassifier, UpliftRandomForestClassifier
from .utils import (
Expand Down
27 changes: 27 additions & 0 deletions causalml/inference/tree/causal/_builder.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# cython: language_level=3
# cython: linetrace=True

from sklearn.tree._tree cimport Node, Tree, TreeBuilder
from sklearn.tree._tree cimport Splitter, SplitRecord
from sklearn.tree._utils cimport StackRecord, Stack
from sklearn.tree._utils cimport PriorityHeapRecord, PriorityHeap
from sklearn.tree._tree cimport SIZE_t, DOUBLE_t


cdef struct FrontierRecord:
# Record of information of a Node, the frontier for a split. Those records are
# maintained in a heap to access the Node with the best improvement in impurity,
# allowing growing trees greedily on this improvement.
SIZE_t node_id
SIZE_t start
SIZE_t end
SIZE_t pos
SIZE_t depth
bint is_leaf
double impurity
double impurity_left
double impurity_right
double improvement
Loading

0 comments on commit aa02308

Please sign in to comment.