Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leaf samples distribution #61

Merged
merged 8 commits into from Oct 28, 2019
11 changes: 9 additions & 2 deletions dtreeviz/shadow.py
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
from sys import maxsize
from collections import defaultdict, Sequence
from typing import Mapping, List, Tuple
from numbers import Number
Expand Down Expand Up @@ -230,17 +231,23 @@ def get_node_type(_tree_model):
return node_type

@staticmethod
def get_leaf_sample_counts(_tree_model):
def get_leaf_sample_counts(_tree_model, min_samples=0, max_samples=maxsize):
parrt marked this conversation as resolved.
Show resolved Hide resolved
"""Get the number of samples for each leaf.

:param min_samples: int
Min number of samples for a leaf
:param max_samples: int
Max number of samples for a leaf

:return: tuple
Contains a list of leaf ids and a list of leaf samples
"""

node_type = ShadowDecTree.get_node_type(_tree_model)
n_node_samples = _tree_model.tree_.n_node_samples

leaf_samples = [(i, n_node_samples[i]) for i in range(0, _tree_model.tree_.node_count) if node_type[i]]
leaf_samples = [(i, n_node_samples[i]) for i in range(0, _tree_model.tree_.node_count) if node_type[i]
and min_samples <= n_node_samples[i] <= max_samples]
x, y = zip(*leaf_samples)
return x, y

Expand Down
31 changes: 29 additions & 2 deletions dtreeviz/trees.py
Expand Up @@ -6,6 +6,7 @@
import matplotlib.patches as patches
import tempfile
import os
from sys import maxsize
from sys import platform as PLATFORM
from colour import Color, rgb2hex
from typing import Mapping, List
Expand Down Expand Up @@ -1216,12 +1217,17 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
colors: dict = None,
fontsize: int = 14,
fontname: str = "Arial",
grid: bool = False):
grid: bool = False,
bins: int = 10,
min_samples: int = 0,
max_samples: int = maxsize):
"""Visualize the number of training samples from each leaf.

If display_type = 'plot' it will show leaf samples using a plot.
If display_type = 'text' it will show leaf samples as plain text. This method is preferred if number
of leaves is very large and the plot become very big and hard to interpret.
If display_type = 'hist' it will show leaf sample histogram. Useful when you want to easily see the general
distribution of leaf samples.

:param tree_model: sklearn.tree
The tree to interpret
Expand All @@ -1237,9 +1243,15 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
Plot labels font name
:param grid: bool
Whether to show the grid lines
:param bins: int
Number of histogram bins
:param min_samples: int
Min number of samples for a leaf
:param max_samples: int
Max number of samples for a leaf
"""

leaf_id, leaf_samples = ShadowDecTree.get_leaf_sample_counts(tree_model)
leaf_id, leaf_samples = ShadowDecTree.get_leaf_sample_counts(tree_model, min_samples, max_samples)

if display_type == "plot":
colors = adjust_colors(colors)
Expand All @@ -1262,6 +1274,21 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
elif display_type == "text":
for leaf, samples in zip(leaf_id, leaf_samples):
print(f"leaf {leaf} has {samples} samples")
elif display_type == "hist":
colors = adjust_colors(colors)

fig, ax = plt.subplots(figsize=figsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(.3)
ax.spines['bottom'].set_linewidth(.3)
n, bins, patches = ax.hist(leaf_samples, bins=bins, color=colors["hist_bar"])
for rect in patches:
rect.set_linewidth(.5)
rect.set_edgecolor(colors['rect_edge'])
ax.set_xlabel("leaf sample", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
ax.set_ylabel("leaf count", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
ax.grid(b=grid)


def ctreeviz_leaf_samples(tree_model: (tree.DecisionTreeClassifier),
Expand Down