Skip to content

Commit

Permalink
Merge branch 'mepland-improve_ctreeviz_univar' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
parrt committed Jan 5, 2023
2 parents 9ef0d45 + 8db4ea5 commit e84511b
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions dtreeviz/trees.py
Expand Up @@ -1473,7 +1473,7 @@ def _ctreeviz_univar(shadow_tree,
for patch in barcontainers:
for rect in patch.patches:
rect.set_linewidth(.5)
rect.set_edgecolor(colors['edge'])
rect.set_edgecolor(colors['rect_edge'])
ax.set_xlim(*overall_feature_range)
ax.set_xticks(overall_feature_range)
ax.set_yticks([0, max([max(h) for h in hist])])
Expand All @@ -1488,23 +1488,27 @@ def _ctreeviz_univar(shadow_tree,
y_noise = np.random.normal(mu + i * class_step, sigma, size=len(bucket))
ax.scatter(bucket, y_noise, alpha=colors['scatter_marker_alpha'], marker='o', s=dot_w, c=color_map[i],
edgecolors=colors['scatter_edge'], lw=.3)
else:
raise ValueError(f'Unrecognized gtype = {gtype}!')

ax.tick_params(axis='both', which='major', width=.3, labelcolor=colors['tick_label'],
labelsize=ticks_fontsize)

splits = [node.split() for node in shadow_tree.internal]
splits = sorted(splits)
bins = [ax.get_xlim()[0]] + splits + [ax.get_xlim()[1]]

if 'splits' in show: # this gets the horiz bars showing prediction region
pred_box_height = .07 * ax.get_ylim()[1]
if 'preds' in show: # this gets the horiz bars showing prediction region
pred_box_height = .07 * (ax.get_ylim()[1] - ax.get_ylim()[0])
bins = [ax.get_xlim()[0]] + splits + [ax.get_xlim()[1]]
for i in range(len(bins) - 1):
left = bins[i]
right = bins[i + 1]
inrange = y_train[(X_train >= left) & (X_train <= right)]
if 0 == len(inrange):
continue
values, counts = np.unique(inrange, return_counts=True)
pred = values[np.argmax(counts)]
rect = patches.Rectangle((left, 0), (right - left), pred_box_height, linewidth=.3,
rect = patches.Rectangle((left, -2*pred_box_height), (right - left), pred_box_height, linewidth=.3, alpha=colors['tesselation_alpha'],
edgecolor=colors['edge'], facecolor=color_map[pred])
ax.add_patch(rect)

Expand All @@ -1518,8 +1522,9 @@ def _ctreeviz_univar(shadow_tree,
ax.set_title(title, fontsize=fontsize, color=colors['title'])

if 'splits' in show:
split_heights = [*ax.get_ylim()]
for split in splits:
ax.plot([split, split], [*ax.get_ylim()], '--', color=colors['split_line'], linewidth=1)
ax.plot([split, split], split_heights, '--', color=colors['split_line'], linewidth=1)


def _ctreeviz_bivar(shadow_tree, fontsize, ticks_fontsize, fontname, show,
Expand Down

0 comments on commit e84511b

Please sign in to comment.