Skip to content

Commit

Permalink
Fix ColumnTransformer in parallel with joblib's auto memmapping (#28822)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb committed Apr 22, 2024
1 parent 3af257e commit 51fca39
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 9 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ Changelog
being explicitly set as well.
:pr:`28483` by :user:`Stefanie Senger <StefanieSenger>`.

- |Fix| Fixed an bug in :class:`compose.ColumnTransformer` with `n_jobs > 1`, where the
intermediate selected columns were passed to the transformers as read-only arrays.
:pr:`28822` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.cross_decomposition`
..................................

Expand Down
9 changes: 5 additions & 4 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..base import TransformerMixin, _fit_context, clone
from ..pipeline import _fit_transform_one, _name_estimators, _transform_one
from ..preprocessing import FunctionTransformer
from ..utils import Bunch, _safe_indexing
from ..utils import Bunch
from ..utils._estimator_html_repr import _VisualBlock
from ..utils._indexing import _get_column_indices
from ..utils._metadata_requests import METHODS
Expand Down Expand Up @@ -389,7 +389,7 @@ def set_params(self, **kwargs):

def _iter(self, fitted, column_as_labels, skip_drop, skip_empty_columns):
"""
Generate (name, trans, column, weight) tuples.
Generate (name, trans, columns, weight) tuples.
Parameters
Expand Down Expand Up @@ -794,7 +794,7 @@ def _call_func_on_transformers(self, X, y, func, column_as_labels, routed_params
)
try:
jobs = []
for idx, (name, trans, column, weight) in enumerate(transformers, start=1):
for idx, (name, trans, columns, weight) in enumerate(transformers, start=1):
if func is _fit_transform_one:
if trans == "passthrough":
output_config = _get_output_config("transform", self)
Expand All @@ -813,9 +813,10 @@ def _call_func_on_transformers(self, X, y, func, column_as_labels, routed_params
jobs.append(
delayed(func)(
transformer=clone(trans) if not fitted else trans,
X=_safe_indexing(X, column, axis=1),
X=X,
y=y,
weight=weight,
columns=columns,
**extra_args,
params=routed_params[name],
)
Expand Down
27 changes: 26 additions & 1 deletion sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import warnings

import joblib
import numpy as np
import pytest
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -36,7 +37,7 @@
assert_almost_equal,
assert_array_equal,
)
from sklearn.utils.fixes import CSR_CONTAINERS
from sklearn.utils.fixes import CSR_CONTAINERS, parse_version


class Trans(TransformerMixin, BaseEstimator):
Expand Down Expand Up @@ -2447,6 +2448,30 @@ def test_column_transformer_error_with_duplicated_columns(dataframe_lib):
transformer.fit_transform(df)


@pytest.mark.skipif(
parse_version(joblib.__version__) < parse_version("1.3"),
reason="requires joblib >= 1.3",
)
def test_column_transformer_auto_memmap():
"""Check that ColumnTransformer works in parallel with joblib's auto-memmapping.
non-regression test for issue #28781
"""
X = np.random.RandomState(0).uniform(size=(3, 4))

scaler = StandardScaler(copy=False)

transformer = ColumnTransformer(
transformers=[("scaler", scaler, [0])],
n_jobs=2,
)

with joblib.parallel_backend("loky", max_nbytes=1):
Xt = transformer.fit_transform(X)

assert_allclose(Xt, StandardScaler().fit_transform(X[:, [0]]))


# Metadata Routing Tests
# ======================

Expand Down
24 changes: 20 additions & 4 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .base import TransformerMixin, _fit_context, clone
from .exceptions import NotFittedError
from .preprocessing import FunctionTransformer
from .utils import Bunch
from .utils import Bunch, _safe_indexing
from .utils._estimator_html_repr import _VisualBlock
from .utils._metadata_requests import METHODS
from .utils._param_validation import HasMethods, Hidden
Expand Down Expand Up @@ -1258,7 +1258,7 @@ def make_pipeline(*steps, memory=None, verbose=False):
return Pipeline(_name_estimators(steps), memory=memory, verbose=verbose)


def _transform_one(transformer, X, y, weight, params):
def _transform_one(transformer, X, y, weight, columns=None, params=None):
"""Call transform and apply weight to output.
Parameters
Expand All @@ -1275,11 +1275,17 @@ def _transform_one(transformer, X, y, weight, params):
weight : float
Weight to be applied to the output of the transformation.
columns : str, array-like of str, int, array-like of int, array-like of bool, slice
Columns to select before transforming.
params : dict
Parameters to be passed to the transformer's ``transform`` method.
This should be of the form ``process_routing()["step_name"]``.
"""
if columns is not None:
X = _safe_indexing(X, columns, axis=1)

res = transformer.transform(X, **params.transform)
# if we have a weight for this transformer, multiply output
if weight is None:
Expand All @@ -1288,7 +1294,14 @@ def _transform_one(transformer, X, y, weight, params):


def _fit_transform_one(
transformer, X, y, weight, message_clsname="", message=None, params=None
transformer,
X,
y,
weight,
columns=None,
message_clsname="",
message=None,
params=None,
):
"""
Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
Expand All @@ -1297,6 +1310,9 @@ def _fit_transform_one(
``params`` needs to be of the form ``process_routing()["step_name"]``.
"""
if columns is not None:
X = _safe_indexing(X, columns, axis=1)

params = params or {}
with _print_elapsed_time(message_clsname, message):
if hasattr(transformer, "fit_transform"):
Expand Down Expand Up @@ -1792,7 +1808,7 @@ def transform(self, X, **params):
routed_params[name] = Bunch(transform={})

Xs = Parallel(n_jobs=self.n_jobs)(
delayed(_transform_one)(trans, X, None, weight, routed_params[name])
delayed(_transform_one)(trans, X, None, weight, params=routed_params[name])
for name, trans, weight in self._iter()
)
if not Xs:
Expand Down

0 comments on commit 51fca39

Please sign in to comment.