-
-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modified BaseDecisionTree so that min_weight_fraction_leaf works when… #6947
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -297,9 +297,13 @@ def fit(self, X, y, sample_weight=None, check_input=True, | |
sample_weight = expanded_class_weight | ||
|
||
# Set min_weight_leaf from min_weight_fraction_leaf | ||
if self.min_weight_fraction_leaf != 0. and sample_weight is not None: | ||
if self.min_weight_fraction_leaf != 0.: | ||
if sample_weight is None: | ||
sample_weight = np.repeat(1., n_samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Firstly, we don't need to explicitly do this, since all we want is the sum of weights, i.e. "total weight". So you could just do Secondly, I see why you might want this, but it now replicates the function of min_samples_leaf = max(min_samples_leaf, int(ceil(self.min_weight_fraction_leaf * n_samples))) In terms of testing, you should look at existing tests for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh my bad, I was not aware of that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we fix this in the cython code? |
||
|
||
min_weight_leaf = (self.min_weight_fraction_leaf * | ||
np.sum(sample_weight)) | ||
|
||
else: | ||
min_weight_leaf = 0. | ||
|
||
|
@@ -577,7 +581,7 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): | |
|
||
min_weight_fraction_leaf : float, optional (default=0.) | ||
The minimum weighted fraction of the input samples required to be at a | ||
leaf node. | ||
leaf node where weights are determined by sample_weight in the fit method. | ||
|
||
max_leaf_nodes : int or None, optional (default=None) | ||
Grow a tree with ``max_leaf_nodes`` in best-first fashion. | ||
|
@@ -831,7 +835,7 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): | |
|
||
min_weight_fraction_leaf : float, optional (default=0.) | ||
The minimum weighted fraction of the input samples required to be at a | ||
leaf node. | ||
leaf node where weights are determined by sample_weight in the fit method. | ||
|
||
max_leaf_nodes : int or None, optional (default=None) | ||
Grow a tree with ``max_leaf_nodes`` in best-first fashion. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't actually need this
if