Skip to content

Commit

Permalink
FIX handle properly missing value in MSE and Friedman-MSE `children_i…
Browse files Browse the repository at this point in the history
…mpurity` (#28327)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
  • Loading branch information
3 people committed Feb 13, 2024
1 parent 29911fb commit e4d7d9a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 32 deletions.
3 changes: 2 additions & 1 deletion doc/whats_new/v1.4.rst
Expand Up @@ -153,7 +153,8 @@ Changelog
:class:`tree.DecisionTreeRegressor` are handling missing values properly. The internal
criterion was not initialize when no missing values were present in the data, leading
to potentially wrong criterion values.
:pr:`28295` by :user:`Guillaume Lemaitre <glemaitre>`.
:pr:`28295` by :user:`Guillaume Lemaitre <glemaitre>` and
:pr:`28327` by :user:`Adam Li <adam2392>`.

:mod:`sklearn.utils`
....................
Expand Down
18 changes: 18 additions & 0 deletions sklearn/tree/_criterion.pyx
Expand Up @@ -1150,6 +1150,8 @@ cdef class MSE(RegressionCriterion):
cdef intp_t k
cdef float64_t w = 1.0

cdef intp_t end_non_missing

for p in range(start, pos):
i = sample_indices[p]

Expand All @@ -1160,6 +1162,22 @@ cdef class MSE(RegressionCriterion):
y_ik = self.y[i, k]
sq_sum_left += w * y_ik * y_ik

if self.missing_go_to_left:
# add up the impact of these missing values on the left child
# statistics.
# Note: this only impacts the square sum as the sum
# is modified elsewhere.
end_non_missing = self.end - self.n_missing

for p in range(end_non_missing, self.end):
i = sample_indices[p]
if sample_weight is not None:
w = sample_weight[i]

for k in range(self.n_outputs):
y_ik = self.y[i, k]
sq_sum_left += w * y_ik * y_ik

sq_sum_right = self.sq_sum_total - sq_sum_left

impurity_left[0] = sq_sum_left / self.weighted_n_left
Expand Down
7 changes: 7 additions & 0 deletions sklearn/tree/_splitter.pyx
Expand Up @@ -264,6 +264,13 @@ cdef inline void shift_missing_values_to_left_if_required(
intp_t[::1] samples,
intp_t end,
) noexcept nogil:
"""Shift missing value sample indices to the left of the split if required.
Note: this should always be called at the very end because it will
move samples around, thereby affecting the criterion.
This affects the computation of the children impurity, which affects
the computation of the next node.
"""
cdef intp_t i, p, current_end
# The partitioner partitions the data such that the missing values are in
# samples[-n_missing:] for the criterion to consume. If the missing values
Expand Down
86 changes: 55 additions & 31 deletions sklearn/tree/tests/test_tree.py
Expand Up @@ -15,11 +15,13 @@
from joblib.numpy_pickle import NumpyPickler
from numpy.testing import assert_allclose

from sklearn import datasets, tree
from sklearn import clone, datasets, tree
from sklearn.dummy import DummyRegressor
from sklearn.exceptions import NotFittedError
from sklearn.impute import SimpleImputer
from sklearn.metrics import accuracy_score, mean_poisson_deviance, mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.random_projection import _sparse_random_matrix
from sklearn.tree import (
DecisionTreeClassifier,
Expand Down Expand Up @@ -2510,44 +2512,53 @@ def test_missing_values_poisson():
assert (y_pred >= 0.0).all()


def make_friedman1_classification(*args, **kwargs):
X, y = datasets.make_friedman1(*args, **kwargs)
y = y > 14
return X, y


@pytest.mark.parametrize(
"make_data, Tree",
"make_data,Tree",
[
(datasets.make_regression, DecisionTreeRegressor),
(datasets.make_classification, DecisionTreeClassifier),
(datasets.make_friedman1, DecisionTreeRegressor),
(make_friedman1_classification, DecisionTreeClassifier),
],
)
@pytest.mark.parametrize("sample_weight_train", [None, "ones"])
def test_missing_values_is_resilience(make_data, Tree, sample_weight_train):
"""Check that trees can deal with missing values and have decent performance."""

rng = np.random.RandomState(0)
n_samples, n_features = 1000, 50
X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)
def test_missing_values_is_resilience(
make_data, Tree, sample_weight_train, global_random_seed
):
"""Check that trees can deal with missing values have decent performance."""
n_samples, n_features = 5_000, 10
X, y = make_data(
n_samples=n_samples, n_features=n_features, random_state=global_random_seed
)

# Create dataset with missing values
X_missing = X.copy()
rng = np.random.RandomState(global_random_seed)
X_missing[rng.choice([False, True], size=X.shape, p=[0.9, 0.1])] = np.nan
X_missing_train, X_missing_test, y_train, y_test = train_test_split(
X_missing, y, random_state=0
X_missing, y, random_state=global_random_seed
)

if sample_weight_train == "ones":
sample_weight_train = np.ones(X_missing_train.shape[0])
sample_weight = np.ones(X_missing_train.shape[0])
else:
sample_weight = None

# Train tree with missing values
tree_with_missing = Tree(random_state=rng)
tree_with_missing.fit(X_missing_train, y_train, sample_weight=sample_weight_train)
score_with_missing = tree_with_missing.score(X_missing_test, y_test)
native_tree = Tree(max_depth=10, random_state=global_random_seed)
native_tree.fit(X_missing_train, y_train, sample_weight=sample_weight)
score_native_tree = native_tree.score(X_missing_test, y_test)

# Train tree without missing values
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
tree = Tree(random_state=rng)
tree.fit(X_train, y_train, sample_weight=sample_weight_train)
score_without_missing = tree.score(X_test, y_test)
tree_with_imputer = make_pipeline(
SimpleImputer(), Tree(max_depth=10, random_state=global_random_seed)
)
tree_with_imputer.fit(X_missing_train, y_train)
score_tree_with_imputer = tree_with_imputer.score(X_missing_test, y_test)

# Score is still 90 percent of the tree's score that had no missing values
assert score_with_missing >= 0.9 * score_without_missing
assert (
score_native_tree > score_tree_with_imputer
), f"{score_native_tree=} should be strictly greater than {score_tree_with_imputer}"


def test_missing_value_is_predictive():
Expand Down Expand Up @@ -2617,7 +2628,19 @@ def test_deterministic_pickle():
assert pickle1 == pickle2


def test_regression_tree_missing_values_toy():
@pytest.mark.parametrize(
"X",
[
# missing values will go left for greedy splits
np.array([np.nan, 2, np.nan, 4, 5, 6]),
np.array([np.nan, np.nan, 3, 4, 5, 6]),
# missing values will go right for greedy splits
np.array([1, 2, 3, 4, np.nan, np.nan]),
np.array([1, 2, 3, np.nan, 6, np.nan]),
],
)
@pytest.mark.parametrize("criterion", ["squared_error", "friedman_mse"])
def test_regression_tree_missing_values_toy(X, criterion):
"""Check that we properly handle missing values in regression trees using a toy
dataset.
Expand All @@ -2629,15 +2652,16 @@ def test_regression_tree_missing_values_toy():
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/28254
https://github.com/scikit-learn/scikit-learn/issues/28316
"""

# With this dataset, the missing values will always be sent to the left child
# at the first split. The leaf will be pure.
X = np.array([np.nan, np.nan, 3, 4, 5, 6]).reshape(-1, 1)
X = X.reshape(-1, 1)
y = np.arange(6)

tree = DecisionTreeRegressor(random_state=0).fit(X, y)
tree = DecisionTreeRegressor(criterion=criterion, random_state=0).fit(X, y)
tree_ref = clone(tree).fit(y.reshape(-1, 1), y)
assert all(tree.tree_.impurity >= 0) # MSE should always be positive
# Check the impurity match after the first split
assert_allclose(tree.tree_.impurity[:2], tree_ref.tree_.impurity[:2])

# Find the leaves with a single sample where the MSE should be 0
leaves_idx = np.flatnonzero(
Expand Down

0 comments on commit e4d7d9a

Please sign in to comment.