-
-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
In StackingClassifier
and StackingRegressor
, we added an option passthrough=True/False
which will concatenate the internal predictions of the first layer model with the original dataset. While everything is going if X
is only numerical, things start to be complicated when we are dealing with mixed types and dataframe.
Let's illustrate the issue by tacking the Titanic dataset:
Workflow
Some imports
import numpy as np
from pandas.api.types import CategoricalDtype
from sklearn.base import clone
from sklearn.compose import make_column_selector as selector
from sklearn.compose import make_column_transformer
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import RidgeClassifierCV
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import OrdinalEncoder
np.random.seed(0)
Load the dataset
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
subset_feature = ['embarked', 'sex', 'pclass', 'age', 'fare']
X = X[subset_feature]
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
Define first-layer learners
Since we are dealing with mixed-types, the first-layer learners need to define some preprocessor.
Here, we will stack a linear model and a tree-based model and then we defined 2 types of preprocessors.
gradient_based_processor = make_column_transformer(
(make_pipeline(StandardScaler(), SimpleImputer(strategy="median")),
selector(dtype_exclude=CategoricalDtype)),
(make_pipeline(SimpleImputer(strategy='constant', fill_value='missing'),
OneHotEncoder(handle_unknown='ignore')),
selector(dtype_include=CategoricalDtype))
)
categories = [
X[col].dtype.categories.tolist() + ["missing"] if X[col].isnull().any()
else X[col].dtype.categories.tolist()
for col in selector(dtype_include=CategoricalDtype)(X)
]
tree_based_processor = make_column_transformer(
(make_pipeline(SimpleImputer(strategy='constant', fill_value='missing'),
OrdinalEncoder(categories=categories)),
selector(dtype_include=CategoricalDtype))
)
rf = make_pipeline(clone(tree_based_processor), RandomForestClassifier())
lr = make_pipeline(clone(gradient_based_processor), LogisticRegression())
# sanity check
print(lr.fit(X_train, y_train).score(X_test, y_test))
print(rf.fit(X_train, y_train).score(X_test, y_test))
0.8048780487804879
0.8048780487804879
Up to now, we have a standard learner. Let's try to stack them
Only stack the predictions of the first-layer models
Let's use some stacking. By default, we will stack the predictions of the first-layer models
numerical_ridge = RidgeClassifierCV()
model = StackingClassifier(
estimators=[("lr", lr), ("rf", rf)],
final_estimator=numerical_ridge,
passthrough=False
)
print(model.fit(X_train, y_train).score(X_test, y_test))
0.8048780487804879
We don't have any issues since the predictions are only numerical. The issue starts if we want to concatenate the original X
Concatenate X
with the predictions of the first-layer models
If we try directly to pass through X
we get an issue since the ridge classifier did not encode the data in X
:
model = StackingClassifier(
estimators=[("lr", lr), ("rf", rf)],
final_estimator=numerical_ridge,
passthrough=True
)
try:
print(model.fit(X_train, y_train).score(X_test, y_test))
except Exception as e:
print(e)
could not convert string to float: 'S'
Thus, one has to create a pipeline to deal with the data:
ridge = make_pipeline(clone(gradient_based_processor), RidgeClassifierCV())
ridge.named_steps["columntransformer"].set_params(remainder="passthrough")
model = StackingClassifier(
estimators=[("lr", lr), ("rf", rf)],
final_estimator=ridge,
passthrough=True
)
try:
print(model.fit(X_train, y_train).score(X_test, y_test))
except Exception as e:
print(e)
make_column_selector can only be applied to pandas dataframes
So here we actually have the following issue:
- the concatenation does not preserve dataframe. This could be solved by stacking all the data to
X
(which should manage without importingpandas
). However, it means that we might want to generate some feature names for the predictions generated by the first-layers. So this is my first question, what shall we do? - if the above issue can be fixed, then we start to have another issue. When creating the pipeline, the user would need to specifically state to let pass through the remainder in the final estimator. Otherwise, the prediction columns will be dropped. So, we might want to interfere with the column transformer of the last steps to be sure to pass through the predictions. But API I don't know what is the best way to do so?
ping @jnothman @qinhanmin2014 @thomasjpfan
Sorry for the long narration.