-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Random forest wrapper #621
[WIP] Random forest wrapper #621
Conversation
This PR depends on #407 right? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Saloni --
Great to see this change coming along fast! I know it's early, so I didn't want to go into to many detailed bits. High-level feedback would be:
- When in doubt, let's match the sklearn interfaces/package layouts/class names/etc. So we should do RandomForestClassifier and RandomForestRegressor classes. If we can do it cleanly, they'd just be slim wrappers around an underlying base class like this one (RandomForest) that supports both approaches.
- We should be careful about which functions modify state on the
self
object always. Only the constructor, fit, and very clear "setter" functions should modify state otherwise we get unexpected side effects. - Should break out the unrelated changes from this PR.
- Would be great to add tests super early on and commit them, even if most are failing. That will definitely speed up development.
Thanks!
@@ -0,0 +1,60 @@ | |||
import pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are probably part of a separate change, right? I'd suggest separating them out, keeping them in separate branches on your machine, since it's otherwise easy to accidentally have a commit that spans both changes and gets hard to disengtangle later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saloni based her PR on the branch of PR #407 so thats why you see so many commits, once that is merged those commits should go away from this PR. That said I might recommend not basing new PRs on branches of open PRs, which I mentioned offline to @Salonijain27
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, i believe I did. I can close this PR and create a new one based on the branch-0.8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i fixed this branch by merging the o.8 branch into it
|
||
# min_rows_per_node in cuml = min_samples_split in sklearn | ||
# max_leaves | ||
def __init__(self, n_estimators=25, max_depth=None, max_features=None, min_rows_per_node=None, bootstrap=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we'd make the names and defaults match sklearn. So n_estimators=10
and change min_rows_per_node
to min_samples_split
unless there's a blocker to one of those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a type arg? Classifier vs. regressor e.g.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at the moment we only have classifier, but will add that argument in
|
||
|
||
|
||
class Randomforest(Base): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably RandomForest
to match sklearn style. Maybe this is the base class then we add RandomForestClassifier and RandomForestRegressor in future PRs?
@@ -0,0 +1,213 @@ | |||
import ctypes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be in ensemble/random_forest
to match sklearn.
self.min_rows_per_node = min_rows_per_node | ||
self.bootstrap = bootstrap | ||
|
||
def _get_ctype_ptr(self, obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand this. Is this an idiom used elsewhere? Since it doesn't involve self
, seems like it should be in a utility function in a shared module somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is in base.pyx
, seems like a copy paste issue. PR #612 moves it to a utility function instead indeed! We're thinking on the same line :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Salonijain27 @JohnZed here you can see the new shiny input utility function that will deal with converting any input type and do the corresponding checks needed: https://github.com/rapidsai/cuml/blob/b49981e06b6a629557e89d4be8cded4bca2ca6c7/python/cuml/utils/input_utils.py
def _get_column_ptr(self, obj): | ||
return self._get_ctype_ptr(obj._column._data.to_gpu_array()) | ||
|
||
def fit(self, X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fit should probably take a labels or y param?
input_ptr = self._get_ctype_ptr(X_m) | ||
|
||
cdef cumlHandle* handle_ = <cumlHandle*> <size_t> self.handle.getHandle() | ||
self.labels_ = cudf.Series(np.zeros(self.n_rows, dtype=np.int32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is just a placeholder right now, right?
|
||
cdef uintptr_t input_ptr | ||
if (isinstance(X, cudf.DataFrame)): | ||
self.gdf_datatype = np.dtype(X[X.columns[0]]._column.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, definitely a utility function for this would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are talking about Pr #612 ;) that is the point of that PR
|
||
cdef uintptr_t input_ptr | ||
if (isinstance(X, cudf.DataFrame)): | ||
self.gdf_datatype = np.dtype(X[X.columns[0]]._column.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure that predict should set anything on self. I think it's surprising if predict changes any internal state since you'll often generate one instance and call predict many times on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it should probably be more like checking that the dtype here matches the expected self.dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a self.dtype
(instead of the legacy badly named gdf_datatype
) for all models to be able to check inputs is being standardized on PR #612
This PR can either follow the example there if that one is merged first, or I can change this in that PR if this one makes it first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can edit it to follow the PR #612
input_ptr = self._get_ctype_ptr(X_m) | ||
|
||
cdef cumlHandle* handle_ = <cumlHandle*> <size_t> self.handle.getHandle() | ||
clust_mat = numba_utils.row_matrix(self.cluster_centers_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe copypasta from another algorithm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry i changed it in my local and forgot to update the branch
No description provided.