Skip to content

Commit

Permalink
Fix failing tests for GPyTorch v1.10 (#956)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Apr 18, 2023
1 parent 51bba10 commit 7276c93
Showing 1 changed file with 95 additions and 54 deletions.
149 changes: 95 additions & 54 deletions skorch/tests/test_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ def forward(self, x):
return latent_pred


class MyBernoulliLikelihood(gpytorch.likelihoods.BernoulliLikelihood):
"""This class only exists to add a param to BernoulliLikelihood
BernoulliLikelihood used to have parameters before gpytorch v1.10, but now
it does not have any parameters anymore. This is not an issue per se, but
there are a few things we cannot test anymore, e.g. that parameters are
passed to the likelihood correctly when using grid search. Therefore, create
a custom class with a (pointless) parameter.
"""
def __init__(self, *args, some_parameter=1, **kwargs):
self.some_parameter = some_parameter
super().__init__(*args, **kwargs)


class BaseProbabilisticTests:
"""Base class for all GP estimators.
Expand Down Expand Up @@ -220,34 +235,24 @@ def pipe(self, gp):
# saving and loading #
######################

@pytest.mark.xfail(strict=True)
def test_pickling(self, gp_fit):
# Currently fails because of issues outside of our control, this test
# should alert us to when the issue has been fixed. Some issues have
# been fixed in https://github.com/cornellius-gp/gpytorch/pull/1336 but
# not all.
pickle.dumps(gp_fit)
def test_pickling(self, gp_fit, data):
loaded = pickle.loads(pickle.dumps(gp_fit))
X, _ = data

def test_pickle_error_msg(self, gp_fit):
# Should eventually be replaced by a test that saves and loads the model
# using pickle and checks that the predictions are identical
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
" https://github.com/pytorch/pytorch/issues/38137. "
"Try using 'dill' instead of 'pickle'.")
with pytest.raises(pickle.PicklingError, match=msg):
pickle.dumps(gp_fit)
y_pred_before = gp_fit.predict(X)
y_pred_after = loaded.predict(X)
assert np.allclose(y_pred_before, y_pred_after)

def test_deepcopy(self, gp_fit):
# Should eventually be replaced by a test that saves and loads the model
# using deepcopy and checks that the predictions are identical
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
" https://github.com/pytorch/pytorch/issues/38137. "
"Try using 'dill' instead of 'pickle'.")
with pytest.raises(pickle.PicklingError, match=msg):
copy.deepcopy(gp_fit) # doesn't raise
def test_deepcopy(self, gp_fit, data):
copied = copy.deepcopy(gp_fit)
X, _ = data

y_pred_before = gp_fit.predict(X)
y_pred_after = copied.predict(X)
assert np.allclose(y_pred_before, y_pred_after)

def test_clone(self, gp_fit):
clone(gp_fit) # doesn't raise
def test_clone(self, gp_fit, data):
clone(gp_fit) # does not raise

def test_save_load_params(self, gp_fit, tmpdir):
gp2 = clone(gp_fit).initialize()
Expand Down Expand Up @@ -335,7 +340,8 @@ def test_grid_search_works(self, gp, data, recwarn):
params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'likelihood__max_plate_nesting': [1, 2],
# this parameter does not exist but that's okay
'likelihood__some_parameter': [1, 2],
}
gp.set_params(verbose=0)
gs = GridSearchCV(gp, params, refit=True, cv=3, scoring=self.scoring)
Expand Down Expand Up @@ -419,32 +425,29 @@ def test_multioutput_predict_proba(self, gp_multioutput, data):
])
def test_set_params_uninitialized_net_correct_message(
self, gp, kwargs, expected, capsys):
# When gp is initialized, if module or optimizer need to be
# re-initialized, alert the user to the fact what parameters
# were responsible for re-initialization. Note that when the
# module parameters but not optimizer parameters were changed,
# the optimizer is re-initialized but not because the
# optimizer parameters changed.
# When gp is uninitialized, there is nothing to alert the user to
gp.set_params(**kwargs)
msg = capsys.readouterr()[0].strip()
assert msg == expected

@pytest.mark.parametrize('kwargs,expected', [
({}, ""),
(
{'likelihood__max_plate_nesting': 2},
# this parameter does not exist but that's okay
{'likelihood__some_parameter': 2},
("Re-initializing module because the following "
"parameters were re-set: likelihood__max_plate_nesting.\n"
"parameters were re-set: likelihood__some_parameter.\n"
"Re-initializing criterion.\n"
"Re-initializing optimizer.")
),
(
{
'likelihood__max_plate_nesting': 2,
# this parameter does not exist but that's okay
'likelihood__some_parameter': 2,
'optimizer__momentum': 0.567,
},
("Re-initializing module because the following "
"parameters were re-set: likelihood__max_plate_nesting.\n"
"parameters were re-set: likelihood__some_parameter.\n"
"Re-initializing criterion.\n"
"Re-initializing optimizer.")
),
Expand Down Expand Up @@ -570,23 +573,6 @@ def gp(self, gp_cls, module_cls):
)
return gpr

# pickling and deepcopy work for ExactGPRegressor but not for the others, so
# override the expected failures here.

def test_pickling(self, gp_fit):
# does not raise
pickle.dumps(gp_fit)

def test_pickle_error_msg(self, gp_fit):
# Should eventually be replaced by a test that saves and loads the model
# using pickle and checks that the predictions are identical
# FIXME
pickle.dumps(gp_fit)

def test_deepcopy(self, gp_fit):
# FIXME
copy.deepcopy(gp_fit) # doesn't raise

def test_wrong_module_type_raises(self, gp_cls):
# ExactGPRegressor requires the module to be an ExactGP, if it's not,
# raise an appropriate error message to the user.
Expand Down Expand Up @@ -649,6 +635,32 @@ def gp(self, gp_cls, module_cls, data):
assert gpr.batch_size < self.n_samples
return gpr

# Since GPyTorch v1.10, GPRegressor works with pickle/deepcopy.

def test_pickling(self, gp_fit, data):
# TODO: remove once Python 3.7 is no longer supported
if version_gpytorch < Version('1.10'):
pytest.skip("GPyTorch < 1.10 does not support pickling.")

loaded = pickle.loads(pickle.dumps(gp_fit))
X, _ = data

y_pred_before = gp_fit.predict(X)
y_pred_after = loaded.predict(X)
assert np.allclose(y_pred_before, y_pred_after)

def test_deepcopy(self, gp_fit, data):
# TODO: remove once Python 3.7 is no longer supported
if version_gpytorch < Version('1.10'):
pytest.skip("GPyTorch < 1.10 does not support deepcopy.")

copied = copy.deepcopy(gp_fit)
X, _ = data

y_pred_before = gp_fit.predict(X)
y_pred_after = copied.predict(X)
assert np.allclose(y_pred_before, y_pred_after)


class TestGPBinaryClassifier(BaseProbabilisticTests):
"""Tests for GPBinaryClassifier."""
Expand Down Expand Up @@ -686,11 +698,40 @@ def gp(self, gp_cls, module_cls, data):
gpc = gp_cls(
module_cls,
module__inducing_points=torch.from_numpy(X[:10]),

likelihood=MyBernoulliLikelihood,
criterion=gpytorch.mlls.VariationalELBO,
criterion__num_data=int(0.8 * len(y)),
batch_size=24,
)
# we want to make sure batching is properly tested
assert gpc.batch_size < self.n_samples
return gpc

# Since GPyTorch v1.10, GPBinaryClassifier is the only estimator left that
# still has issues with pickling/deepcopying.

@pytest.mark.xfail(strict=True)
def test_pickling(self, gp_fit, data):
# Currently fails because of issues outside of our control, this test
# should alert us to when the issue has been fixed. Some issues have
# been fixed in https://github.com/cornellius-gp/gpytorch/pull/1336 but
# not all.
pickle.dumps(gp_fit)

def test_pickle_error_msg(self, gp_fit, data):
# Should eventually be replaced by a test that saves and loads the model
# using pickle and checks that the predictions are identical
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
" https://github.com/pytorch/pytorch/issues/38137. "
"Try using 'dill' instead of 'pickle'.")
with pytest.raises(pickle.PicklingError, match=msg):
pickle.dumps(gp_fit)

def test_deepcopy(self, gp_fit, data):
# Should eventually be replaced by a test that saves and loads the model
# using deepcopy and checks that the predictions are identical
msg = ("This GPyTorch model cannot be pickled. The reason is probably this:"
" https://github.com/pytorch/pytorch/issues/38137. "
"Try using 'dill' instead of 'pickle'.")
with pytest.raises(pickle.PicklingError, match=msg):
copy.deepcopy(gp_fit) # doesn't raise

0 comments on commit 7276c93

Please sign in to comment.