diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 91fad17d..67a7d611 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -1142,8 +1142,19 @@ def variational( errors = re.findall(pat, contents) if len(errors) > 0: valid = False - if require_converged and not valid: - raise RuntimeError('The algorithm may not have converged.') + if not valid: + if require_converged: + raise RuntimeError( + 'The algorithm may not have converged.\n' + 'If you would like to inspect the output, ' + 're-call with require_converged=False' + ) + # else: + get_logger().warning( + '%s\n%s', + 'The algorithm may not have converged.', + 'Proceeding because require_converged is set to False', + ) if not runset._check_retcodes(): msg = 'Error during variational inference:\n{}'.format( runset.get_err_msgs() diff --git a/test/test_variational.py b/test/test_variational.py index 2debadec..581a01d6 100644 --- a/test/test_variational.py +++ b/test/test_variational.py @@ -235,12 +235,22 @@ def test_variational_eta_fail(self): ) model = CmdStanModel(stan_file=stan) with self.assertRaisesRegex( - RuntimeError, 'algorithm may not have converged' + RuntimeError, + r'algorithm may not have converged\.\n.*require_converged', ): model.variational(algorithm='meanfield', seed=12345) - model.variational( - algorithm='meanfield', seed=12345, require_converged=False + with LogCapture() as log: + model.variational( + algorithm='meanfield', seed=12345, require_converged=False + ) + log.check_present( + ( + 'cmdstanpy', + 'WARNING', + 'The algorithm may not have converged.\n' + 'Proceeding because require_converged is set to False', + ) )