Skip to content

Commit

Permalink
FIX Fixes OrdinalEncoder.inverse_tranform nan encoded values (#24087)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored and glemaitre committed Aug 5, 2022
1 parent 796f5eb commit 4bdd3b1
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 5 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v1.1.rst
Expand Up @@ -36,6 +36,13 @@ Changelog
a node if there are duplicates in the dataset.
:pr:`23395` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.preprocessing`
............................

- |Fix| :meth:`preprocessing.OrdinalEncoder.inverse_transform` correctly handles
use cases where `unknown_value` or `encoded_missing_value` is `nan`. :pr:`24087`
by `Thomas Fan`_.

.. _changes_1_1_1:

Version 1.1.1
Expand Down
14 changes: 9 additions & 5 deletions sklearn/preprocessing/_encoders.py
Expand Up @@ -1408,19 +1408,23 @@ def inverse_transform(self, X):
found_unknown = {}

for i in range(n_features):
labels = X[:, i].astype("int64", copy=False)
labels = X[:, i]

# replace values of X[:, i] that were nan with actual indices
if i in self._missing_indices:
X_i_mask = _get_mask(X[:, i], self.encoded_missing_value)
X_i_mask = _get_mask(labels, self.encoded_missing_value)
labels[X_i_mask] = self._missing_indices[i]

if self.handle_unknown == "use_encoded_value":
unknown_labels = labels == self.unknown_value
X_tr[:, i] = self.categories_[i][np.where(unknown_labels, 0, labels)]
unknown_labels = _get_mask(labels, self.unknown_value)

known_labels = ~unknown_labels
X_tr[known_labels, i] = self.categories_[i][
labels[known_labels].astype("int64", copy=False)
]
found_unknown[i] = unknown_labels
else:
X_tr[:, i] = self.categories_[i][labels]
X_tr[:, i] = self.categories_[i][labels.astype("int64", copy=False)]

# insert None values for unknown values
if found_unknown:
Expand Down
61 changes: 61 additions & 0 deletions sklearn/preprocessing/tests/test_encoders.py
Expand Up @@ -1928,6 +1928,15 @@ def test_ordinal_encoder_unknown_missing_interaction():
X_test_trans = oe.transform(X_test)
assert_allclose(X_test_trans, [[np.nan], [-3]])

# Non-regression test for #24082
X_roundtrip = oe.inverse_transform(X_test_trans)

# np.nan is unknown so it maps to None
assert X_roundtrip[0][0] is None

# -3 is the encoded missing value so it maps back to nan
assert np.isnan(X_roundtrip[1][0])


@pytest.mark.parametrize("with_pandas", [True, False])
def test_ordinal_encoder_encoded_missing_value_error(with_pandas):
Expand All @@ -1953,3 +1962,55 @@ def test_ordinal_encoder_encoded_missing_value_error(with_pandas):

with pytest.raises(ValueError, match=error_msg):
oe.fit(X)


@pytest.mark.parametrize(
"X_train, X_test_trans_expected, X_roundtrip_expected",
[
(
# missing value is not in training set
# inverse transform will considering encoded nan as unknown
np.array([["a"], ["1"]], dtype=object),
[[0], [np.nan], [np.nan]],
np.asarray([["1"], [None], [None]], dtype=object),
),
(
# missing value in training set,
# inverse transform will considering encoded nan as missing
np.array([[np.nan], ["1"], ["a"]], dtype=object),
[[0], [np.nan], [np.nan]],
np.asarray([["1"], [np.nan], [np.nan]], dtype=object),
),
],
)
def test_ordinal_encoder_unknown_missing_interaction_both_nan(
X_train, X_test_trans_expected, X_roundtrip_expected
):
"""Check transform when unknown_value and encoded_missing_value is nan.
Non-regression test for #24082.
"""
oe = OrdinalEncoder(
handle_unknown="use_encoded_value",
unknown_value=np.nan,
encoded_missing_value=np.nan,
).fit(X_train)

X_test = np.array([["1"], [np.nan], ["b"]])
X_test_trans = oe.transform(X_test)

# both nan and unknown are encoded as nan
assert_allclose(X_test_trans, X_test_trans_expected)
X_roundtrip = oe.inverse_transform(X_test_trans)

n_samples = X_roundtrip_expected.shape[0]
for i in range(n_samples):
expected_val = X_roundtrip_expected[i, 0]
val = X_roundtrip[i, 0]

if expected_val is None:
assert val is None
elif is_scalar_nan(expected_val):
assert np.isnan(val)
else:
assert val == expected_val

0 comments on commit 4bdd3b1

Please sign in to comment.