-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
test_updates.py
92 lines (87 loc) · 2.89 KB
/
test_updates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import numpy as np
import theano
from theano.configparser import change_flags
from pymc3.variational.updates import (
sgd,
momentum,
nesterov_momentum,
adagrad,
rmsprop,
adadelta,
adam,
adamax,
adagrad_window
)
_a = theano.shared(1.)
_b = _a*2
_m = theano.shared(np.empty((10, ), theano.config.floatX))
_n = _m.sum()
_m2 = theano.shared(np.empty((10, 10, 10), theano.config.floatX))
_n2 = _b + _n + _m2.sum()
@pytest.mark.parametrize(
'opt',
[sgd, momentum, nesterov_momentum,
adagrad, rmsprop, adadelta, adam,
adamax, adagrad_window],
ids=['sgd', 'momentum', 'nesterov_momentum',
'adagrad', 'rmsprop', 'adadelta', 'adam',
'adamax', 'adagrad_window']
)
@pytest.mark.parametrize(
'getter',
[lambda t: t, # all params -> ok
lambda t: (None, t[1]), # missing loss -> fail
lambda t: (t[0], None), # missing params -> fail
lambda t: (None, None)], # all missing -> partial
ids=['all_params',
'missing_loss',
'missing_params',
'all_missing']
)
@pytest.mark.parametrize(
'kwargs',
[dict(), dict(learning_rate=1e-2)],
ids=['without_args', 'with_args']
)
@pytest.mark.parametrize(
'loss_and_params',
[(_b, [_a]), (_n, [_m]), (_n2, [_a, _m, _m2])],
ids=['scalar', 'matrix', 'mixed']
)
def test_updates_fast(opt, loss_and_params, kwargs, getter):
with change_flags(compute_test_value='ignore'):
loss, param = getter(loss_and_params)
args = dict()
args.update(**kwargs)
args.update(dict(loss_or_grads=loss, params=param))
if loss is None and param is None:
updates = opt(**args)
# Here we should get new callable
assert callable(updates)
# And be able to get updates
updates = opt(_b, [_a])
assert isinstance(updates, dict)
# case when both are None is above
elif loss is None or param is None:
# Here something goes wrong and user provides not full set of [params + loss_or_grads]
# We raise Value error
with pytest.raises(ValueError):
opt(**args)
else:
# Usual call to optimizer, old behaviour
updates = opt(**args)
assert isinstance(updates, dict)