-
Notifications
You must be signed in to change notification settings - Fork 615
Open
Enhancement
1 / 21 of 2 issues completed
Copy link
Labels
improvementImprovement / enhancement to an existing functionImprovement / enhancement to an existing function
Description
For sufficiently large datasets, cuML's implementation of RandomForestClassifier can produce significantly lower accuracy scores than scikit-learn when using the default parameters. The main culprits seem to be max_depth and max_features.
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
import cuml
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz"
data = pd.read_csv(url, header=None)
X, y = data.iloc[:, :-1], data.iloc[:, -1]
print(X.shape)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print("sklearn RF accuracy: " + str(accuracy_score(y_test, y_pred)))
clf_gpu = cuml.ensemble.RandomForestClassifier(n_estimators=100)
clf_gpu.fit(X_train, y_train)
y_pred = clf_gpu.predict(X_test)
print("cuML RF default accuracy: " + str(accuracy_score(y_test, y_pred)))
clf_gpu_mod = cuml.ensemble.RandomForestClassifier(n_estimators=100, max_depth=30, max_features=1.0)
clf_gpu_mod.fit(X_train, y_train)
y_pred = clf_gpu_mod.predict(X_test)
print("cuML RF accuracy with greater max_depth and max_features: " + str(accuracy_score(y_test, y_pred)))
(581012, 54)
sklearn RF accuracy: 0.955259330654114
cuML RF default accuracy: 0.7143016961696341
cuML RF accuracy with greater max_depth and max_features: 0.9637616928994949cuML's default max_depth is just 16, compared to sklearn which has unlimited depth. Would be great if cuML had better or "smarter" defaults, particularly when working with larger datasets.
Reactions are currently unavailable
Sub-issues
Metadata
Metadata
Assignees
Labels
improvementImprovement / enhancement to an existing functionImprovement / enhancement to an existing function