Skip to content

Commit

Permalink
Merge pull request #37 from prmiles/update_chainstat_display
Browse files Browse the repository at this point in the history
added acceptance rate display feature - this closes #36
  • Loading branch information
Paul Miles committed Apr 30, 2019
2 parents 6320aba + 5dc4e17 commit 5cddb8f
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 57 deletions.
70 changes: 68 additions & 2 deletions pymcmcstat/chain/ChainStatistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def chainstats(chain=None, results=None, returnstats=False):
tau, m = integrated_autocorrelation_time(chain)
# print statistics
print_chain_statistics(names, meanii, stdii, mcerr, tau, p)
# print acceptance rate
print_chain_acceptance_info(chain, results=results)
# assign stats to dictionary
stats = dict(mean=list(meanii),
std=list(stdii),
Expand Down Expand Up @@ -88,7 +90,8 @@ def print_chain_statistics(names, meanii, stdii, mcerr, tau, p):
'''
npar = len(names)
# print statistics
print('\n---------------------')
print('\n')
print(30*'-')
print('{:10s}: {:>10s} {:>10s} {:>10s} {:>10s} {:>10s}'.format(
'name',
'mean',
Expand All @@ -112,7 +115,70 @@ def print_chain_statistics(names, meanii, stdii, mcerr, tau, p):
mcerr[ii],
tau[ii],
p[ii]))
print('---------------------')
print(30*'-')


# ----------------------------------------------------
def print_chain_acceptance_info(chain, results=None):
'''
Print chain acceptance rate(s)
If results structure is provided, it will try to print
acceptance rates with respect to delayed rejection as well.
Example display (if results dictionary provided):
::
------------------------------
Acceptance rate information
---------------
Results dictionary:
Stage 1: 3.32%
Stage 2: 22.60%
Net : 25.92% -> 1296/5000
---------------
Chain provided:
Net : 32.10% -> 963/3000
---------------
Note, the net acceptance rate from the results dictionary
may be different if you only provided a subset of the chain,
e.g., removed the first part for burnin-in.
------------------------------
Args:
* **chain** (:class:`~numpy.ndarray`): Sampling chain.
* **results** (:py:class:`dict`): Results from MCMC simulation. \
Default is `None`.
'''
print('Acceptance rate information')
flag = False
if results is not None:
if 'iacce' in results:
if 'nsimu' in results:
nsimu = results['nsimu']
else:
nsimu = chain.shape[0]
print(15*'-')
print('Results dictionary:')
for ii, stage in enumerate(results['iacce']):
print('Stage {:d}: {:4.2f}%'.format(ii + 1, stage/nsimu * 100))
print('Net : {:4.2f}% -> {:d}/{:d}'.format(
results['iacce'].sum()/nsimu * 100,
results['iacce'].sum(), nsimu))
print(15*'-')
flag = True
print('Chain provided:')
unique_elem = np.unique(chain[:, 0]).size
print('Net : {:4.2f}% -> {:d}/{:d}'.format(
unique_elem/chain.shape[0] * 100,
unique_elem, chain.shape[0]))
if flag is True:
print(15*'-')
print('Note, the net acceptance rate from the results dictionary\n'
+ 'may be different if you only provided a subset of the chain,\n'
+ 'e.g., removed the first part for burnin-in.')
print(30*'-')


# ----------------------------------------------------
Expand Down
161 changes: 106 additions & 55 deletions test/chain/test_ChainStatistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,129 +6,180 @@
@author: prmiles
"""

from pymcmcstat.chain import ChainStatistics
from pymcmcstat.chain import ChainStatistics as CS
import unittest
import numpy as np
import io
import sys

CS = ChainStatistics
chain = np.random.random_sample(size = (1000,2))
chain[:,1] = 1e6*chain[:,1]

chain = np.random.random_sample(size=(1000, 2))
chain[:, 1] = 1e6*chain[:, 1]


# --------------------------
# chainstats
# --------------------------
class Chainstats_Eval(unittest.TestCase):

def test_cs_eval_with_return(self):
stats = CS.chainstats(chain = chain, returnstats = True)
stats = CS.chainstats(chain=chain, returnstats=True)
self.assertTrue(isinstance(stats,dict))

def test_cs_eval_with_no_return(self):
stats = CS.chainstats(chain = chain, returnstats = False)
stats = CS.chainstats(chain=chain, returnstats=False)
self.assertEqual(stats, None)

def test_cs_eval_with_no_chain(self):
stats = CS.chainstats(chain = None, returnstats = True)
stats = CS.chainstats(chain=None, returnstats=True)
self.assertTrue(isinstance(stats, str))


# --------------------------
class ChainAcceptanceRateInfo(unittest.TestCase):

def test_string_display(self):
capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = capturedOutput # and redirect stdout.
CS.print_chain_acceptance_info(chain)
sys.stdout = sys.__stdout__ # Reset redirect.
self.assertTrue(isinstance(capturedOutput.getvalue(), str),
msg='Caputured string')
self.assertFalse('Results dictionary' in capturedOutput.getvalue(),
msg='Expect results dictionary not included')

def test_string_display_with_results(self):
capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = capturedOutput # and redirect stdout.
CS.print_chain_acceptance_info(
chain, results=dict(
nsimu=5000,
iacce=np.array([200, 800])
))
sys.stdout = sys.__stdout__ # Reset redirect.
self.assertTrue(isinstance(capturedOutput.getvalue(), str),
msg='Caputured string')
self.assertTrue('Results dictionary' in capturedOutput.getvalue(),
msg='Expect results dictionary included')


# --------------------------
class BatchMeanSTD(unittest.TestCase):

def test_len_s(self):
s = CS.batch_mean_standard_deviation(chain, b = None)
s = CS.batch_mean_standard_deviation(chain, b=None)
self.assertEqual(len(s), 2)

def test_too_few_batches(self):
with self.assertRaises(SystemExit, msg = 'too few batches'):
CS.batch_mean_standard_deviation(chain, b = chain.shape[0])
with self.assertRaises(SystemExit, msg='too few batches'):
CS.batch_mean_standard_deviation(chain, b=chain.shape[0])


# --------------------------
class PowerSpectralDensity(unittest.TestCase):

def test_nfft_none_size(self):
x = chain[:,0]
y = CS.power_spectral_density_using_hanning_window(x = x)
nfft = min(len(x),256)
x = chain[:, 0]
y = CS.power_spectral_density_using_hanning_window(x=x)
nfft = min(len(x), 256)
n2 = int(np.floor(nfft/2))
self.assertEqual(n2, len(y))

def test_nfft_not_none_size(self):
x = chain[:,0]
x = chain[:, 0]
nfft = 100
y = CS.power_spectral_density_using_hanning_window(x = x, nfft = nfft)
y = CS.power_spectral_density_using_hanning_window(x=x, nfft=nfft)
n2 = int(np.floor(nfft/2))
self.assertEqual(n2, len(y))

def test_nw_not_none(self):
x = chain[:,0]
y = CS.power_spectral_density_using_hanning_window(x = x, nw = len(x))
nfft = min(len(x),256)
x = chain[:, 0]
y = CS.power_spectral_density_using_hanning_window(x=x, nw=len(x))
nfft = min(len(x), 256)
n2 = int(np.floor(nfft/2))
self.assertEqual(n2, len(y))



# --------------------------
def setup_chains():
chains = []
chains=[]
for ii in range(4):
chains.append(np.concatenate((ii*np.linspace(0, 1, 1000).reshape(1000,1), ii*np.linspace(2.5, 3.3, 1000).reshape(1000,1)), axis = 1))
chains.append(
np.concatenate((ii*np.linspace(0, 1, 1000).reshape(1000, 1),
ii*np.linspace(2.5, 3.3, 1000).reshape(1000,1)),
axis=1))
return chains


class GelmanRubin(unittest.TestCase):
def standard_check(self, psrf):
self.assertTrue(isinstance(psrf, dict), msg = 'Expect dictionary output')
self.assertTrue(isinstance(psrf, dict),
msg='Expect dictionary output')
check_these = ['B', 'W', 'V', 'R', 'neff']
for _, ps in enumerate(psrf):
for ct in check_these:
self.assertTrue(ct in psrf[ps], msg = str('{} not in {}'.format(ct, ps)))

self.assertTrue(ct in psrf[ps],
msg=str('{} not in {}'.format(ct, ps)))

def test_gelman_rubin(self):
chains = setup_chains()
chains=setup_chains()
capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = capturedOutput # and redirect stdout.
psrf = CS.gelman_rubin(chains = chains, display = True)
psrf = CS.gelman_rubin(chains=chains, display=True)
sys.stdout = sys.__stdout__ # Reset redirect.
self.assertTrue(isinstance(capturedOutput.getvalue(), str), msg = 'Caputured string')
self.assertTrue(isinstance(capturedOutput.getvalue(), str),
msg='Caputured string')
self.standard_check(psrf)

def test_gelman_rubin_no_display(self):
chains = setup_chains()
chains=setup_chains()
capturedOutput = io.StringIO() # Create StringIO object
sys.stdout = capturedOutput # and redirect stdout.
psrf = CS.gelman_rubin(chains = chains, display = False)
psrf = CS.gelman_rubin(chains=chains, display=False)
sys.stdout = sys.__stdout__ # Reset redirect.
self.assertTrue(isinstance(capturedOutput.getvalue(), str), msg = 'Caputured string')
self.assertEqual(capturedOutput.getvalue(), '', msg = 'Caputured string')
self.assertTrue(isinstance(capturedOutput.getvalue(), str),
msg='Caputured string')
self.assertEqual(capturedOutput.getvalue(), '',
msg='Caputured string')
self.standard_check(psrf)

def test_gelman_rubin_with_pres(self):
chains = setup_chains()
pres = []
for _, chain in enumerate(chains):
pres.append(dict(chain = chain, nsimu = chain.shape[0]))

psrf = CS.gelman_rubin(chains = pres)
pres.append(dict(chain=chain, nsimu=chain.shape[0]))
psrf = CS.gelman_rubin(chains=pres)
self.standard_check(psrf)

def test_gelman_rubin_with_names(self):
chains = setup_chains()
psrf = CS.gelman_rubin(chains = chains, names = ['a', 'b'])
psrf = CS.gelman_rubin(chains=chains, names=['a', 'b'])
self.standard_check(psrf)

def test_gelman_rubin_raise_error(self):
chains = setup_chains()
for _ in range(len(chains)-1):
chains=setup_chains()
for _ in range(len(chains) - 1):
chains.pop(-1)
with self.assertRaises(ValueError, msg = 'Must have multiple chains'):
CS.gelman_rubin(chains = chains)

with self.assertRaises(ValueError,
msg='Must have multiple chains'):
CS.gelman_rubin(chains=chains)


# --------------------------
class PSRF(unittest.TestCase):
def test_calc_psrf(self):
x = np.concatenate((np.linspace(0, 1, 1000).reshape(1000,1), np.linspace(2.5, 3.3, 1000).reshape(1000,1)), axis = 1)
psrf = CS.calculate_psrf(x, nsimu = 1000, nchains = 2)
self.assertTrue(isinstance(psrf, dict), msg = 'Expect dictionary output')
self.assertAlmostEqual(psrf['R'], 8.001818935964492, places = 6, msg = str('R: {} neq {}'.format(psrf['R'], 8.001818935964492)))
self.assertAlmostEqual(psrf['B'], 2879.9999999999964, places = 6, msg = str('R: {} neq {}'.format(psrf['B'], 2879.9999999999964)))
self.assertAlmostEqual(psrf['W'], 0.06853867547894903, places = 6, msg = str('R: {} neq {}'.format(psrf['W'], 0.06853867547894903)))
self.assertAlmostEqual(psrf['V'], 4.388470136803464, places = 6, msg = str('R: {} neq {}'.format(psrf['V'], 4.388470136803464)))
self.assertAlmostEqual(psrf['neff'], 3.047548706113521, places = 6, msg = str('R: {} neq {}'.format(psrf['neff'], 3.047548706113521)))
x = np.concatenate((np.linspace(0, 1, 1000).reshape(1000, 1),
np.linspace(2.5, 3.3, 1000).reshape(1000, 1)),
axis=1)
psrf = CS.calculate_psrf(x, nsimu=1000, nchains=2)
self.assertTrue(isinstance(psrf, dict), msg='Expect dictionary output')
self.assertAlmostEqual(psrf['R'], 8.001818935964492, places=6,
msg=str('R: {} neq {}'.format(psrf['R'], 8.001818935964492)))
self.assertAlmostEqual(psrf['B'], 2879.9999999999964, places=6,
msg=str('R: {} neq {}'.format(psrf['B'], 2879.9999999999964)))
self.assertAlmostEqual(psrf['W'], 0.06853867547894903, places=6,
msg=str('R: {} neq {}'.format(psrf['W'], 0.06853867547894903)))
self.assertAlmostEqual(psrf['V'], 4.388470136803464, places=6,
msg=str('R: {} neq {}'.format(psrf['V'], 4.388470136803464)))
self.assertAlmostEqual(psrf['neff'], 3.047548706113521, places=6,
msg=str('R: {} neq {}'.format(psrf['neff'], 3.047548706113521)))

0 comments on commit 5cddb8f

Please sign in to comment.