Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def validate(self, chains: int) -> None:
)
if self.step_size is not None:
if isinstance(self.step_size, Real):
if self.step_size < 0:
if self.step_size <= 0:
Comment thread
mitzimorris marked this conversation as resolved.
raise ValueError(
'step_size must be > 0, found {}'.format(self.step_size)
)
Expand Down Expand Up @@ -336,7 +336,7 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
'init_alpha must not be set when algorithm is Newton'
)
if isinstance(self.init_alpha, Real):
if self.init_alpha < 0:
if self.init_alpha <= 0:
Comment thread
mitzimorris marked this conversation as resolved.
raise ValueError('init_alpha must be greater than 0')
else:
raise ValueError('init_alpha must be type of float')
Expand Down Expand Up @@ -403,6 +403,7 @@ def __init__(
elbo_samples: int = None,
eta: Real = None,
adapt_iter: int = None,
adapt_engaged: bool = True,
tol_rel_obj: Real = None,
eval_elbo: int = None,
output_samples: int = None,
Expand All @@ -413,6 +414,7 @@ def __init__(
self.elbo_samples = elbo_samples
self.eta = eta
self.adapt_iter = adapt_iter
self.adapt_engaged = adapt_engaged
self.tol_rel_obj = tol_rel_obj
self.eval_elbo = eval_elbo
self.output_samples = output_samples
Expand Down Expand Up @@ -453,19 +455,19 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument
' found {}'.format(self.elbo_samples)
)
if self.eta is not None:
if self.eta < 1 or not isinstance(self.eta, (Integral, Real)):
if self.eta < 0 or not isinstance(self.eta, (Integral, Real)):
raise ValueError(
'eta must be a non-negative number,'
' found {}'.format(self.eta)
)
if self.adapt_iter is not None:
if self.adapt_iter < 1 or not isinstance(self.eta, Integral):
if self.adapt_iter < 1 or not isinstance(self.adapt_iter, Integral):
raise ValueError(
'adapt_iter must be a positive integer,'
' found {}'.format(self.adapt_iter)
)
if self.tol_rel_obj is not None:
if self.tol_rel_obj < 1 or not isinstance(
if self.tol_rel_obj <= 0 or not isinstance(
self.tol_rel_obj, (Integral, Real)
):
raise ValueError(
Expand Down Expand Up @@ -503,9 +505,13 @@ def compose(self, idx: int, cmd: List) -> str:
cmd.append('elbo_samples={}'.format(self.elbo_samples))
if self.eta is not None:
cmd.append('eta={}'.format(self.eta))
if self.adapt_iter is not None:
cmd.append('adapt')
cmd.append('iter={}'.format(self.adapt_iter))
cmd.append('adapt')
if self.adapt_engaged:
cmd.append('engaged=1')
if self.adapt_iter is not None:
cmd.append('iter={}'.format(self.adapt_iter))
else:
cmd.append('engaged=0')
if self.tol_rel_obj is not None:
cmd.append('tol_rel_obj={}'.format(self.tol_rel_obj))
if self.eval_elbo is not None:
Expand Down
10 changes: 9 additions & 1 deletion cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,10 +900,12 @@ def variational(
grad_samples: int = None,
elbo_samples: int = None,
eta: Real = None,
adapt_engaged: bool = True,
adapt_iter: int = None,
tol_rel_obj: Real = None,
eval_elbo: int = None,
output_samples: int = None,
require_converged: bool = True,
) -> CmdStanVB:
"""
Run CmdStan's variational inference algorithm to approximate
Expand Down Expand Up @@ -961,6 +963,8 @@ def variational(

:param eta: Stepsize scaling parameter.

:param adapt_engaged: Whether eta adaptation is engaged.

:param adapt_iter: Number of iterations for eta adaptation.

:param tol_rel_obj: Relative tolerance parameter for convergence.
Expand All @@ -970,6 +974,9 @@ def variational(
:param output_samples: Number of approximate posterior output draws
to save.

:param require_converged: Whether or not to raise an error if stan
reports that "The algorithm may not have converged".

:return: CmdStanVB object
"""
variational_args = VariationalArgs(
Expand All @@ -978,6 +985,7 @@ def variational(
grad_samples=grad_samples,
elbo_samples=elbo_samples,
eta=eta,
adapt_engaged=adapt_engaged,
adapt_iter=adapt_iter,
tol_rel_obj=tol_rel_obj,
eval_elbo=eval_elbo,
Expand Down Expand Up @@ -1010,7 +1018,7 @@ def variational(
errors = re.findall(pat, contents)
if len(errors) > 0:
valid = False
if not valid:
if require_converged and not valid:
raise RuntimeError('The algorithm may not have converged.')
if not runset._check_retcodes():
msg = 'Error during variational inference.\n{}'.format(
Expand Down
25 changes: 10 additions & 15 deletions cmdstanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,22 +553,17 @@ def scan_variational_csv(path: str) -> Dict:
lineno = scan_column_names(fd, dict, lineno)
line = fd.readline().lstrip(' #\t').rstrip()
lineno += 1
if not line.startswith('Stepsize adaptation complete.'):
raise ValueError(
'line {}: expecting adaptation msg, found:\n\t "{}"'.format(
lineno, line
)
)
line = fd.readline().lstrip(' #\t\n')
lineno += 1
if not line.startswith('eta = 1'):
raise ValueError(
'line {}: expecting eta = 1, found:\n\t "{}"'.format(
lineno, line
if line.startswith('Stepsize adaptation complete.'):
line = fd.readline().lstrip(' #\t\n')
lineno += 1
if not line.startswith('eta'):
raise ValueError(
'line {}: expecting eta, found:\n\t "{}"'.format(
lineno, line
)
)
)
line = fd.readline().lstrip(' #\t\n')
lineno += 1
line = fd.readline().lstrip(' #\t\n')
lineno += 1
xs = line.split(',')
variational_mean = [float(x) for x in xs]
dict['variational_mean'] = variational_mean
Expand Down
19 changes: 17 additions & 2 deletions test/test_cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,26 @@ def test_args_variational(self):
self.assertIn('method=variational', ' '.join(cmd))
self.assertIn('output_samples=1', ' '.join(cmd))

args = VariationalArgs(tol_rel_obj=1)
args = VariationalArgs(tol_rel_obj=0.01)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('method=variational', ' '.join(cmd))
self.assertIn('tol_rel_obj=1', ' '.join(cmd))
self.assertIn('tol_rel_obj=0.01', ' '.join(cmd))

args = VariationalArgs(adapt_engaged=True, adapt_iter=100)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('adapt engaged=1 iter=100', ' '.join(cmd))

args = VariationalArgs(adapt_engaged=False)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('adapt engaged=0', ' '.join(cmd))

args = VariationalArgs(eta=0.1)
args.validate(chains=1)
cmd = args.compose(idx=0, cmd=[])
self.assertIn('eta=0.1', ' '.join(cmd))

def test_args_bad(self):
args = VariationalArgs(algorithm='no_such_algo')
Expand Down
4 changes: 4 additions & 0 deletions test/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def test_variational_eta_fail(self):
):
model.variational(algorithm='meanfield', seed=12345)

model.variational(
algorithm='meanfield', seed=12345, require_converged=False
)


if __name__ == '__main__':
unittest.main()