Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT deprecating presort #14907

Merged
merged 6 commits into from Sep 10, 2019
Merged

Conversation

adrinjalali
Copy link
Member

Related to #14711.

This PR removes presort from the tree codebase and deprecates wherever public facing.

It is a first step towards simplifying the tree codebase, and we can do it since presort is only used in the GradientBoosting* classes, which are superseded by HistGradientBoost*.

@glemaitre is working on benchmarks.

Also ping @ogrisel , @NicolasHug

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we need an entry in the what's new to know that this is deprecated

sklearn/ensemble/gradient_boosting.py Outdated Show resolved Hide resolved
sklearn/ensemble/gradient_boosting.py Show resolved Hide resolved
sklearn/ensemble/gradient_boosting.py Show resolved Hide resolved
sklearn/ensemble/tests/test_gradient_boosting.py Outdated Show resolved Hide resolved
sklearn/tree/tree.py Outdated Show resolved Hide resolved
sklearn/tree/tree.py Show resolved Hide resolved
@glemaitre
Copy link
Member

Code for quick benchmark:

import matplotlib.pyplot as plt
import neurtu
import numpy as np
import pandas as pd

def full_bench():
    from itertools import product
    from sklearn.datasets import make_classification
    from sklearn.model_selection import ParameterGrid
    from sklearn.ensemble import GradientBoostingClassifier

    dict_grid = {
        'presort': [True, False],
        'max_depth': [3]
    }
    param_grid = ParameterGrid(dict_grid)
    all_samples = [1000, 5000, 10000, 50000]
    all_features = [10, 30, 50]

    for n_samples, n_features in product(all_samples, all_features):
        X, y = make_classification(
            n_samples=n_samples, n_features=n_features, random_state=42
        )
        for params in param_grid:
            clf = GradientBoostingClassifier(**params, random_state=42)
            tags = params.copy()
            tags['n_samples'] = n_samples
            tags['n_features'] = n_features
            print(f"tags={tags}")
            yield neurtu.delayed(clf, tags=tags).fit(X, y)

bench = neurtu.Benchmark(wall_time=True, cpu_time=False, peak_memory=False, repeat=3)
results = bench(full_bench())

I'll post the plot soon.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awaiting for the benchmark, the code change is OK

@glemaitre
Copy link
Member

                               wall_time
                                    mean
max_depth n_samples n_features          
3         1000      10          2.254464
                    30          2.687342
                    50          3.060762
          5000      10          2.805958
                    30          3.792815
                    50          3.268488
          10000     10          2.792354
                    30          2.717362
                    50          2.812471
          50000     10          2.799642
                    30          2.471523
                    50          2.751465

So we can expect to have a slow down between ~2.5x-3x

@NicolasHug
Copy link
Member

thanks for the benchmark @glemaitre , what are the actual times?

@glemaitre
Copy link
Member

Ups the results are on my other computer. Let me launch the benchmark. I will add max_depth as well.

@glemaitre
Copy link
Member

                                                  wall_time (mean in secs)
max_depth presort n_samples n_features                     
3         False   1000      10           0.3178995613319178
                            30           0.8377630273268247
                            50           1.3577687386617374
                  5000      10           1.6594447976676747
                            30           4.7921825619996525
                            50            7.917619777998577
                  10000     10              3.5402019986844
                            30           10.292737206696378
                            50           17.040565639346216
                  50000     10           20.917810662339132
                            30            61.16267168098906
                            50            101.9009846470047
          True    1000      10          0.14917945566897592
                            30          0.31050456332741305
                            50           0.4863813063323808
                  5000      10           0.6068839596506829
                            30            1.518422972992994
                            50           2.4471007402947484
                  10000     10           1.2506994560244493
                            30             3.27469397035505
                            50            5.277677528346733
                  50000     10            7.402856699676097
                            30           20.662346292325918
                            50            33.64108543334684
5         False   1000      10           0.5144556166681772
                            30           1.3418558356740202
                            50           2.1755363999982364
                  5000      10           2.6696607836638577
                            30            7.721516522675908
                            50           12.772884557993772
                  10000     10            5.704740655336839
                            30            16.65626473131124
                            50           27.569214578310493
                  50000     10            33.87523479100006
                            30            99.54414345966264
                            50           166.20399805364045
          True    1000      10           0.2619890916588095
                            30           0.5752228413475677
                            50           0.9043408400029875
                  5000      10           1.0712766640159923
                            30           2.7852896840001145
                            50            4.481746415976279
                  10000     10           2.2506565929894955
                            30            6.029141690329804
                            50            9.858553728670813
                  50000     10           13.644354979333002
                            30            38.90828472100353
                            50            63.52168881633164
8         False   1000      10           0.7855316890054382
                            30            2.013416900008451
                            50           3.1628264736694596
                  5000      10            4.243325444015984
                            30           11.859167789322479
                            50            19.43104972932876
                  10000     10            8.972976303969821
                            30           25.771263662027195
                            50            42.46643067333692
                  50000     10            52.70508605833553
                            30            155.0475379213652
                            50            259.6424670080111
          True    1000      10          0.49651414932062227
                            30           1.0923167230018105
                            50           1.6129085013332467
                  5000      10            2.192897269327659
                            30            5.370380771326988
                            50            8.534915591318471
                  10000     10            4.406733286664045
                            30           11.774100963336727
                            50           19.050514400335185
                  50000     10           28.974516748664126
                            30             81.8345820933464
                            50           129.15557014703518

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the benchmark. So basically this makes the code slower, but mostly in cases where the new Hist version would be much better anyway.

Maybe we can recommend to use HistGradientBoostingXXX in the warning message and in the docstring?

LGTM regardless

@glemaitre
Copy link
Member

Maybe we can recommend to use HistGradientBoostingXXX in the warning message and in the docstring?

Indeed, it would be a nice addition.

@adrinjalali
Copy link
Member Author

Added the recommendation, anything else @NicolasHug @glemaitre ?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments, LGTM

Feel free to merge @adrinjalali

doc/whats_new/v0.22.rst Show resolved Hide resolved
sklearn/tree/tests/test_tree.py Show resolved Hide resolved
@adrinjalali adrinjalali merged commit 99d3d34 into scikit-learn:master Sep 10, 2019
@adrinjalali adrinjalali deleted the tree/simplify branch September 10, 2019 18:39
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 9, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 9, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
vasselai added a commit to vasselai/monoensemble that referenced this pull request Aug 10, 2021
Further corrections are necessary to make 'monoensemble' work with current sklearn. The main ones that I need your attention to:

(1) the "presort" and "X_idx_sorted" sklearn parameters have been deprecated. See, respectively:
scikit-learn/scikit-learn#14907
scikit-learn/scikit-learn#16818
Since I don't know how exactly do you prefer to handle that in light of the suggestions in the first link above, in order to at least leave 'monoensemble' in a working state, the only thing I did was to comment out "presort=self.presort" from line 1540 in the file 'mono_gradient_boosting.py'. But a more definitive solution will be necessary, since right now a FutureWarning is issued every iteration due to "X_idx_sorted" deprecation (which, besides being annoying, means that the code will soon be broken again if "X_idx_sorted" is not eliminated from the code base).

(2) in the line 436, from file 'mono_forest.py',  the "_generate_unsampled_indices" throws an error because that function now has an extra parameters, 'n_samples_bootstrap':
https://github.com/scikit-learn/scikit-learn/blob/4b8cd880397f279200b8faf9c75df13801cb45b7/sklearn/ensemble/_forest.py#L123
I obviously also do not know what is your preference here, but given the implementation in that link, it seems safe to assume that thus far your code was operating with the equivalent of 'n_samples_bootstrap = 1'. So that is what I imposed for now in the line 436, from file 'mono_forest.py'.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants