Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leaf samples distribution #61

Merged
merged 8 commits into from Oct 28, 2019
17 changes: 13 additions & 4 deletions dtreeviz/shadow.py
Expand Up @@ -230,19 +230,28 @@ def get_node_type(_tree_model):
return node_type

@staticmethod
def get_leaf_sample_counts(_tree_model):
def get_leaf_sample_counts(_tree_model, min_samples=0, max_samples=None):
"""Get the number of samples for each leaf.

There is the option to filter the leaves with less than min_samples or more than max_samples.

:param min_samples: int
Min number of samples for a leaf
:param max_samples: int
Max number of samples for a leaf

:return: tuple
Contains a list of leaf ids and a list of leaf samples
Contains a numpy array of leaf ids and an array of leaf samples
"""

node_type = ShadowDecTree.get_node_type(_tree_model)
n_node_samples = _tree_model.tree_.n_node_samples

leaf_samples = [(i, n_node_samples[i]) for i in range(0, _tree_model.tree_.node_count) if node_type[i]]
max_samples = max_samples if max_samples else n_node_samples.max()
leaf_samples = [(i, n_node_samples[i]) for i in range(0, _tree_model.tree_.node_count) if node_type[i]
and min_samples <= n_node_samples[i] <= max_samples]
x, y = zip(*leaf_samples)
return x, y
return np.array(x), np.array(y)

@staticmethod
def get_leaf_sample_counts_by_class(_tree_model):
Expand Down
33 changes: 31 additions & 2 deletions dtreeviz/trees.py
Expand Up @@ -1257,12 +1257,20 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
colors: dict = None,
fontsize: int = 14,
fontname: str = "Arial",
grid: bool = False):
grid: bool = False,
bins: int = 10,
min_samples: int = 0,
max_samples: int = None):
"""Visualize the number of training samples from each leaf.

There is the option to filter the leaves with less than min_samples or more than max_samples. This is helpful
especially when you want to investigate leaves with number of samples from a specific range.

If display_type = 'plot' it will show leaf samples using a plot.
If display_type = 'text' it will show leaf samples as plain text. This method is preferred if number
of leaves is very large and the plot become very big and hard to interpret.
If display_type = 'hist' it will show leaf sample histogram. Useful when you want to easily see the general
distribution of leaf samples.

:param tree_model: sklearn.tree
The tree to interpret
Expand All @@ -1278,9 +1286,15 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
Plot labels font name
:param grid: bool
Whether to show the grid lines
:param bins: int
Number of histogram bins
:param min_samples: int
Min number of samples for a leaf
:param max_samples: int
Max number of samples for a leaf
"""

leaf_id, leaf_samples = ShadowDecTree.get_leaf_sample_counts(tree_model)
leaf_id, leaf_samples = ShadowDecTree.get_leaf_sample_counts(tree_model, min_samples, max_samples)

if display_type == "plot":
colors = adjust_colors(colors)
Expand All @@ -1303,6 +1317,21 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
elif display_type == "text":
for leaf, samples in zip(leaf_id, leaf_samples):
print(f"leaf {leaf} has {samples} samples")
elif display_type == "hist":
colors = adjust_colors(colors)

fig, ax = plt.subplots(figsize=figsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(.3)
ax.spines['bottom'].set_linewidth(.3)
n, bins, patches = ax.hist(leaf_samples, bins=bins, color=colors["hist_bar"])
for rect in patches:
rect.set_linewidth(.5)
rect.set_edgecolor(colors['rect_edge'])
ax.set_xlabel("leaf sample", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
ax.set_ylabel("leaf count", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
ax.grid(b=grid)


def ctreeviz_leaf_samples(tree_model: (tree.DecisionTreeClassifier),
Expand Down