Skip to content

Implement better default parameters for RandomForestClassifier #6416

@btepera

Description

@btepera

For sufficiently large datasets, cuML's implementation of RandomForestClassifier can produce significantly lower accuracy scores than scikit-learn when using the default parameters. The main culprits seem to be max_depth and max_features.

import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import cuml

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"

data = pd.read_csv(url, header=None)

X, y = data.iloc[:, :-1], data.iloc[:, -1]
print(X.shape)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print("sklearn RF accuracy: " + str(accuracy_score(y_test, y_pred)))

clf_gpu = cuml.ensemble.RandomForestClassifier(n_estimators=100)
clf_gpu.fit(X_train, y_train)
y_pred = clf_gpu.predict(X_test)
print("cuML RF default accuracy: " + str(accuracy_score(y_test, y_pred)))

clf_gpu_mod = cuml.ensemble.RandomForestClassifier(n_estimators=100, max_depth=30, max_features=1.0)
clf_gpu_mod.fit(X_train, y_train)
y_pred = clf_gpu_mod.predict(X_test)
print("cuML RF accuracy with greater max_depth and max_features: " + str(accuracy_score(y_test, y_pred)))

(581012, 54)
sklearn RF accuracy: 0.955259330654114
cuML RF default accuracy: 0.7143016961696341
cuML RF accuracy with greater max_depth and max_features: 0.9637616928994949

cuML's default max_depth is just 16, compared to sklearn which has unlimited depth. Would be great if cuML had better or "smarter" defaults, particularly when working with larger datasets.

Sub-issues

Metadata

Metadata

Assignees

Labels

improvementImprovement / enhancement to an existing function

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions