Skip to content

Commit

Permalink
Merge pull request #139 from sblunt/view_chop_chains
Browse files Browse the repository at this point in the history
Viewing and chopping chains
  • Loading branch information
sblunt committed Oct 1, 2019
2 parents 48420cb + 8d17c4b commit aaa49d7
Show file tree
Hide file tree
Showing 3 changed files with 364 additions and 46 deletions.
172 changes: 156 additions & 16 deletions docs/tutorials/MCMC_tutorial.ipynb

Large diffs are not rendered by default.

129 changes: 127 additions & 2 deletions orbitize/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from orbitize.system import radec2seppa
import orbitize.results

import matplotlib.pyplot as plt

# Python 2 & 3 handle ABCs differently
if sys.version_info[0] < 3:
ABC = abc.ABCMeta('ABC', (), {})
Expand Down Expand Up @@ -494,7 +496,7 @@ def __init__(self, system, num_temps=20, num_walkers=1000, num_threads=1, like='

def _fill_in_fixed_params(self, sampled_params):
"""
Fills in the missing parameters from the chain that aren't being sampeld
Fills in the missing parameters from the chain that aren't being sampled
Args:
sampled_params (np.array): either 1-D array of size = number of sampled params, or 2-D array of shape (num_models, num_params)
Expand Down Expand Up @@ -553,7 +555,7 @@ def _logl(self, params, include_logp=False):

return super(MCMC, self)._logl(full_params) + logp

def run_sampler(self, total_orbits, burn_steps=0, thin=1):
def run_sampler(self, total_orbits, burn_steps=0, thin=1, examine_chains=False):
"""
Runs PT MCMC sampler. Results are stored in ``self.chain`` and ``self.lnlikes``.
Results also added to ``orbitize.results.Results`` object (``self.results``)
Expand All @@ -569,6 +571,8 @@ def run_sampler(self, total_orbits, burn_steps=0, thin=1):
to discard certain number of steps at the beginning
thin (int): factor to thin the steps of each walker
by to remove correlations in the walker steps
examine_chains (boolean): Displays plots of walkers at each step by
running `examine_chains` after `total_orbits` sampled.
Returns:
``emcee.sampler`` object: the sampler used to run the MCMC
Expand Down Expand Up @@ -631,4 +635,125 @@ def run_sampler(self, total_orbits, burn_steps=0, thin=1):

print('Run complete')

if examine_chains:
self.examine_chains()

return sampler

def examine_chains(self, param_list=None, walker_list=None, n_walkers=None, step_range=None):
"""
Plots position of walkers at each step from Results object. Returns list of figures, one per parameter
Args:
param_list: List of strings of parameters to plot (e.g. "sma1")
If None (default), all parameters are plotted
walker_list: List or array of walker numbers to plot
If None (default), all walkers are plotted
n_walkers (int): Randomly select `n_walkers` to plot
Overrides walker_list if this is set
If None (default), walkers selected as per `walker_list`
step_range (array or tuple): Start and end values of step numbers to plot
If None (default), all the steps are plotted
Returns:
List of ``matplotlib.pyplot.Figure`` objects:
Walker position plot for each parameter selected
(written): Henry Ngo, 2019
"""

# Get the flattened chain from Results object (nwalkers*nsteps, nparams)
flatchain = np.copy(self.results.post)
total_samples, n_params = flatchain.shape
n_steps = np.int(total_samples/self.num_walkers)
# Reshape it to (nwalkers, nsteps, nparams)
chn = flatchain.reshape((self.num_walkers, n_steps, n_params))

# Get list of walkers to use
if n_walkers is not None: # If n_walkers defined, randomly choose that many walkers
walkers_to_plot = np.random.choice(self.num_walkers,size=n_walkers,replace=False)
elif walker_list is not None: # if walker_list is given, use that list
walkers_to_plot = np.array(walker_list)
else: # both n_walkers and walker_list are none, so use all walkers
walkers_to_plot = np.arange(self.num_walkers)

# Get list of parameters to use
if param_list is None:
params_to_plot = np.arange(n_params)
else: # build list from user input strings
params_plot_list = []
for i in param_list:
if i in self.system.param_idx:
params_plot_list.append(self.system.param_idx[i])
else:
raise Exception('Invalid param name: {}. See system.param_idx.'.format(i))
params_to_plot = np.array(params_plot_list)

# Loop through each parameter and make plot
output_figs = []
for pp in params_to_plot:
fig, ax = plt.subplots()
for ww in walkers_to_plot:
ax.plot(chn[ww,:,pp],'k-')
ax.set_xlabel('Step')
if step_range is not None: # Limit range shown if step_range is set
ax.set_xlim(step_range)
output_figs.append(fig)

# Return
return output_figs

def chop_chains(self, burn, trim=0):
"""
Permanently removes steps from beginning (and/or end) of chains from the Results object.
Also updates `curr_pos` if steps are removed from the end of the chain
Args:
burn (int): The number of steps to remove from the beginning of the chains
trim (int): The number of steps to remove from the end of the chians (optional)
Returns:
None. Updates self.curr_pos and the `Results` object.
.. Warning:: Does not update bookkeeping arrays within `MCMC` sampler object.
(written): Henry Ngo, 2019
"""

# Retrieve information from results object
flatchain = np.copy(self.results.post)
total_samples, n_params = flatchain.shape
n_steps = np.int(total_samples/self.num_walkers)
flatlnlikes = np.copy(self.results.lnlike) ## TODO: May have to change this to merge with other branches

# Reshape chain to (nwalkers, nsteps, nparams)
chn = flatchain.reshape((self.num_walkers, n_steps, n_params))
# Reshape lnlike to (nwalkers, nsteps)
lnlikes = flatlnlikes.reshape((self.num_walkers, n_steps))

# Find beginning and end indices for steps to keep
keep_start = burn
keep_end = n_steps - trim
n_chopped_steps = n_steps - trim - burn

# Update arrays in `sampler`: chain, lnlikes, lnlikes_alltemps (if PT), post
chopped_chain = chn[:, keep_start:keep_end, :]
chopped_lnlikes = lnlikes[:, keep_start:keep_end]

# Update current position if trimmed from edge
if trim > 0:
self.curr_pos = chopped_chain[:,-1,:]

# Flatten likelihoods and samples
flat_chopped_chain = chopped_chain.reshape(self.num_walkers*n_chopped_steps, n_params)
flat_chopped_lnlikes = chopped_lnlikes.reshape(self.num_walkers*n_chopped_steps)

# Update results object associated with this sampler
self.results = orbitize.results.Results(
sampler_name = self.__class__.__name__,
post = flat_chopped_chain,
lnlike = flat_chopped_lnlikes,
tau_ref_epoch=self.system.tau_ref_epoch,
labels = self.system.labels
)

# Print a confirmation
print('Chains successfully chopped. Results object updated.')
109 changes: 81 additions & 28 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@
import orbitize.sampler as sampler
import orbitize.system as system
import orbitize.read_input as read_input
import matplotlib.pyplot as plt

def test_pt_mcmc_runs(num_threads=1):
def test_mcmc_runs(num_temps=0, num_threads=1):
"""
Tests the PTMCMC sampler by making sure it even runs
Tests the MCMC sampler by making sure it even runs
Args:
num_temps: Number of temperatures to use
Uses Parallel Tempering MCMC (ptemcee) if > 1,
otherwises, uses Affine-Invariant Ensemble Sampler (emcee)
num_threads: number of threads to run
"""

# use the test_csv dir
testdir = os.path.dirname(os.path.abspath(__file__))
input_file = os.path.join(testdir, 'test_val.csv')

# construct Driver
n_walkers=100
myDriver = Driver(input_file, 'MCMC', 1, 1, 0.01,
mcmc_kwargs={'num_temps':2, 'num_threads':num_threads, 'num_walkers':100}
mcmc_kwargs={'num_temps':2, 'num_threads':num_threads, 'num_walkers':n_walkers}
)

# run it a little (tests 0 burn-in steps)
Expand All @@ -34,36 +42,81 @@ def test_pt_mcmc_runs(num_threads=1):

assert returned_lnlike_test == pytest.approx(computed_lnlike_test, abs=0.01)

def test_ensemble_mcmc_runs(num_threads=1):
def test_examine_chop_chains(num_temps=0, num_threads=1):
"""
Tests the EnsembleMCMC sampler by making sure it even runs
Tests the MCMC sampler's examine_chains and chop_chains methods
Args:
num_temps: Number of temperatures to use
Uses Parallel Tempering MCMC (ptemcee) if > 1,
otherwises, uses Affine-Invariant Ensemble Sampler (emcee)
num_threads: number of threads to run
"""

# use the test_csv dir
testdir = os.path.dirname(os.path.abspath(__file__))
input_file = os.path.join(testdir, 'test_val.csv')

myDriver = Driver(input_file, 'MCMC', 1, 1, 0.01,
mcmc_kwargs={'num_temps':1, 'num_threads':num_threads, 'num_walkers':100}
)

# run it a little (tests 0 burn-in steps)
myDriver.sampler.run_sampler(100)

# run it a little more
myDriver.sampler.run_sampler(1000, burn_steps=1)

# run it a little more (tests adding to results object)
myDriver.sampler.run_sampler(1000, burn_steps=1)

# test that lnlikes being saved are correct
returned_lnlike_test = myDriver.sampler.results.lnlike[0]
computed_lnlike_test = myDriver.sampler._logl(myDriver.sampler.results.post[0])

assert returned_lnlike_test == pytest.approx(computed_lnlike_test, abs=0.01)
data_table = read_input.read_file(input_file)
# Manually set 'object' column of data table
data_table['object'] = 1

# construct the system
orbit = system.System(1, data_table, 1, 0.01)

# construct Driver
n_walkers = 20
mcmc = sampler.MCMC(orbit, num_temps, n_walkers, num_threads=num_threads)

# run it a little
n_samples1 = 2000 # 100 steps for each of 20 walkers
n_samples2 = 2000 # 100 steps for each of 20 walkers
n_samples = n_samples1+n_samples2
mcmc.run_sampler(n_samples1)
# run it a little more (tries examine_chains within run_sampler)
mcmc.run_sampler(n_samples2, examine_chains=True)
# (4000 orbit samples = 20 walkers x 200 steps)

# Try all variants of examine_chains
mcmc.examine_chains()
plt.close('all') # Close figures generated
fig_list = mcmc.examine_chains(param_list=['sma1','ecc1','inc1'])
# Should only get 3 figures
assert len(fig_list) == 3
plt.close('all') # Close figures generated
mcmc.examine_chains(walker_list=[10, 12])
plt.close('all') # Close figures generated
mcmc.examine_chains(n_walkers=5)
plt.close('all') # Close figures generated
mcmc.examine_chains(step_range=[50,100])
plt.close('all') # Close figures generated

# Now try chopping the chains
# Chop off first 50 steps
chop1=50
mcmc.chop_chains(chop1)
# Calculate expected number of orbits now
expected_total_orbits = n_samples - chop1*n_walkers
# Check lengths of arrays in results object
assert len(mcmc.results.lnlike) == expected_total_orbits
assert mcmc.results.post.shape[0] == expected_total_orbits

# With 150 steps left, now try to trim 25 steps off each end
chop2 = 25
trim2 = 25
mcmc.chop_chains(chop2,trim=trim2)
# Calculated expected number of orbits now
samples_removed = (chop1 + chop2 + trim2)*n_walkers
expected_total_orbits = n_samples - samples_removed
# Check lengths of arrays in results object
assert len(mcmc.results.lnlike) == expected_total_orbits
assert mcmc.results.post.shape[0] == expected_total_orbits

if __name__ == "__main__":
test_pt_mcmc_runs(num_threads=1)
test_pt_mcmc_runs(num_threads=4)
test_ensemble_mcmc_runs(num_threads=1)
test_ensemble_mcmc_runs(num_threads=8)
# Parallel Tempering tests
test_mcmc_runs(num_temps=2, num_threads=1)
test_mcmc_runs(num_temps=2, num_threads=4)
# Ensemble MCMC tests
test_mcmc_runs(num_temps=0, num_threads=1)
test_mcmc_runs(num_temps=0, num_threads=8)
# Test examine/chop chains
test_examine_chop_chains(num_temps=5) # PT
test_examine_chop_chains(num_temps=0) # Ensemble

0 comments on commit aaa49d7

Please sign in to comment.