Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ColumnTransformer in parallel with joblib's auto memmapping #28822

Merged
merged 5 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ Changelog
being explicitly set as well.
:pr:`28483` by :user:`Stefanie Senger <StefanieSenger>`.

- |Fix| Fixed an bug in :class:`compose.ColumnTransformer` with `n_jobs > 1`, where the
intermediate selected columns were passed to the transformers as read-only arrays.
:pr:`28822` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.cross_decomposition`
..................................

Expand Down
9 changes: 5 additions & 4 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..base import TransformerMixin, _fit_context, clone
from ..pipeline import _fit_transform_one, _name_estimators, _transform_one
from ..preprocessing import FunctionTransformer
from ..utils import Bunch, _safe_indexing
from ..utils import Bunch
from ..utils._estimator_html_repr import _VisualBlock
from ..utils._indexing import _get_column_indices
from ..utils._metadata_requests import METHODS
Expand Down Expand Up @@ -389,7 +389,7 @@ def set_params(self, **kwargs):

def _iter(self, fitted, column_as_labels, skip_drop, skip_empty_columns):
"""
Generate (name, trans, column, weight) tuples.
Generate (name, trans, columns, weight) tuples.


Parameters
Expand Down Expand Up @@ -794,7 +794,7 @@ def _call_func_on_transformers(self, X, y, func, column_as_labels, routed_params
)
try:
jobs = []
for idx, (name, trans, column, weight) in enumerate(transformers, start=1):
for idx, (name, trans, columns, weight) in enumerate(transformers, start=1):
if func is _fit_transform_one:
if trans == "passthrough":
output_config = _get_output_config("transform", self)
Expand All @@ -813,9 +813,10 @@ def _call_func_on_transformers(self, X, y, func, column_as_labels, routed_params
jobs.append(
delayed(func)(
transformer=clone(trans) if not fitted else trans,
X=_safe_indexing(X, column, axis=1),
X=X,
y=y,
weight=weight,
columns=columns,
**extra_args,
params=routed_params[name],
)
Expand Down
27 changes: 26 additions & 1 deletion sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import warnings

import joblib
import numpy as np
import pytest
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -36,7 +37,7 @@
assert_almost_equal,
assert_array_equal,
)
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.fixes import CSR_CONTAINERS, parse_version


class Trans(TransformerMixin, BaseEstimator):
Expand Down Expand Up @@ -2447,6 +2448,30 @@ def test_column_transformer_error_with_duplicated_columns(dataframe_lib):
transformer.fit_transform(df)


@pytest.mark.skipif(
parse_version(joblib.__version__) < parse_version("1.3"),
reason="requires joblib >= 1.3",
)
def test_column_transformer_auto_memmap():
"""Check that ColumnTransformer works in parallel with joblib's auto-memmapping.

non-regression test for issue #28781
"""
X = np.random.RandomState(0).uniform(size=(3, 4))

scaler = StandardScaler(copy=False)

transformer = ColumnTransformer(
transformers=[("scaler", scaler, [0])],
n_jobs=2,
)

with joblib.parallel_backend("loky", max_nbytes=1):
Xt = transformer.fit_transform(X)

assert_allclose(Xt, StandardScaler().fit_transform(X[:, [0]]))


# Metadata Routing Tests
# ======================

Expand Down
24 changes: 20 additions & 4 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .base import TransformerMixin, _fit_context, clone
from .exceptions import NotFittedError
from .preprocessing import FunctionTransformer
from .utils import Bunch
from .utils import Bunch, _safe_indexing
from .utils._estimator_html_repr import _VisualBlock
from .utils._metadata_requests import METHODS
from .utils._param_validation import HasMethods, Hidden
Expand Down Expand Up @@ -1258,7 +1258,7 @@ def make_pipeline(*steps, memory=None, verbose=False):
return Pipeline(_name_estimators(steps), memory=memory, verbose=verbose)


def _transform_one(transformer, X, y, weight, params):
def _transform_one(transformer, X, y, weight, columns=None, params=None):
"""Call transform and apply weight to output.

Parameters
Expand All @@ -1275,11 +1275,17 @@ def _transform_one(transformer, X, y, weight, params):
weight : float
Weight to be applied to the output of the transformation.

columns : str, array-like of str, int, array-like of int, array-like of bool, slice
Columns to select before transforming.

params : dict
Parameters to be passed to the transformer's ``transform`` method.

This should be of the form ``process_routing()["step_name"]``.
"""
if columns is not None:
X = _safe_indexing(X, columns, axis=1)

res = transformer.transform(X, **params.transform)
# if we have a weight for this transformer, multiply output
if weight is None:
Expand All @@ -1288,7 +1294,14 @@ def _transform_one(transformer, X, y, weight, params):


def _fit_transform_one(
transformer, X, y, weight, message_clsname="", message=None, params=None
transformer,
X,
y,
weight,
columns=None,
message_clsname="",
message=None,
params=None,
):
"""
Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
Expand All @@ -1297,6 +1310,9 @@ def _fit_transform_one(

``params`` needs to be of the form ``process_routing()["step_name"]``.
"""
if columns is not None:
X = _safe_indexing(X, columns, axis=1)

params = params or {}
with _print_elapsed_time(message_clsname, message):
if hasattr(transformer, "fit_transform"):
Expand Down Expand Up @@ -1792,7 +1808,7 @@ def transform(self, X, **params):
routed_params[name] = Bunch(transform={})

Xs = Parallel(n_jobs=self.n_jobs)(
delayed(_transform_one)(trans, X, None, weight, routed_params[name])
delayed(_transform_one)(trans, X, None, weight, params=routed_params[name])
for name, trans, weight in self._iter()
)
if not Xs:
Expand Down