Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions cmdstanpy/stanfit/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,11 @@ def summary(
) -> pd.DataFrame:
"""
Run cmdstan/bin/stansummary over all output CSV files, assemble
summary into DataFrame object; first row contains summary statistics
for total joint log probability `lp__`, remaining rows contain summary
summary into DataFrame object. The first row contains statistics
for the total joint log probability `lp__`, but is omitted when the
Stan model has no parameters. The remaining rows contain summary
statistics for all parameters, transformed parameters, and generated
quantities variables listed in the order in which they were declared
in the Stan program.
quantities variables, in program declaration order.

:param percentiles: Ordered non-empty sequence of percentiles to report.
Must be integers from (1, 99), inclusive. Defaults to
Expand All @@ -467,7 +467,6 @@ def summary(

:return: pandas.DataFrame
"""

if len(percentiles) == 0:
raise ValueError(
'Invalid percentiles argument, must be ordered'
Expand Down Expand Up @@ -526,7 +525,14 @@ def summary(
comment='#',
float_precision='high',
)
mask = [x == 'lp__' or not x.endswith('__') for x in summary_data.index]
mask = (
[not x.endswith('__') for x in summary_data.index]
if self._is_fixed_param
else [
x == 'lp__' or not x.endswith('__') for x in summary_data.index
]
)
summary_data.index.name = None
return summary_data[mask]

def diagnose(self) -> Optional[str]:
Expand Down
11 changes: 9 additions & 2 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@ def test_fixed_param_unspecified(self):
iter_sampling=100, show_progress=False
)
self.assertEqual(datagen_fit.step_size, None)
summary = datagen_fit.summary()
self.assertNotIn('lp__', list(summary.index))

exe_only = os.path.join(DATAFILES_PATH, 'exe_only')
shutil.copyfile(datagen_model.exe_file, exe_only)
Expand All @@ -608,6 +610,8 @@ def test_fixed_param_unspecified(self):
)
self.assertEqual(datagen2_fit.chains, 4)
self.assertEqual(datagen2_fit.step_size, None)
summary = datagen2_fit.summary()
self.assertNotIn('lp__', list(summary.index))

def test_bernoulli_file_with_space(self):
self.test_bernoulli_good('bernoulli with space in name.stan')
Expand Down Expand Up @@ -743,11 +747,11 @@ def test_validate_good_run(self):

self.assertEqual(
list(fit.draws_pd(vars=['theta', 'lp__']).columns),
['theta', 'lp__']
['theta', 'lp__'],
)
self.assertEqual(
list(fit.draws_pd(vars=['lp__', 'theta']).columns),
['lp__', 'theta']
['lp__', 'theta'],
)

summary = fit.summary()
Expand All @@ -756,6 +760,9 @@ def test_validate_good_run(self):
self.assertIn('95%', list(summary.columns))
self.assertNotIn('1%', list(summary.columns))
self.assertNotIn('99%', list(summary.columns))
self.assertEqual(summary.index.name, None)
self.assertIn('lp__', list(summary.index))
self.assertIn('theta', list(summary.index))

summary = fit.summary(percentiles=[1, 45, 99])
self.assertIn('1%', list(summary.columns))
Expand Down