Skip to content

Commit

Permalink
Implement InputShapeSetter callback (#786)
Browse files Browse the repository at this point in the history
  • Loading branch information
ottonemo committed Jul 29, 2021
1 parent f0a6ca3 commit 32577c5
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `load_best` attribute to `Checkpoint` callback to automatically load state of the best result at the end of training
- Added a `get_all_learnable_params` method to retrieve the named parameters of all PyTorch modules defined on the net, including of criteria if applicable
- Added `MlflowLogger` callback for logging to Mlflow (#769)
- Added `InputShapeSetter` callback for automatically setting the input dimension of the PyTorch module

### Changed

Expand Down
1 change: 1 addition & 0 deletions skorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'Freezer',
'GradientNormClipping',
'Initializer',
'InputShapeSetter',
'LRScheduler',
'LoadInitState',
'MlflowLogger',
Expand Down
75 changes: 74 additions & 1 deletion skorch/callbacks/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


__all__ = ['Checkpoint', 'EarlyStopping', 'ParamMapper', 'Freezer',
'Unfreezer', 'Initializer', 'LoadInitState', 'TrainEndCheckpoint']
'Unfreezer', 'Initializer', 'InputShapeSetter', 'LoadInitState',
'TrainEndCheckpoint']


class Checkpoint(Callback):
Expand Down Expand Up @@ -761,3 +762,75 @@ def on_train_end(self, net, **kwargs):
self.checkpoint_.save_model(net)
self.checkpoint_._sink("Final checkpoint triggered", net.verbose)
return self


class InputShapeSetter(Callback):
"""Sets the input dimension of the PyTorch module to the input dimension
of the training data. By default the last dimension of X (``X.shape[-1]``)
will be used.
This can be of use when the shape of X is not known beforehand,
e.g. when using a skorch model within an sklearn pipeline and
grid-searching feature transformers, or using feature selection
methods.
Basic usage:
>>> class MyModule(torch.nn.Module):
... def __init__(self, input_dim=1):
... super().__init__()
... self.layer = torch.nn.Linear(input_dim, 3)
... # ...
>>> X1 = np.zeros(100, 5)
>>> X2 = np.zeros(100, 3)
>>> y = np.zeros(100)
>>> net = NeuralNetClassifier(MyModule, callbacks=[InputShapeSetter()])
>>> net.fit(X1, y) # self.module_.layer.in_features == 5
>>> net.fit(X2, y) # self.module_.layer.in_features == 3
Parameters
----------
param_name : str (default='input_dim')
The parameter name is the parameter your model uses to define the
input dimension in its ``__init__`` method.
input_dim_fn : callable, None (default=None)
In case your ``X`` value is more complex and deriving the input
dimension is not as easy as ``X.shape[-1]`` you can pass a callable
to this parameter which takes ``X`` and returns the input dimension.
module_name : str (default='module')
Only needs change when you are using more than one module in your
skorch model (e.g., in case of GANs).
"""
def __init__(
self,
param_name='input_dim',
input_dim_fn=None,
module_name='module',
):
self.module_name = module_name
self.param_name = param_name
self.input_dim_fn = input_dim_fn

def get_input_dim(self, X):
if self.input_dim_fn is not None:
return self.input_dim_fn(X)
if len(X.shape) < 2:
raise ValueError(
"Expected at least two-dimensional input data for X. "
"If your data is one-dimensional, please use the "
"`input_dim_fn` parameter to infer the correct "
"input shape."
)
return X.shape[-1]

def on_train_begin(self, net, X, y, **kwargs):
params = net.get_params()
input_dim = self.get_input_dim(X)
param_name = f'{self.module_name}__{self.param_name}'

if params.get(param_name, None) == input_dim:
return

kwargs = {param_name: input_dim}
net.set_params(**kwargs)
182 changes: 182 additions & 0 deletions skorch/tests/callbacks/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pytest
from sklearn.base import clone
import torch


class TestCheckpoint:
Expand Down Expand Up @@ -1140,3 +1141,184 @@ def test_pickle_initialized_callback(self, trainendcheckpoint_cls):
# does not raise
s = pickle.dumps(cp)
pickle.loads(s)


class TestInputShapeSetter:

@pytest.fixture
def module_cls(self):
import torch

class Module(torch.nn.Module):
def __init__(self, input_dim=3):
super().__init__()
self.layer = torch.nn.Linear(input_dim, 2)
def forward(self, X):
return self.layer(X)

return Module

@pytest.fixture
def net_cls(self):
from skorch import NeuralNetClassifier
return NeuralNetClassifier

@pytest.fixture
def input_shape_setter_cls(self):
from skorch.callbacks import InputShapeSetter
return InputShapeSetter

def generate_data(self, n_input):
from sklearn.datasets import make_classification
X, y = make_classification(
1000,
n_input,
n_informative=n_input,
n_redundant=0,
random_state=0,
)
return X.astype(np.float32), y

@pytest.fixture
def data_fixed(self):
return self.generate_data(n_input=10)

@pytest.fixture(params=[2, 10, 20])
def data_parametrized(self, request):
return self.generate_data(n_input=request.param)

def test_shape_set(
self, net_cls, module_cls, input_shape_setter_cls, data_parametrized,
):
net = net_cls(module_cls, max_epochs=2, callbacks=[
input_shape_setter_cls(),
])

X, y = data_parametrized
n_input = X.shape[1]
net.fit(X, y)

assert net.module_.layer.in_features == n_input

def test_one_dimensional_x_raises(
self, net_cls, module_cls, input_shape_setter_cls,
):
net = net_cls(module_cls, max_epochs=2, callbacks=[
input_shape_setter_cls(),
])

X, y = np.zeros(10), np.zeros(10)

with pytest.raises(ValueError) as e:
net.fit(X, y)

assert (
"Expected at least two-dimensional input data for X. "
"If your data is one-dimensional, please use the `input_dim_fn` "
"parameter to infer the correct input shape."
) in str(e)

def test_shape_set_using_fn(
self, net_cls, module_cls, input_shape_setter_cls, data_parametrized,
):
fn_calls = 0

def input_dim_fn(X):
nonlocal fn_calls
fn_calls += 1
return X.shape[1]

net = net_cls(module_cls, max_epochs=2, callbacks=[
input_shape_setter_cls(input_dim_fn=input_dim_fn),
])

X, y = data_parametrized
n_input = X.shape[1]
net.fit(X, y)

assert net.module_.layer.in_features == n_input
assert fn_calls == 1

def test_parameter_name(
self, net_cls, input_shape_setter_cls, data_parametrized,
):
class MyModule(torch.nn.Module):
def __init__(self, other_input_dim=22):
super().__init__()
self.layer = torch.nn.Linear(other_input_dim, 2)
def forward(self, X):
return self.layer(X)

net = net_cls(MyModule, max_epochs=2, callbacks=[
input_shape_setter_cls(param_name='other_input_dim'),
])

X, y = data_parametrized
n_input = X.shape[1]
net.fit(X, y)

assert net.module_.layer.in_features == n_input

def test_module_name(
self, net_cls, module_cls, input_shape_setter_cls, data_parametrized,
):
class MyNet(net_cls):
def initialize_module(self):
kwargs = self.get_params_for('module')
self.module_ = self.module(**kwargs)

kwargs = self.get_params_for('module2')
self.module2_ = self.module(**kwargs)

net = MyNet(
module=module_cls,
max_epochs=2,
callbacks=[
input_shape_setter_cls(module_name='module'),
input_shape_setter_cls(module_name='module2'),
],
)

X, y = data_parametrized
n_input = X.shape[1]
net.fit(X, y)

assert net.module_.layer.in_features == n_input
assert net.module2_.layer.in_features == n_input

def test_no_module_reinit_when_already_correct(
self, net_cls, module_cls, input_shape_setter_cls, data_fixed,
):
with patch('skorch.classifier.NeuralNetClassifier.initialize_module',
side_effect=net_cls.initialize_module, autospec=True):
net = net_cls(
module_cls, max_epochs=2, callbacks=[input_shape_setter_cls()],

# set the input dim to the correct shape beforehand
module__input_dim=data_fixed[0].shape[-1],
)

net.fit(*data_fixed)

# first initialization due to `initialize()` but not
# a second one since the input shape is already correct.
assert net.initialize_module.call_count == 1

def test_no_module_reinit_partial_fit(
self, net_cls, module_cls, input_shape_setter_cls, data_fixed,
):
with patch('skorch.classifier.NeuralNetClassifier.initialize_module',
side_effect=net_cls.initialize_module, autospec=True):
net = net_cls(
module_cls, max_epochs=2, callbacks=[input_shape_setter_cls()],
)

net.fit(*data_fixed)
# first initialization due to `initialize()`, second
# by setting the input dimension in `on_train_begin`
assert net.initialize_module.call_count == 2

net.partial_fit(*data_fixed)
# no re-initialization when there was no change in
# input dimension.
assert net.initialize_module.call_count == 2

0 comments on commit 32577c5

Please sign in to comment.