Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Bootstrapping causes accuracy drop in cuML RF #2895

Closed
hcho3 opened this issue Oct 1, 2020 · 24 comments
Closed

[BUG] Bootstrapping causes accuracy drop in cuML RF #2895

hcho3 opened this issue Oct 1, 2020 · 24 comments
Labels
bug Something isn't working inactive-30d inactive-90d

Comments

@hcho3
Copy link
Contributor

hcho3 commented Oct 1, 2020

Describe the bug
I have been investing the accuracy bug in cuML RF (#2518), and I managed to isolate the cause of the accuracy drop. The bootstrapping option causes cuML RF to do worse than sklearn.

Steps/Code to reproduce bug
Download the dataset in NumPy, which has been obtained from #2561:

Then run the following script:

import itertools

import numpy as np
from sklearn.model_selection import cross_validate, KFold
from sklearn.ensemble import RandomForestClassifier
from cuml.ensemble import RandomForestClassifier as cuml_RandomForestClassifier

# Preprocessed data
X = np.load('data/loans_X.npy')
y = np.load('data/loans_y.npy')

param_range = {
    'n_estimators': [1, 10, 100],
    'max_features': [1.0],
    'bootstrap': [False, True],
    'random_state': [0]
}

max_depth = 21
n_bins = 64

cv_fold = KFold(n_splits=10, shuffle=True, random_state=2020)

param_set = (dict(zip(param_range, x)) for x in itertools.product(*param_range.values()))
for params in param_set:
    print(f'==== params = {params} ====')
    skl_clf = RandomForestClassifier(n_jobs=-1, max_depth=max_depth, **params)
    scores = cross_validate(skl_clf, X, y, cv=cv_fold, n_jobs=-1, return_train_score=True)
    skl_train_acc = scores['train_score']
    skl_cv_acc = scores['test_score']
    print(f'sklearn: Training accuracy = {skl_train_acc.mean()} (std={skl_train_acc.std()}), ' +
          f'CV accuracy = {skl_cv_acc.mean()} (std={skl_cv_acc.std()})')
    
    for split_algo in [0, 1]:
        cuml_clf = cuml_RandomForestClassifier(n_bins=n_bins, max_depth=max_depth, n_streams=1,
                                               split_algo=split_algo, **params)
        scores = cross_validate(cuml_clf, X, y, cv=cv_fold, return_train_score=True)
        cuml_train_acc = scores['train_score']
        cuml_cv_acc = scores['test_score']
        print(f'cuml, split_algo = {split_algo}: Training accuracy = {cuml_train_acc.mean()} ' +
              f'(std={cuml_train_acc.std()}), CV accuracy = {cuml_cv_acc.mean()} ' +
              f'(std={cuml_cv_acc.std()})')

cuML RF gives substantially lower training accuracy than sklearn (up to 9%p lower):

Training accuracy, bootstrap=True

n_estimators sklearn cuML (split_algo=0) cuML (split_algo=1)
1 0.876951 0.822472 0.821807
10 0.925004 0.857921 0.861096
100 0.931354 0.84961 0.852527

On the other hand, turning off bootstrapping with bootstrap=False improves the accuracy of cuML RF relative to sklearn:

Training accuracy, bootstrap=False

n_estimators sklearn cuML (split_algo=0) cuML (split_algo=1)
1 0.92087 0.921404 0.928852
10 0.922088 0.921404 0.928852
100 0.92228 0.921404 0.928852

To make sure that bootstrapping is the issue, I wrote the following script to generate bootstraps with NumPy and fed the same bootstraps into both cuML RF and sklearn:

import time

import numpy as np
from sklearn.base import clone
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from cuml.ensemble import RandomForestClassifier as cuml_RandomForestClassifier

def fit_with_custom_bootstrap(base_estimator, X, y, *, n_estimators, random_state):
    assert len(X.shape) == 2 and len(y.shape) == 1
    assert X.shape[0] == y.shape[0]
    rng = np.random.default_rng(seed=random_state)
    estimators = []
    for _ in range(n_estimators):
        estimator = clone(base_estimator)
        indices = rng.choice(X.shape[0], size=(X.shape[0],), replace=True)
        bootstrap_X, bootstrap_y = X[indices, :], y[indices]
        assert bootstrap_X.shape == X.shape
        assert bootstrap_y.shape == y.shape
        estimator.fit(bootstrap_X, bootstrap_y)

        estimators.append(estimator)
    return estimators

def predict_unweighted_vote(estimators, X_test):
    s = np.zeros((X_test.shape[0], 2))
    for estimator in estimators:
        s[np.arange(X_test.shape[0]), estimator.predict(X_test).astype(np.int32)] += 1.0
    s /= len(estimators)
    return np.argmax(s, axis=1)

def predict_weighted_vote(estimators, X_test):
    s = estimators[0].predict_proba(X_test)
    for estimator in estimators[1:]:
        s += estimator.predict_proba(X_test)
    s /= len(estimators)
    return np.argmax(s, axis=1)

X = np.load('data/loans_X.npy')
y = np.load('data/loans_y.npy')
assert np.array_equal(np.unique(y), np.array([0., 1.]))

max_depth = 21
n_bins = 64
split_algo = 0
n_estimators = 1  # Also number of bootstraps

# Since we generate our own bootstraps, disable bootstrap in cuML / sklearn
params = {
    'n_estimators': 1,
    'max_features': 1.0,
    'bootstrap': False,
    'random_state': 0
}

cuml_clf = cuml_RandomForestClassifier(n_bins=n_bins, max_depth=max_depth, n_streams=1,
                                       split_algo=split_algo, **params)

tstart = time.perf_counter()
estimators = fit_with_custom_bootstrap(cuml_clf, X, y, n_estimators=n_estimators, random_state=0)
tend = time.perf_counter()
print(f'cuml, Training: {tend - tstart} sec')
tstart = time.perf_counter()
y_pred = predict_unweighted_vote(estimators, X)
tend = time.perf_counter()
print(f'cuml, Prediction: {tend - tstart} sec')
print(accuracy_score(y, y_pred))

skl_clf = RandomForestClassifier(n_jobs=-1, max_depth=max_depth, **params)

tstart = time.perf_counter()
estimators = fit_with_custom_bootstrap(skl_clf, X, y, n_estimators=n_estimators, random_state=0)
tend = time.perf_counter()
print(f'sklearn, Training: {tend - tstart} sec')
tstart = time.perf_counter()
y_pred = predict_weighted_vote(estimators, X)
tend = time.perf_counter()
print(f'sklearn, Prediction: {tend - tstart} sec')
print(accuracy_score(y, y_pred))

The results now look a lot better: cuML RF gives competitive training accuracy as sklearn.

n_estimators sklearn cuML (split_algo=0) cuML (split_algo=1)
1 0.87526379 0.875951111 0.875735555
10 0.92300364 0.921437212 0.931396502
100 0.9296966 0.919802517 0.930890215
@hcho3 hcho3 added ? - Needs Triage Need team to review and classify bug Something isn't working and removed ? - Needs Triage Need team to review and classify labels Oct 1, 2020
@github-actions github-actions bot added this to Needs prioritizing in Bug Squashing Oct 1, 2020
@beckernick
Copy link
Member

beckernick commented Oct 1, 2020

Just to add to this, using 500,000 rows from the Higgs Boson dataset (prepared via the first few cells of this notebook: I see the following:

Data

import os

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

# This is a 2.7 GB file.
# Please make sure you have enough space available before
# uncommenting the code below and downloading this file.

DATA_DIRECTORY = "./"
DATASET_PATH = os.path.join(DATA_DIRECTORY, "HIGGS.csv.gz")

# if not os.path.isfile(DATASET_PATH):
#     !wget https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz -P {DATA_DIRECTORY}

# This fuction is borrowed and adapted from
# https://github.com/NVIDIA/gbm-bench/blob/master/datasets.py
# Thanks!

def prepare_higgs(nrows=None):
    higgs = pd.read_csv(DATASET_PATH, nrows=nrows)
    X = higgs.iloc[:, 1:].to_numpy(dtype=np.float32)
    y = higgs.iloc[:, 0].to_numpy(dtype=np.int64)
    return train_test_split(X, y, stratify=y, random_state=77, test_size=0.2)

Default max depths, 64 bins

# NROWS = 500_000
# X_train, X_test, y_train, y_test = prepare_higgs(nrows=NROWS)# in-sample accuracy: i.e., can it learn the signal?
from sklearn.ensemble import RandomForestClassifier
import cumlclf = RandomForestClassifier(n_jobs=-1)
clf.fit(X_train, y_train)
print(clf.score(X_train, y_train))
​
clf = cuml.ensemble.RandomForestClassifier(bootstrap=False, n_bins=64)
clf.fit(X_train, y_train)
print(clf.score(X_train, y_train))
​
​
clf = cuml.ensemble.RandomForestClassifier(bootstrap=True, n_bins=64)
clf.fit(X_train, y_train)
print(clf.score(X_train, y_train))
0.9999975
0.838890016078949
0.7924475073814392

max depth = 20, n_bins = 64

# NROWS = 500_000
# X_train, X_test, y_train, y_test = prepare_higgs(nrows=NROWS)# in-sample accuracy: i.e., can it learn the signal?
from sklearn.ensemble import RandomForestClassifier
import cumlclf = RandomForestClassifier(n_jobs=-1, max_depth=20)
clf.fit(X_train, y_train)
print(clf.score(X_train, y_train))
​
clf = cuml.ensemble.RandomForestClassifier(bootstrap=False, max_depth=20, n_bins=64)
clf.fit(X_train, y_train)
print(clf.score(X_train, y_train))
​
​
clf = cuml.ensemble.RandomForestClassifier(bootstrap=True, max_depth=20, n_bins=64)
clf.fit(X_train, y_train)
print(clf.score(X_train, y_train))
0.9478925
0.9565474987030029
0.8567699790000916

In this example, the results seem to support the bootstrapping hypothesis, as well as provide another datapoint:

  • Bootstrapping is clearly reducing the in-sample accuracy of cuML
  • cuML requires higher max_depth than scikit-learn does to learn the signal, regardless of bootstrapping=False|True

In your tests, you have max_depth=21 for both cuml and scikit-learn tests. Did you see a similar pattern that by default, cuML needed a higher max depth to learn the signal?

@hcho3
Copy link
Contributor Author

hcho3 commented Oct 1, 2020

@beckernick Are you using the latest cuML? The previous version (0.15) had an issue where cuML interpreted max_depth differently than sklearn.

@beckernick
Copy link
Member

beckernick commented Oct 1, 2020

Yes, from the nightly as of around 9 AM EDT on 2020-10-01.

conda list | grep "rapids\|scikit-learn"
# packages in environment at /raid/nicholasb/miniconda3/envs/rapids-tpcx-bb-20201001:
cudf                      0.16.0a201001   cuda_10.2_py37_gffa25070c4_1923    rapidsai-nightly
cuml                      0.16.0a201001   cuda10.2_py37_g7e72203b7_836    rapidsai-nightly
dask-cudf                 0.16.0a201001   py37_gffa25070c4_1923    rapidsai-nightly
faiss-proc                1.0.0                      cuda    rapidsai-nightly
libcudf                   0.16.0a201001   cuda10.2_gffa25070c4_1923    rapidsai-nightly
libcuml                   0.16.0a201001   cuda10.2_g7e72203b7_836    rapidsai-nightly
libcumlprims              0.16.0a200930   cuda10.2_g1c28023_35    rapidsai-nightly
librmm                    0.16.0a201001   cuda10.2_g12ac71a_394    rapidsai-nightly
rmm                       0.16.0a201001   cuda_10.2_py37_g12ac71a_394    rapidsai-nightly
scikit-learn              0.23.2           py37h6785257_0    conda-forge
ucx                       1.8.1+g6b29558       cuda10.2_0    rapidsai-nightly
ucx-proc                  1.0.0                       gpu    rapidsai-nightly
ucx-py                    0.16.0a201001   py37_g6b29558_177    rapidsai-nightly

@hcho3
Copy link
Contributor Author

hcho3 commented Oct 1, 2020

Got it. It is then likely that bootstrapping is not the only issue that's causing accuracy drop. There are probably multiple factors in play. I will investigate further.

@miroenev
Copy link
Contributor

miroenev commented Oct 2, 2020

I'll try this on the Airline delays & NYC taxi datasets and report back. Good detective work so far @hcho3 !

@alexis-intellegens
Copy link

alexis-intellegens commented Oct 2, 2020

Can I just note that the RandomForestRegressor has a similar issue where if R^2 is measured it always seems worse then the Sklearn implementation. This is especially evident on large datasets, eg:

from time import time
import cuml
import cudf
import numpy as np
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from sklearn.ensemble import RandomForestRegressor
from cuml.ensemble import RandomForestRegressor as RFGPU

X, y = make_regression(n_samples=20_000, 
                       n_features=20, 
                       n_informative=10, 
                       n_targets=1, 
                       noise=0.2,
                       random_state=42)
X = X.astype(np.float32)
y = y.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, 
                                                    y, 
                                                    test_size=0.33, 
                                                    random_state=42)

start = time()
regr = RandomForestRegressor(n_estimators=100,
                             max_depth=25,
                             n_jobs=-1,
                             random_state=42)
regr.fit(X_train, y_train)
pred = regr.predict(X_test)
end = time()

print(f"The R^2 score is: {np.round(r2_score(pred, y_test), 2)} and it took {np.round(end - start, 2)} seconds")
# 0.8 R^2

start = time()
# I also tested this with default n_bins, default max_features/bootstrap and split_algo = 1
regr2 = RFGPU(n_estimators=100,
             max_depth=26, max_features="auto", bootstrap=True, n_bins=512,
             split_algo=0)
regr2.fit(X_train, y_train)
pred = regr2.predict(X_test)
end = time()

print(f"The R^2 score is: {np.round(r2_score(pred, y_test), 2)} and it took {np.round(end - start, 2)} seconds")
# 0.6 R^2

@miroenev
Copy link
Contributor

miroenev commented Oct 2, 2020

Hmm, I'm finding some conflicting results where having bootstrap as True leads to better accuracy on binary classification tasks. The accuracy gap on the NYC taxi dataset is somewhat dramatic. Below are results running on a single GPU with 3 cross-validation folds.

RAPIDS image ** = rapidsai/rapidsai-nightly:0.16-cuda11.0-base-ubuntu18.04-py3.7
model = cuml/sklearn.ensemble.RandomForestClassifier
scoring = cuml/sklearn.metrics.accuracy_score
GPU = 1x V100 16GB
CPUs = 2x Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz [ 40 cores total ]

RF Parameters Used

'max_depth': 15,
'max_features': 1.0,
'n_bins': 64, [ GPU only ]
'n_jobs':-1, [ CPU only ]
'n_estimators': 100,
'random_state': 0

Dataset 1 - Airline Stats, 2019 Full Year
Learning Objective = predict arrival delays > 15 minutes
Dataset shape = (7422037, 14)

'bootstrap': True,

cuml: average-score = 0.9364989995956421

  • total_time = 157 seconds
  • cv-fold scores : [0.9365155100822449, 0.936671257019043, 0.9363102316856384]

sklearn: average-score: 0.9391411831653146

  • total_time = 943 seconds
  • fold scores : [0.9390751423454837, 0.9391670491530815, 0.9391813579973782]

'bootstrap': False,

cuml: average-score = 0.9342407782872518

  • total_time = 235 seconds
  • cv-fold scores : [0.9340819120407104, 0.9345200061798096, 0.9341204166412354]

sklearn: average-score 0.9359070174608258

  • total_time = 1344 seconds
  • fold scores : [0.9358985789116253, 0.9359013306124515, 0.9359211428584008]

Dataset 2 - NYC Yellow Cab Trips, 2020 January
Learning Objective = predict above average tips (>$2.20)
Dataset shape = (6339567, 18)

'bootstrap': True,

cuml: average-score = 0.9740167458852133

  • total_time = 144 seconds
  • cv-fold scores : [0.9733003973960876, 0.9758374691009521, 0.9729123711585999]

sklearn: average-score = 0.9947840820278816

  • total_time = 655 seconds
  • fold scores : [0.9951050292385853, 0.9945573578515129, 0.9946898589935466]

'bootstrap': False,

cuml: average-score = 0.8833222190539042

  • total_time = 229 seconds
  • cv-fold scores : [0.9405258893966675, 0.8215708136558533, 0.8878699541091919]

sklearn: average-score: 0.9944858492986693

  • total_time = 940 seconds
  • fold scores : [0.9944778571662927, 0.9944904763226768, 0.9944892144070384]

** essentially identical results with rapidsai/rapidsai:0.15-cuda10.2-base-ubuntu18.04-py3.7

@hcho3

This comment has been minimized.

@vinaydes
Copy link
Contributor

vinaydes commented Oct 6, 2020

I wanted to try iterating over random_state and observe effect on accuracy. However in the code @hcho3 posted above the line estimator = clone(base_estimator) in function fit_with_custom_bootstrap does not copy the random_state parameter thus all the estimators end up having same, default random_state. I replaced it with copy.copy and it worked.
After trying 50 different values for random_state I observed following accuracy numbers for sklearn and cuML RF
image
Sklearn accuracy numbers are normally distributed around 80% accuracy.

image
In cuML RF, in most cases the accuracy is around 80%, in some cases the accuracy revolves arounf 60%. It's almost like there is another smaller bell curve with smaller peak. KISS99 is the default random number generator for cuML RF.

To ensure that KISS99 is not the cause of issue, I re-did the experiment with Philox generator from the cuRAND, the result did not change much.
image

With or without custom bootstrapping, same result was observed. All results observed with latest pull from cuML branch 0.16.

@vinaydes
Copy link
Contributor

vinaydes commented Oct 7, 2020

The script I used for generating the accuracy data

import time
import numpy as np
from sklearn.base import clone
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from cuml.ensemble import RandomForestClassifier as cuml_RandomForestClassifier
from sklearn.utils import shuffle



X = np.load('data/loans_X.npy')
y = np.load('data/loans_y.npy')
# X, y = shuffle(X, y)

assert np.array_equal(np.unique(y), np.array([0., 1.]))

max_depth = 12
n_bins = 64
split_algo = 1
n_estimators = 1  # Also number of bootstraps

params = {
    'n_estimators': 1,
    'max_features': 1.0,
    'bootstrap': True,
    'random_state': 2
}

cuml_clf = cuml_RandomForestClassifier(n_bins=n_bins, max_depth=max_depth, n_streams=1,
                                       split_algo=split_algo, **params)
# for s in range(0, 500):
# for s in [0]:
for s in range(50):
    cuml_clf.random_state = s
    tstart = time.perf_counter()
    cuml_clf.fit(X, y)
    tend = time.perf_counter()
    # print(f'cuml, Training: {tend - tstart} sec')
    # tstart = time.perf_counter()
    y_pred = cuml_clf.predict(X)
    # tend = time.perf_counter()
    print('random_state = ', s, ', ', sep='', end='')
    print('accuracy = ', accuracy_score(y, y_pred), sep='', end='')
    print()

skl_clf = RandomForestClassifier(n_jobs=-1, max_depth=max_depth, **params)
for s in range(0, 500):
for s in [0]:
for s in range(50):
    skl_clf.random_state = s
    tstart = time.perf_counter()
    skl_clf.fit(X, y)
    tend = time.perf_counter()
    # print(f'cuml, Training: {tend - tstart} sec')
    # tstart = time.perf_counter()
    y_pred = skl_clf.predict(X)
    # tend = time.perf_counter()
    print('random_state = ', s, ', ', sep='', end='')
    print('accuracy = ', accuracy_score(y, y_pred), sep='', end='')
    print()

@hcho3
Copy link
Contributor Author

hcho3 commented Oct 13, 2020

@vinaydes I am now seeing a different pattern, where cuML is consistently worse than sklearn (your results show that cuML sometimes does as well as sklearn):
download

My code:

import itertools

import numpy as np
from sklearn.model_selection import cross_validate, KFold
from sklearn.ensemble import RandomForestClassifier
from cuml.ensemble import RandomForestClassifier as cuml_RandomForestClassifier

# Preprocessed data
X = np.load('data/loans_X.npy')
y = np.load('data/loans_y.npy')

param_range = {
    'n_estimators': [1, 10, 100],
    'max_features': [1.0],
    'bootstrap': [False, True],
    'random_state': list(range(50))
}

max_depth = 21
n_bins = 64

cv_fold = KFold(n_splits=10, shuffle=True, random_state=2020)

param_set = (dict(zip(param_range, x)) for x in itertools.product(*param_range.values()))
print('n_estimators,bootstrap,random_state,legend,train_acc,train_acc_std,cv_acc,cv_acc_std')
for params in param_set:
    skl_clf = RandomForestClassifier(n_jobs=-1, max_depth=max_depth, **params)
    scores = cross_validate(skl_clf, X, y, cv=cv_fold, n_jobs=-1, return_train_score=True)
    skl_train_acc = scores['train_score']
    skl_cv_acc = scores['test_score']
    print(f'{params["n_estimators"]},{params["bootstrap"]},{params["random_state"]},sklearn,{skl_train_acc.mean()},{skl_train_acc.std()},{skl_cv_acc.mean()},{skl_cv_acc.std()}')
    
    for split_algo in [0, 1]:
        cuml_clf = cuml_RandomForestClassifier(n_bins=n_bins, max_depth=max_depth, n_streams=1, split_algo=split_algo, **params)
        scores = cross_validate(cuml_clf, X, y, cv=cv_fold, return_train_score=True)
        cuml_train_acc = scores['train_score']
        cuml_cv_acc = scores['test_score']
        print(f'{params["n_estimators"]},{params["bootstrap"]},{params["random_state"]},cuML (split_algo={split_algo}),{cuml_train_acc.mean()},{cuml_train_acc.std()},{cuml_cv_acc.mean()},{cuml_cv_acc.std()}')

The major differences are:

  1. I am using 10-fold cross validation with data shuffling; and
  2. I am setting max_depth to 21.

@darshats
Copy link

Hello,
Is there an update on this bug? I'm using rapids.ai v0.16 for pixel level segmentation and see this pattern as well. Scikit-learn random forest is consistently better by about 3% points.
I have pretty high tree depth of about 128 and training data is ~1.2 million labels. With bootstrap=False there was a small improvement from 96% to 96.3%, whereas scikit is at 99.42%. I can share the data if it helps.

Thanks,
Darshat Shah

@teju85
Copy link
Member

teju85 commented Nov 18, 2020

@darshats we are working towards a better backend that's responsible for building decision trees on GPUs. In the process, we also identified a few shortcomings in the existing backend. We are aiming to have this new backend to be able to completely replace the existing one in the next couple of releases or so.

In our initial studies, we have found that the accuracy of this new backend can be significantly better (atleast on some toy datasets) than the existing one. Are you ok trying out our nightly builds? If yes, then I suggest that you rerun your training with this new backend enabled with the following option: use_experimental_backend. And if you did manage to try this out, we'd love to know if it helped improving accuracy.

@darshats
Copy link

Sure I can try that out. So the code remains exactly the same - except I get one of the latest nightly builds, and set this flag to true?

@teju85
Copy link
Member

teju85 commented Nov 18, 2020

That's correct.

@darshats
Copy link

As per documentation max_depth can only be 14. Whereas I am using 128. Will it work now?

use_experimental_backendboolean (default = False)
If set to true and following conditions are also met, experimental
decision tree training implementation would be used:
split_algo = 1 (GLOBAL_QUANTILE) 0 < max_depth < 14 max_features = 1.0 (Feature sub-sampling disabled) quantile_per_tree = false (No per tree quantile computation)

@teju85
Copy link
Member

teju85 commented Nov 18, 2020

IIRC, there was a PR recently done to remove this max_depth limit. @vinaydes am I right?

@darshats
Copy link

I made the change, the accuracy did go up by 1% approx, but still 1.5% behind scikit-learn's rf.

@teju85
Copy link
Member

teju85 commented Nov 19, 2020

yes @darshats, we are aware of this accuracy difference and are working towards closing this gap. In the meanwhile, can you provide us with the hyper-param values that you used in this experiment?

@darshats
Copy link

Hi Thejaswi,
I used estimators=200, tree_depth 128, split_algo=1, bootstrap=false, experimental_backend=true. All others are default.

@teju85
Copy link
Member

teju85 commented Nov 19, 2020

Thank you. This will be useful for us.

@vinaydes
Copy link
Contributor

Sorry for the delay in reply. As you have already figured, There is no restriction on max_depth in nightly build.

@github-actions
Copy link

This issue has been marked stale due to no recent activity in the past 30d. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be marked rotten if there is no activity in the next 60d.

@github-actions
Copy link

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

@hcho3 hcho3 closed this as completed May 18, 2021
Bug Squashing automation moved this from Needs prioritizing to Closed May 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inactive-30d inactive-90d
Projects
Development

No branches or pull requests

7 participants