Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,8 +1142,19 @@ def variational(
errors = re.findall(pat, contents)
if len(errors) > 0:
valid = False
if require_converged and not valid:
raise RuntimeError('The algorithm may not have converged.')
if not valid:
if require_converged:
raise RuntimeError(
'The algorithm may not have converged.\n'
'If you would like to inspect the output, '
're-call with require_converged=False'
)
# else:
get_logger().warning(
'%s\n%s',
'The algorithm may not have converged.',
'Proceeding because require_converged is set to False',
)
if not runset._check_retcodes():
msg = 'Error during variational inference:\n{}'.format(
runset.get_err_msgs()
Expand Down
16 changes: 13 additions & 3 deletions test/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,22 @@ def test_variational_eta_fail(self):
)
model = CmdStanModel(stan_file=stan)
with self.assertRaisesRegex(
RuntimeError, 'algorithm may not have converged'
RuntimeError,
r'algorithm may not have converged\.\n.*require_converged',
):
model.variational(algorithm='meanfield', seed=12345)

model.variational(
algorithm='meanfield', seed=12345, require_converged=False
with LogCapture() as log:
model.variational(
algorithm='meanfield', seed=12345, require_converged=False
)
log.check_present(
(
'cmdstanpy',
'WARNING',
'The algorithm may not have converged.\n'
'Proceeding because require_converged is set to False',
)
)


Expand Down