Skip to content

Commit

Permalink
Merge pull request #117 from dnouri/feature/params-for-utility
Browse files Browse the repository at this point in the history
params_for utility
  • Loading branch information
benjamin-work committed Nov 17, 2017
2 parents 5d42d37 + 2760c10 commit 7f6da55
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
6 changes: 2 additions & 4 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from skorch.utils import get_dim
from skorch.utils import to_numpy
from skorch.utils import to_var
from skorch.utils import params_for


# pylint: disable=unused-argument
Expand Down Expand Up @@ -875,10 +876,7 @@ def get_iterator(self, dataset, train=False):
return iterator(dataset, **kwargs)

def _get_params_for(self, prefix):
if not prefix.endswith('__'):
prefix += '__'
return {key[len(prefix):]: val for key, val in self.__dict__.items()
if key.startswith(prefix)}
return params_for(prefix, self.__dict__)

def _get_param_names(self):
return self.__dict__.keys()
Expand Down
16 changes: 16 additions & 0 deletions skorch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ def test_no_duplicates(self, duplicate_items, collections):
])
def test_duplicates(self, duplicate_items, collections, expected):
assert duplicate_items(*collections) == expected


class TestParamsFor:
@pytest.fixture
def params_for(self):
from skorch.utils import params_for
return params_for

@pytest.mark.parametrize('prefix, kwargs, expected', [
('p1', {'p1__a': 1, 'p1__b': 2}, {'a': 1, 'b': 2}),
('p2', {'p1__a': 1, 'p1__b': 2}, {}),
('p1', {'p1__a': 1, 'p1__b': 2, 'p2__a': 3}, {'a': 1, 'b': 2}),
('p2', {'p1__a': 1, 'p1__b': 2, 'p2__a': 3}, {'a': 3}),
])
def test_params_for(self, params_for, prefix, kwargs, expected):
assert params_for(prefix, kwargs) == expected
18 changes: 18 additions & 0 deletions skorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,21 @@ def duplicate_items(*collections):
else:
seen.add(item)
return duplicates


def params_for(prefix, kwargs):
"""Extract parameters that belong to a given sklearn module prefix from
``kwargs``. This is useful to obtain parameters that belong to a
submodule.
Example usage
-------------
>>> kwargs = {'encoder__a': 3, 'encoder__b': 4, 'decoder__a': 5}
>>> params_for('encoder', kwargs)
{'a': 3, 'b': 4}
"""
if not prefix.endswith('__'):
prefix += '__'
return {key[len(prefix):]: val for key, val in kwargs.items()
if key.startswith(prefix)}

0 comments on commit 7f6da55

Please sign in to comment.