Skip to content

Commit

Permalink
Merge pull request #265 from stan-dev/feature/250-summary-specify-probs
Browse files Browse the repository at this point in the history
added percentiles arg to summary command, unit tests
  • Loading branch information
mitzimorris committed Aug 5, 2020
2 parents ac80e60 + de9a5bd commit 76d7a7a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def sample(
:param parallel_chains: Number of processes to run in parallel. Must be
a positive integer. Defaults to ``multiprocessing.cpu_count()``.
:param threads_per_chain: the number of threads to use in parallelized
:param threads_per_chain: The number of threads to use in parallelized
sections within an MCMC chain (e.g., when using the Stan functions
``reduce_sum()`` or ``map_rect()``). This will only have an effect
if the model was compiled with threading support. The total number
Expand Down
25 changes: 24 additions & 1 deletion cmdstanpy/stanfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,34 @@ def _assemble_sample(self) -> None:
xs = line.split(',')
self._sample[i, chain, :] = [float(x) for x in xs]

def summary(self) -> pd.DataFrame:
def summary(self, percentiles: List[int] = None) -> pd.DataFrame:
"""
Run cmdstan/bin/stansummary over all output csv files.
Echo stansummary stdout/stderr to console.
Assemble csv tempfile contents into pandasDataFrame.
:param percentiles: Ordered non-empty list of percentiles to report.
Must be integers from (1, 99), inclusive.
"""
percentiles_str = '--percentiles=5,50,95'
if percentiles is not None:
if len(percentiles) == 0:
raise ValueError(
'invalid percentiles argument, must be ordered'
' non-empty list from (1, 99), inclusive.'
)

cur_pct = 0
for pct in percentiles:
if pct > 99 or not pct > cur_pct:
raise ValueError(
'invalid percentiles spec, must be ordered'
' non-empty list from (1, 99), inclusive.'
)
cur_pct = pct
percentiles_str = '='.join(
['--percentiles', ','.join([str(x) for x in percentiles])]
)
cmd_path = os.path.join(
cmdstan_path(), 'bin', 'stansummary' + EXTENSION
)
Expand All @@ -516,6 +538,7 @@ def summary(self) -> pd.DataFrame:
)
cmd = [
cmd_path,
percentiles_str,
'--csv_file={}'.format(tmp_csv_path),
] + self.runset.csv_files
do_command(cmd, logger=self.runset._logger)
Expand Down
41 changes: 27 additions & 14 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,22 +444,35 @@ def test_validate_good_run(self):
drawset.shape,
(fit.runset.chains * fit.num_draws, len(fit.column_names)),
)
_ = fit.summary()
self.assertTrue(True)

# TODO - use cmdstan test files instead
expected = '\n'.join(
[
'Checking sampler transitions treedepth.',
'Treedepth satisfactory for all transitions.',
'\nChecking sampler transitions for divergences.',
'No divergent transitions found.',
'\nChecking E-BFMI - sampler transitions HMC potential energy.',
'E-BFMI satisfactory for all transitions.',
'\nEffective sample size satisfactory.',
]
summary = fit.summary()
self.assertIn('5%', list(summary.columns))
self.assertIn('50%', list(summary.columns))
self.assertIn('95%', list(summary.columns))
self.assertNotIn('1%', list(summary.columns))
self.assertNotIn('99%', list(summary.columns))

summary = fit.summary(percentiles=[1, 45, 99])
self.assertIn('1%', list(summary.columns))
self.assertIn('45%', list(summary.columns))
self.assertIn('99%', list(summary.columns))
self.assertNotIn('5%', list(summary.columns))
self.assertNotIn('50%', list(summary.columns))
self.assertNotIn('95%', list(summary.columns))

with self.assertRaises(ValueError):
fit.summary(percentiles=[])

with self.assertRaises(ValueError):
fit.summary(percentiles=[-1])

diagnostics = fit.diagnose()
self.assertIn(
'Treedepth satisfactory for all transitions.', diagnostics
)
self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
self.assertIn('No divergent transitions found.', diagnostics)
self.assertIn('E-BFMI satisfactory for all transitions.', diagnostics)
self.assertIn('Effective sample size satisfactory.', diagnostics)

def test_validate_big_run(self):
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
Expand Down

0 comments on commit 76d7a7a

Please sign in to comment.