Skip to content

Commit

Permalink
TST: Asserts deprecation to remove warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan committed Nov 7, 2018
1 parent c7162b0 commit 077a0af
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions skorch/tests/test_helper.py
Expand Up @@ -171,6 +171,7 @@ def test_grid_search_with_dict_works(
print(gs.best_score_, gs.best_params_)


# TODO: remove in 0.5.0
class TestFilterParameterGroupsRequiresGrad():

@pytest.fixture
Expand All @@ -187,7 +188,8 @@ def test_all_parameters_requires_gradient(self, filter_requires_grad):
'params': [torch.zeros(1, requires_grad=True)]
}]

filter_pgroups = list(filter_requires_grad(pgroups))
with pytest.warns(DeprecationWarning):
filter_pgroups = list(filter_requires_grad(pgroups))
assert len(filter_pgroups) == 2
assert len(list(filter_pgroups[0]['params'])) == 2
assert len(list((filter_pgroups[1]['params']))) == 1
Expand All @@ -204,7 +206,8 @@ def test_some_params_requires_gradient(self, filter_requires_grad):
'params': [torch.zeros(1, requires_grad=False)]
}]

filter_pgroups = list(filter_requires_grad(pgroups))
with pytest.warns(DeprecationWarning):
filter_pgroups = list(filter_requires_grad(pgroups))
assert len(filter_pgroups) == 2
assert len(list(filter_pgroups[0]['params'])) == 1
assert len(list(filter_pgroups[1]['params'])) == 0
Expand All @@ -222,14 +225,16 @@ def test_does_not_drop_group_when_requires_grad_is_false(
'params': [torch.zeros(1, requires_grad=False)]
}]

filter_pgroups = list(filter_requires_grad(pgroups))
with pytest.warns(DeprecationWarning):
filter_pgroups = list(filter_requires_grad(pgroups))
assert len(filter_pgroups) == 2
assert len(list(filter_pgroups[0]['params'])) == 0
assert len(list(filter_pgroups[1]['params'])) == 0

assert filter_pgroups[0]['lr'] == 0.1


# TODO: remove in 0.5.0
class TestOptimizerParamsRequiresGrad:

@pytest.fixture
Expand All @@ -252,8 +257,9 @@ def test_passes_filtered_cgroups(
'params': [torch.zeros(1, requires_grad=True)]
}]

opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
filtered_opt = opt(pgroups, lr=0.2)
with pytest.warns(DeprecationWarning):
opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
filtered_opt = opt(pgroups, lr=0.2)

assert isinstance(filtered_opt, torch.optim.SGD)
assert len(list(filtered_opt.param_groups[0]['params'])) == 1
Expand All @@ -273,17 +279,19 @@ def test_passes_kwargs_to_neuralnet_optimizer(
output_units=1,
)

opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
net = NeuralNetClassifier(
module_cls, optimizer=opt, optimizer__momentum=0.9)
with pytest.warns(DeprecationWarning):
opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
net = NeuralNetClassifier(
module_cls, optimizer=opt, optimizer__momentum=0.9)
net.initialize()

net.initialize()
assert isinstance(net.optimizer_, torch.optim.SGD)
assert len(net.optimizer_.param_groups) == 1
assert net.optimizer_.param_groups[0]['momentum'] == 0.9

def test_pickle(self, filtered_optimizer, filter_requires_grad):
opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
with pytest.warns(DeprecationWarning):
opt = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
# Does not raise
pickle.dumps(opt)

Expand Down

0 comments on commit 077a0af

Please sign in to comment.