diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 707118c2..8f8499bf 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -132,7 +132,7 @@ def validate(self, chains: int) -> None: ) if self.step_size is not None: if isinstance(self.step_size, Real): - if self.step_size < 0: + if self.step_size <= 0: raise ValueError( 'step_size must be > 0, found {}'.format(self.step_size) ) @@ -336,7 +336,7 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument 'init_alpha must not be set when algorithm is Newton' ) if isinstance(self.init_alpha, Real): - if self.init_alpha < 0: + if self.init_alpha <= 0: raise ValueError('init_alpha must be greater than 0') else: raise ValueError('init_alpha must be type of float') @@ -403,6 +403,7 @@ def __init__( elbo_samples: int = None, eta: Real = None, adapt_iter: int = None, + adapt_engaged: bool = True, tol_rel_obj: Real = None, eval_elbo: int = None, output_samples: int = None, @@ -413,6 +414,7 @@ def __init__( self.elbo_samples = elbo_samples self.eta = eta self.adapt_iter = adapt_iter + self.adapt_engaged = adapt_engaged self.tol_rel_obj = tol_rel_obj self.eval_elbo = eval_elbo self.output_samples = output_samples @@ -453,19 +455,19 @@ def validate(self, chains=None) -> None: # pylint: disable=unused-argument ' found {}'.format(self.elbo_samples) ) if self.eta is not None: - if self.eta < 1 or not isinstance(self.eta, (Integral, Real)): + if self.eta < 0 or not isinstance(self.eta, (Integral, Real)): raise ValueError( 'eta must be a non-negative number,' ' found {}'.format(self.eta) ) if self.adapt_iter is not None: - if self.adapt_iter < 1 or not isinstance(self.eta, Integral): + if self.adapt_iter < 1 or not isinstance(self.adapt_iter, Integral): raise ValueError( 'adapt_iter must be a positive integer,' ' found {}'.format(self.adapt_iter) ) if self.tol_rel_obj is not None: - if self.tol_rel_obj < 1 or not isinstance( + if self.tol_rel_obj <= 0 or not isinstance( self.tol_rel_obj, (Integral, Real) ): raise ValueError( @@ -503,9 +505,13 @@ def compose(self, idx: int, cmd: List) -> str: cmd.append('elbo_samples={}'.format(self.elbo_samples)) if self.eta is not None: cmd.append('eta={}'.format(self.eta)) - if self.adapt_iter is not None: - cmd.append('adapt') - cmd.append('iter={}'.format(self.adapt_iter)) + cmd.append('adapt') + if self.adapt_engaged: + cmd.append('engaged=1') + if self.adapt_iter is not None: + cmd.append('iter={}'.format(self.adapt_iter)) + else: + cmd.append('engaged=0') if self.tol_rel_obj is not None: cmd.append('tol_rel_obj={}'.format(self.tol_rel_obj)) if self.eval_elbo is not None: diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index e1ac37a2..f96159bc 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -900,10 +900,12 @@ def variational( grad_samples: int = None, elbo_samples: int = None, eta: Real = None, + adapt_engaged: bool = True, adapt_iter: int = None, tol_rel_obj: Real = None, eval_elbo: int = None, output_samples: int = None, + require_converged: bool = True, ) -> CmdStanVB: """ Run CmdStan's variational inference algorithm to approximate @@ -961,6 +963,8 @@ def variational( :param eta: Stepsize scaling parameter. + :param adapt_engaged: Whether eta adaptation is engaged. + :param adapt_iter: Number of iterations for eta adaptation. :param tol_rel_obj: Relative tolerance parameter for convergence. @@ -970,6 +974,9 @@ def variational( :param output_samples: Number of approximate posterior output draws to save. + :param require_converged: Whether or not to raise an error if stan + reports that "The algorithm may not have converged". + :return: CmdStanVB object """ variational_args = VariationalArgs( @@ -978,6 +985,7 @@ def variational( grad_samples=grad_samples, elbo_samples=elbo_samples, eta=eta, + adapt_engaged=adapt_engaged, adapt_iter=adapt_iter, tol_rel_obj=tol_rel_obj, eval_elbo=eval_elbo, @@ -1010,7 +1018,7 @@ def variational( errors = re.findall(pat, contents) if len(errors) > 0: valid = False - if not valid: + if require_converged and not valid: raise RuntimeError('The algorithm may not have converged.') if not runset._check_retcodes(): msg = 'Error during variational inference.\n{}'.format( diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index c5164a79..09d78dcd 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -553,22 +553,17 @@ def scan_variational_csv(path: str) -> Dict: lineno = scan_column_names(fd, dict, lineno) line = fd.readline().lstrip(' #\t').rstrip() lineno += 1 - if not line.startswith('Stepsize adaptation complete.'): - raise ValueError( - 'line {}: expecting adaptation msg, found:\n\t "{}"'.format( - lineno, line - ) - ) - line = fd.readline().lstrip(' #\t\n') - lineno += 1 - if not line.startswith('eta = 1'): - raise ValueError( - 'line {}: expecting eta = 1, found:\n\t "{}"'.format( - lineno, line + if line.startswith('Stepsize adaptation complete.'): + line = fd.readline().lstrip(' #\t\n') + lineno += 1 + if not line.startswith('eta'): + raise ValueError( + 'line {}: expecting eta, found:\n\t "{}"'.format( + lineno, line + ) ) - ) - line = fd.readline().lstrip(' #\t\n') - lineno += 1 + line = fd.readline().lstrip(' #\t\n') + lineno += 1 xs = line.split(',') variational_mean = [float(x) for x in xs] dict['variational_mean'] = variational_mean diff --git a/test/test_cmdstan_args.py b/test/test_cmdstan_args.py index 28d0176a..17316f94 100644 --- a/test/test_cmdstan_args.py +++ b/test/test_cmdstan_args.py @@ -601,11 +601,26 @@ def test_args_variational(self): self.assertIn('method=variational', ' '.join(cmd)) self.assertIn('output_samples=1', ' '.join(cmd)) - args = VariationalArgs(tol_rel_obj=1) + args = VariationalArgs(tol_rel_obj=0.01) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('method=variational', ' '.join(cmd)) - self.assertIn('tol_rel_obj=1', ' '.join(cmd)) + self.assertIn('tol_rel_obj=0.01', ' '.join(cmd)) + + args = VariationalArgs(adapt_engaged=True, adapt_iter=100) + args.validate(chains=1) + cmd = args.compose(idx=0, cmd=[]) + self.assertIn('adapt engaged=1 iter=100', ' '.join(cmd)) + + args = VariationalArgs(adapt_engaged=False) + args.validate(chains=1) + cmd = args.compose(idx=0, cmd=[]) + self.assertIn('adapt engaged=0', ' '.join(cmd)) + + args = VariationalArgs(eta=0.1) + args.validate(chains=1) + cmd = args.compose(idx=0, cmd=[]) + self.assertIn('eta=0.1', ' '.join(cmd)) def test_args_bad(self): args = VariationalArgs(algorithm='no_such_algo') diff --git a/test/test_variational.py b/test/test_variational.py index af22a0d7..b098b561 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -147,6 +147,10 @@ def test_variational_eta_fail(self): ): model.variational(algorithm='meanfield', seed=12345) + model.variational( + algorithm='meanfield', seed=12345, require_converged=False + ) + if __name__ == '__main__': unittest.main()