Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adds Minimal Cost-Complexity Pruning to Decision Trees #12887

Merged
merged 94 commits into from Aug 20, 2019

Conversation

@thomasjpfan
Copy link
Member

thomasjpfan commented Dec 29, 2018

Reference Issues/PRs

Fixes #6557

What does this implement/fix? Explain your changes.

This PR implements Minimal Cost-Complexity Pruning based on L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification and Regression Trees", Wadsworth, Belmont, CA, 1984.

Most of this implementation is the same as the literature. There are two differences:

  1. In Breiman, r(t) is the an estimate of the probability of misclassification. This PR, uses the impurity as r(t).
  2. The weighted number of samples is used to compute the probability of a point landing on a nodes.

A cost complexity parameter, alpha, was added to __init__ to control cost complexity pruning. The post pruning is done at the end of fit.

The code performing Minimal Cost-Complexity Pruning is mostly done in Python. The Python part produces the node ids that will become the leaves of the new subtree. These leaves are passed to a Cython function called build_pruned_tree that builds a tree. This was written in Cython since the tree building API is in Cython.

In Cython, the Stack class is used to go through the tree. Not all the fields of the StackRecord is used. This is a trade off between the code complexity of adding yet another Stack class, and being a little memory inefficient.

Currently, prune_tree is public, which allows for the following use case:

clf = DecisionTreeClassifier(alpha=0.0)
clf.fit(X, y)
clf.set_params(alpha=0.1)
clf.prune_tree()

If we prefer, we can make prune_tree private and not encourage this use case.

thomasjpfan added 15 commits Dec 28, 2018
Copy link
Member

jnothman left a comment

Some API comments.

Some questions for the novice:

  • does calling prune_tree with the same alpha repeatedly return the same tree?
  • does calling prune_tree with increasing alpha return a strict sub-tree?
sklearn/tree/tree.py Outdated Show resolved Hide resolved
sklearn/tree/tree.py Outdated Show resolved Hide resolved
@@ -510,6 +515,110 @@ def decision_path(self, X, check_input=True):
X = self._validate_X_predict(X, check_input)
return self.tree_.decision_path(X)

def prune_tree(self):

This comment has been minimized.

Copy link
@jnothman

jnothman Jan 1, 2019

Member

I think this needs to be called after fit automatically to facilitate cross validation etc.

I wonder if this should instead be a public function?

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Jan 1, 2019

Author Member

Currently, prune_tree is a public function that is called at the end of fit. It should work with our cross validation classes/functions.

@thomasjpfan

This comment has been minimized.

Copy link
Member Author

thomasjpfan commented Jan 1, 2019

does calling prune_tree with the same alpha repeatedly return the same tree?

As long as the original tree is the same, using the same alpha will return the same tree. I will add a test for this behavior.

does calling prune_tree with increasing alpha return a strict sub-tree?

When alpha gets high enough, the entire tree can be pruned, leaving just the root node.

thomasjpfan added 2 commits Jan 1, 2019
###############################################################################
# Plot training and test scores vs alpha
# --------------------------------------
# Calcuate and plot the the training scores and test accuracy scores

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

Just a few typos (I'll document myself on tree pruning and will try to provide mode in-depth review later)

  • Calcuate
  • the the
  • above "a smaller trees"

also I think you should avoid the `:math:` notation in comments

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Jan 1, 2019

Author Member

The :math: notation is currently used in other examples such as: https://github.com/scikit-learn/scikit-learn/blob/master/examples/svm/plot_svm_scale_c.py. Are we discouraging the usage of :math: in our examples?

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

It's OK in the docstrings since it will be rendered like regular rst by sphinx, but in the comments it is not necessary.

This comment has been minimized.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

Ooh ok I didn't know it worked like that, sorry

thomasjpfan added 2 commits Jan 1, 2019
Copy link
Contributor

NicolasHug left a comment

A first pass of cosmestic comments.

@thomasjpfan it seems to me that in general, and unless there's a compelling reason not to, sklearn code uses (potentially long) descriptive variable names.

For example par_idx could be renamed to parent_idx.
Same for cur_alpha, cur_idx, etc.


# bubble up values to ancestor nodes
for idx in leaf_idicies:
cur_R = r_node[idx]

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

Avoid upper-case in variable names (same for R_diff)

leaves_in_subtree = np.zeros(shape=n_nodes, dtype=np.uint8)

stack = [(0, -1)]
while len(stack) > 0:

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

while stack is more pythonic (same below)


stack = [(0, -1)]
while len(stack) > 0:
node_id, parent = stack.pop()

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

node_idx to stay consistent with the rest of the function.
I would also suggest parent_idx instead of parent,

and the parents array could just be named parent.

# computes number of leaves in all branches and the overall impurity of
# the branch. The overall impurity is the sum of r_node in its leaves.
n_leaves = np.zeros(shape=n_nodes, dtype=np.int32)
leaf_idicies, = np.where(leaves_in_subtree)

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

leaf_indicies

r_branch[leaf_idicies] = r_node[leaf_idicies]

# bubble up values to ancestor nodes
for idx in leaf_idicies:

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

for leaf_idx in...?


# descendants of branch are not in subtree
stack = [cur_idx]
while len(stack) > 0:

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

while stack

inner_nodes[idx] = False
leaves_in_subtree[idx] = 0
in_subtree[idx] = False
n_left, n_right = child_l[idx], child_r[idx]

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

Usually n_something denotes a count of a number. Here those are just indices right?

leaves_in_subtree[cur_idx] = 1

# updates number of leaves
cur_leaves, n_leaves[cur_idx] = n_leaves[cur_idx], 0

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

I would propose

n_pruned_leaves = n_leaves[cur_idx] - 1
n_leaves[cur_idx] = 0

and accordingly update n_leaves[cur_idx] below


# bubble up values to ancestors
cur_idx = parents[cur_idx]
while cur_idx != _tree.TREE_LEAF:

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

It's a bit weird to bubble up to a leaf.

Whatever you're comparing to here should explicitly be the same value as what you used for defining the root's parent above (stack = [(0, -1)])

I would simply use while cur_idx != -1:

@adrinjalali

This comment has been minimized.

Copy link
Member

adrinjalali commented Jan 1, 2019

Some random thoughts:

  • In the context of ensembles and random forests, the parameter needs to be exposed there as well.
  • It'd be nice if we knew the overhead of the pruning. Specially once the user uses it in the context of random forests, then the overhead is multiplied by the number of trees, and that times the number of fits in a grid search, might be significant. Related to that, two point come to mind:
    • having some numbers related to the overhead would be nice.
    • a potential warm_start for the pruning, maybe (cause the rest of fit doesn't have to be run again for different alpha values).
    • contemplating moving the code to cython might be an idea.
  • I'm not sure if it's necessary to create a copy of the tree for the pruned one. Probably having a masked version of the tree would be optimal for trying out multiple alpha values and a potential warm_start. That also depends on how much overhead that copying has.
@NicolasHug

This comment has been minimized.

Copy link
Contributor

NicolasHug commented Jan 1, 2019

@jnothman , to add to @thomasjpfan answers:

does calling prune_tree with the same alpha repeatedly return the same tree?

The procedure is deterministic so calling prune_tree with same alpha and same original tree will give you the same pruned tree. Also as far as I understand, tree.prune_tree(alpha) == tree.prune_tree(alpha).prune_tree(alpha).

does calling prune_tree with increasing alpha return a strict sub-tree?

A subtree yes, but not necessarily a strict one:

with slpha1 > alpha2, tree.prune_tree(alpha_1) is a subtree of tree.prune_tree(alpha_2) but they may also be equal. This is because the alpha parameter is only used as a threshold here.

in_subtree = np.ones(shape=n_nodes, dtype=np.bool)

cur_alpha = 0
while cur_alpha < self.alpha:

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 1, 2019

Contributor

2 thoughts:

  • on the resources that I read (here and here), the very first pruning step is to remove all the pure leaves (equivalent to using alpha=0 apparently). This is not done here since cur_alpha is immediately overwritten. I wonder if this is done by default in the tree growing algorithm.
  • As you check for cur_alpha < self.alpha and cur_alpha is computed before the tree is pruned in the loop, this means that the alpha of the returned pruned tree will be greater than self.alpha. It would seem more natural to me to return a tree whose alpha is less than self.alpha. In any case we would need to explain how alpha is used in the docs, something like "subtrees whose scores are less than alpha are discarded. The score is computed as ..."

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Jan 7, 2019

Author Member
  • When alpha = 0, the number of leaves does not contribute to the cost-complexity measure, which I interpreted as "do not prune". Removing the leaves when alpha=0, will increase the cost-complexity measure.

  • Returning a tree whose alpha is less than self.alpha makes sense and should documented.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 7, 2019

Contributor

Removing the leaves when alpha=0, will increase the cost-complexity measure.

It cannot increase the cost-complexity: the first step is to prune (some of) the pure leaves. That is if a node N has 2 child leaves where all the samples in both leaves belong to the same class, then the first step will remove those 2 leaves and make N a leaf (which will still be pure). The process is repeated with N and its sibling if needed.

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Jan 7, 2019

Author Member

That is if a node N has 2 child leaves where all the samples in both leaves belong to the same class, then the first step will remove those 2 leaves and make N a leaf

This makes sense. I will review the tree building code to see if this can happen.

To prevent future confusion, I want to get on the same page with our definition of a pure leaf. From my understanding, a pure leaf is a leaf whose samples belong to the same class, independent of all other leaves. From reading your response, you consider two leaves to be pure if they are siblings and their samples belong to the same class. Is this correct?

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 7, 2019

Contributor

From my understanding, a pure leaf is a leaf whose samples belong to the same class, independent of all other leaves

I meant this as well

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jan 7, 2019

Contributor

I just looked at the tree code, I think we can assume that this "first step" is not needed here after all, since a node is made a leaf according to the min_impurity_decrease (or deprecated min_impurity_split) parameter.

That is, if a node is pure according to min_impurity_decrease, it will be made a leaf, and thus the example case I mentioned above (a node with 2 pure leaves) cannot exist.

@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Jan 2, 2019

Where I was going with my questions was the idea of warm_start as well...

@thomasjpfan

This comment has been minimized.

Copy link
Member Author

thomasjpfan commented Jan 7, 2019

@adamgreenhall @jnothman

  • Exposing to the ensemble trees make sense.
  • I will do some experimenting to benchmark the overhead of pruning, in its current form and a Cython version of it. I'll post the results here.
  • The masked tree together with the warm_start parameter are great ideas. A masked tree would allow for the level of pruning to be adjusted without growing the tree again, which looks really nice. The current copying approach, allows for the original tree to be deleted, and the pruned tree will take up less space.
@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Jan 8, 2019

@amueller

This comment has been minimized.

Copy link
Member

amueller commented Jul 17, 2019

conflicts!

Copy link
Member

adrinjalali left a comment

I think I forgot to approve here. LGTM, excited about this one, thanks @thomasjpfan

@adrinjalali

This comment has been minimized.

Copy link
Member

adrinjalali commented Jul 26, 2019

This is kinda ready, isn't it?

@thomasjpfan

This comment has been minimized.

Copy link
Member Author

thomasjpfan commented Jul 26, 2019

@adrinjalali This should be ready to go!

Copy link
Contributor

NicolasHug left a comment

I made these comments a while ago and didn't submit, I'll try to provide another round soon


.. math::
R_\alpha(T) = R(T) + \alpha|\tilde{T}|

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor

why use tilde? Using |T| should be enough

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_tree_plot_cost_complexity_pruning.py`

.. topic:: References:

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor

Please mention directly link to it in the section. There are many references, it's not obvious which one is used for what.

:class:`ensemble.RandomTreesEmbedding`,
:class:`ensemble.GradientBoostingClassifier`,
and :class:`ensemble.GradientBoostingRegressor`.
:issue:`12887` by :user:`Thomas Fan <thomasjpfan>`.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor
Suggested change
:issue:`12887` by :user:`Thomas Fan <thomasjpfan>`.
:pr:`12887` by :user:`Thomas Fan <thomasjpfan>`.
complexity pruning provides another option to control the size of a tree. In
:class:`DecisionTreeClassifier`, this pruning technique is parameterized by the
cost complexity parameter, ``ccp_alpha``. Greater values of ``ccp_alpha``
increases the number of nodes pruned. Here we only show the effect of

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor
Suggested change
increases the number of nodes pruned. Here we only show the effect of
increase the number of nodes pruned. Here we only show the effect of
# Minimal cost complexity pruning recursively finds the node with the
# "weakest link". The weakest link is characterized by an effective alpha,
# where the nodes with the smallest effective alpha are pruned first.
# scikit-learn provides a

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor
Suggested change
# scikit-learn provides a
# scikit-learn provides

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor

I would also mention the purpose of the function, e.g.

To get an idea of what values of cpp_alpha could be appropriate, scikit-learn provides ...

# that returns the effective alphas and the corresponding total leaf impurities
# at each step of the pruning process. As alpha increases, more
# of the tree is pruned, which increases the total impurity of its leaves.
# In the following plot, the maximum effective alpha value is removed,

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor

I would suggest leaving this last sentence for a comment in the code.

There's a lot to process reading the example, and one may not immediately understand that you ignore it simply because it would flatten the plots.

ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(ccp_alphas, train_scores, label="train", drawstyle="steps-post")

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 26, 2019

Contributor

about steps-post, can we also add the actual points, to disambiguate the last ones?

Copy link
Contributor

NicolasHug left a comment

Some more.

Mostly looks good.

I wish we could have more tests. Can you think of some edge cases? E.g. trying to prune a single node tree?

@@ -0,0 +1,8 @@
"""

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

remove? ^^

.. _minimal_cost_complexity_pruning
Minimal Cost-Complexity Pruning
===============================

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

Let's add a link to this section in every docstring for ccp_alpha. Else, there's no way users can know what ccp_alpha really is and how it works when they look e.g. at the RandomForest docstring.

@@ -508,6 +520,56 @@ def decision_path(self, X, check_input=True):
X = self._validate_X_predict(X, check_input)
return self.tree_.decision_path(X)

def _prune_tree(self):
"""Prunes tree using Minimal Cost-Complexity Pruning."""

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor
Suggested change
"""Prunes tree using Minimal Cost-Complexity Pruning."""
"""Prune tree using Minimal Cost-Complexity Pruning."""
# Build Pruned Tree
# =============================================================================

cpdef build_pruned_tree_ccp(

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

the cpdef functions can just be def, and do not need be in the pxd file.

cdef _cost_complexity_prune(Tree orig_tree,
_CCPPruneController controller,
unsigned char[:] leaves_in_subtree):
"""Performs cost complexity pruning.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

Please describe in more details.

What does that exactly mean? What are the IN parameters, what are the OUT parameters?

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Aug 16, 2019

Author Member

Added more details in the docstring


cdef:
UINT32_t total_items = path_finder.count
np.ndarray ccp_alphas = np.empty(shape=total_items,

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

just use a view?

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Aug 16, 2019

Author Member

Defining the numpy array here is nice since ccp_alphas and impurities will be returned as a dictionary for python to use.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor

I meant float64 [:] ccp_alphas = np.empty(shape=total_items, dtype=np.float64)

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor

ping but not super important I think

This comment has been minimized.

Copy link
@thomasjpfan

thomasjpfan Aug 16, 2019

Author Member

Later when:

return {'ccp_alphas': ccp_alphas, 'impurities': impurities}

is called, this was designed to return a dict of numpy arrays (not memoryviews).

To make this clear, I added a docstring to ccp_pruning_path (which should have been there).



cdef class _CCPPruneController:
"""Base class used by build_pruned_tree_ccp to control pruning

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

and ccp_pruning_path

return self.ccp_alpha < effective_alpha

cdef void after_pruning(self, unsigned char[:] in_subtree) nogil:
"""Called after pruning"""

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

Please describe why this is needed and what it does

sklearn/tree/_tree.pyx Show resolved Hide resolved
self.tree_ = pruned_tree

def cost_complexity_pruning_path(self, X, y, sample_weight=None):
"""Prune tree using Minimal Cost-Complexity Pruning.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Jul 29, 2019

Contributor

That's the docstring for _prune_tree, please update

thomasjpfan added 11 commits Jul 30, 2019
Copy link
Contributor

NicolasHug left a comment

Thanks Thomas, last minor comments but LGTM!

:math:`t`, and its branch, :math:`T_t`, can be equal depending on
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor
Suggested change
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node
:math:`\alpha_{eff}(t)=(R(t)-R(T_t))/(|T|-1)`. A non-terminal node

removed tilde

Minimal Cost-Complexity Pruning
===============================

Minimal cost-complexity pruning is an algorithm used to prune a tree after it

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor

Please add ref here to L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees (Chapter 3)

ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(ccp_alphas, train_scores, label="train", drawstyle="steps-post")

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor
Suggested change
ax.plot(ccp_alphas, train_scores, label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, train_scores, marker='o', label="train", drawstyle="steps-post")

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor

same below

Tree orig_tree,
DOUBLE_t ccp_alpha):
"""Build a pruned tree from the original tree by transforming the nodes in
leaves_in_subtree into leaves.

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor

Remove this one??


cdef:
UINT32_t total_items = path_finder.count
np.ndarray ccp_alphas = np.empty(shape=total_items,

This comment has been minimized.

Copy link
@NicolasHug

NicolasHug Aug 16, 2019

Contributor

ping but not super important I think

thomasjpfan added 5 commits Aug 16, 2019
Copy link
Member

jnothman left a comment

Incidental comments

where :math:`|\tilde{T}|` is the number of terminal nodes in :math:`T` and
:math:`R(T)` is traditionally defined as the total misclassification rate of
the terminal nodes. Alternatively, scikit-learn uses the total sample weighted
impurity of the terminal nodes for :math:`R(T)`. As shown in the previous

This comment has been minimized.

Copy link
@jnothman

jnothman Aug 18, 2019

Member

best to say "above" rather than "in the previous section" or link to it so it can withstand change.

:math:`t`, and its branch, :math:`T_t`, can be equal depending on
:math:`\alpha`. We define the effective :math:`\alpha` of a node to be the
value where they are equal, :math:`R_\alpha(T_t)=R_\alpha(t)` or
:math:`\alpha_{eff}=(R(t)-R(T_t))/(|\tilde{T}|-1)`. A non-terminal node with

This comment has been minimized.

Copy link
@jnothman

jnothman Aug 18, 2019

Member

use \frac since this is not easily readable anyway.

===============================

Minimal cost-complexity pruning is an algorithm used to prune a tree, described
in Chapter 3 of [BRE]_. This algorithm is parameterized by :math:`\alpha\ge0`

This comment has been minimized.

Copy link
@jnothman

jnothman Aug 18, 2019

Member

It might be worth adding a small note to say that this is one method used to avoid over-fitting in trees.

:mod:`sklearn.tree`
...................

- |Feature| Adds minimal cost complexity pruning to

This comment has been minimized.

Copy link
@jnothman

jnothman Aug 18, 2019

Member

might be worth mentioning what the public api is... i.e. is it just ccp_alpha?

@NicolasHug NicolasHug merged commit 67c94c7 into scikit-learn:master Aug 20, 2019
15 of 17 checks passed
15 of 17 checks passed
scikit-learn.scikit-learn Build #20190818.20 had test failures
Details
scikit-learn.scikit-learn (Linux pylatest_conda_mkl_pandas) Linux pylatest_conda_mkl_pandas failed
Details
LGTM analysis: C/C++ No code changes detected
Details
LGTM analysis: JavaScript No code changes detected
Details
LGTM analysis: Python No new or fixed alerts
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: doc Your tests passed on CircleCI!
Details
ci/circleci: doc-min-dependencies Your tests passed on CircleCI!
Details
ci/circleci: lint Your tests passed on CircleCI!
Details
codecov/patch 100% of diff hit (target 96.88%)
Details
codecov/project Absolute coverage decreased by -0.18% but relative coverage increased by +3.11% compared to e8f2708
Details
scikit-learn.scikit-learn (Linux py35_conda_openblas) Linux py35_conda_openblas succeeded
Details
scikit-learn.scikit-learn (Linux py35_ubuntu_atlas) Linux py35_ubuntu_atlas succeeded
Details
scikit-learn.scikit-learn (Linux32 py35_ubuntu_atlas_32bit) Linux32 py35_ubuntu_atlas_32bit succeeded
Details
scikit-learn.scikit-learn (Windows py35_pip_openblas_32bit) Windows py35_pip_openblas_32bit succeeded
Details
scikit-learn.scikit-learn (Windows py37_conda_mkl) Windows py37_conda_mkl succeeded
Details
scikit-learn.scikit-learn (macOS pylatest_conda_mkl) macOS pylatest_conda_mkl succeeded
Details
@NicolasHug

This comment has been minimized.

Copy link
Contributor

NicolasHug commented Aug 20, 2019

Failure is unrelated and Joel's comments were addressed so I guess it's good to merge 🎉

Thanks @thomasjpfan !

@jnothman

This comment has been minimized.

Copy link
Member

jnothman commented Aug 20, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
6 participants
You can’t perform that action at this time.