Skip to content
Browse files

Add DataFrameTransformer (#507)

* Refactor tests in

Make assert_dicts_equal a standalone function.

* Implement DataFrameTransformer

An sklearn compatible transformer that helps working with pandas
DataFrames by transforming the DataFrame into a representation that
works well with neural networks.

* Fix bug: DataFrameTransformer didn't work without floats

It was assumed that there is always at least one float column.

* Implement a helper method to help construct signature

Allows to call DataFrameTransformer().describe_signature(df), which
describes what keys are needed, what their dtypes are, and how many
input units they require.

* Address reviewer comments

* Use np.issubdtype instead of enumerating all possible dtypes.
* describe_signature values are not namedtuples anymore, but

* Address reviewer comments

* use pd.api.types.CategoricalDtype
* rename variable for more clarity

* Include reviewer comment in

Use np.issubdtype.

Co-Authored-By: Thomas J Fan <>

* Fix syntax error

Co-authored-by: Thomas J Fan <>
Co-authored-by: ottonemo <>
  • Loading branch information
3 people committed Feb 14, 2020
1 parent d47357a commit 817697fa1bdc4fbfb27ae211302d244e293f1d1c
Showing with 499 additions and 10 deletions.
  1. +2 −0
  2. +252 −0 skorch/
  3. +245 −10 skorch/tests/
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](

### Added

- Add DataFrameTransformer, an sklearn compatible transformer that helps working with pandas DataFrames by transforming the DataFrame into a representation that works well with neural networks (#507)

### Changed

- When using caching in scoring callbacks, no longer uselessly iterate over the data; this can save time if iteration is slow (#552, #557)
@@ -4,13 +4,18 @@
from collections import Sequence
from collections import namedtuple
from functools import partial

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
import torch

from skorch.cli import parse_args
from skorch.utils import _make_split
from skorch.utils import is_torch_data_type
from skorch.utils import to_tensor

class SliceDict(dict):
@@ -257,3 +262,250 @@ def predefined_split(dataset):
return partial(_make_split, valid_ds=dataset)

class DataFrameTransformer(BaseEstimator, TransformerMixin):
"""Transform a DataFrame into a dict useful for working with skorch.
Transforms cardinal data to floats and categorical data to vectors
of ints so that they can be embedded.
Although skorch can deal with pandas DataFrames, the default
behavior is often not very useful. Use this transformer to
transform the DataFrame into a dict with all float columns
concatenated using the key "X" and all categorical values encoded
as integers, using their respective column names as keys.
Your module must have a matching signature for this to work. It
must accept an argument ``X`` for all cardinal
values. Additionally, for all categorical values, it must accept
an argument with the same name as the corresponding column (see
example below). If you need help with the required signature, use
the ``describe_signature`` method of this class and pass it your
You can choose whether you want to treat int columns the same as
float columns (default) or as categorical values.
To one-hot encode categorical features, initialize their
corresponding embedding layers using the identity matrix.
>>> df = pd.DataFrame({
... 'col_floats': np.linspace(0, 1, 12),
... 'col_ints': [11, 11, 10] * 4,
... 'col_cats': ['a', 'b', 'a'] * 4,
... })
>>> # cast to category dtype to later learn embeddings
>>> df['col_cats'] = df['col_cats'].astype('category')
>>> y = np.asarray([0, 1, 0] * 4)
>>> class MyModule(nn.Module):
... def __init__(self):
... super().__init__()
... self.reset_params()
>>> def reset_params(self):
... self.embedding = nn.Embedding(2, 10)
... self.linear = nn.Linear(2, 10)
... self.out = nn.Linear(20, 2)
... self.nonlin = nn.Softmax(dim=-1)
>>> def forward(self, X, col_cats):
... # "X" contains the values from col_floats and col_ints
... # "col_cats" contains the values from "col_cats"
... X_lin = self.linear(X)
... X_cat = self.embedding(col_cats)
... X_concat =, X_cat), dim=1)
... return self.nonlin(self.out(X_concat))
>>> net = NeuralNetClassifier(MyModule)
>>> pipe = Pipeline([
... ('transform', DataFrameTransformer()),
... ('net', net),
... ])
>>>, y)
treat_int_as_categorical : bool (default=False)
Whether to treat integers as categorical values or as cardinal
values, i.e. the same as floats.
float_dtype : numpy dtype or None (default=np.float32)
The dtype to cast the cardinal values to. If None, don't change
int_dtype : numpy dtype or None (default=np.int64)
The dtype to cast the categorical values to. If None, don't
change them. If you do this, it can happen that the categorical
values will have different dtypes, reflecting the number of
unique categories.
The value of X will always be 2-dimensional, even if it only
contains 1 column.
import pandas as pd

def __init__(
self.treat_int_as_categorical = treat_int_as_categorical
self.float_dtype = float_dtype
self.int_dtype = int_dtype

def _check_dtypes(self, df):
"""Perform a check on the DataFrame to detect wrong dtypes or keys.
Makes sure that there are no conflicts in key names.
If dtypes are found that cannot be dealt with, raises a
TypeError with a message indicating which ones caused trouble.
If there already is a column named 'X'.
If a wrong dtype is found.
if 'X' in df:
raise ValueError(
"DataFrame contains a column named 'X', which clashes "
"with the name chosen for cardinal features; consider "
"renaming that column.")

wrong_dtypes = []

for col, dtype in zip(df, df.dtypes):
if isinstance(dtype, self.pd.api.types.CategoricalDtype):
if np.issubdtype(dtype, np.integer):
if np.issubdtype(dtype, np.floating):
wrong_dtypes.append((col, dtype))

if not wrong_dtypes:

wrong_dtypes = sorted(wrong_dtypes, key=lambda tup: tup[0])
msg_dtypes = ", ".join(
"{} ({})".format(col, dtype) for col, dtype in wrong_dtypes)
msg = ("The following columns have dtypes that cannot be "
"interpreted as numerical dtypes: {}".format(msg_dtypes))
raise TypeError(msg)

# pylint: disable=unused-argument
def fit(self, df, y=None, **fit_params):
return self

def transform(self, df):
"""Transform DataFrame to become a dict that works well with skorch.
df : pd.DataFrame
Incoming DataFrame.
X_dict: dict
Dictionary with all floats concatenated using the key "X"
and all categorical values encoded as integers, using their
respective column names as keys.

X_dict = {}
Xf = [] # floats

for col, dtype in zip(df, df.dtypes):
X_col = df[col]

if isinstance(dtype, self.pd.api.types.CategoricalDtype):
x =
if self.int_dtype is not None:
x = x.astype(self.int_dtype)
X_dict[col] = x

if (
np.issubdtype(dtype, np.integer)
and self.treat_int_as_categorical
x = X_col.astype('category')
if self.int_dtype is not None:
x = x.astype(self.int_dtype)
X_dict[col] = x


if not Xf:
return X_dict

X = np.stack(Xf, axis=1)
if self.float_dtype is not None:
X = X.astype(self.float_dtype)
X_dict['X'] = X
return X_dict

def describe_signature(self, df):
"""Describe the signature required for the given data.
Pass the DataFrame to receive a description of the signature
required for the module's forward method. The description
consists of three parts:
1. The names of the arguments that the forward method
2. The dtypes of the torch tensors passed to forward.
3. The number of input units that are required for the
corresponding argument. For the float parameter, this is just
the number of dimensions of the tensor. For categorical
parameters, it is the number of unique elements.
signature : dict
Returns a dict with each key corresponding to one key
required for the forward method. The values are dictionaries
of two elements. The key "dtype" describes the torch dtype
of the resulting tensor, the key "input_units" describes the
required number of input units.
X_dict = self.fit_transform(df)
signature = {}

X = X_dict.get('X')
if X is not None:
signature['X'] = dict(
dtype=to_tensor(X, device='cpu').dtype,

for key, val in X_dict.items():
if key == 'X':

tensor = to_tensor(val, device='cpu')
nunique = len(torch.unique(tensor))
signature[key] = dict(

return signature

0 comments on commit 817697f

Please sign in to comment.
You can’t perform that action at this time.