Skip to content

Commit

Permalink
Merge pull request #212 from sblunt/param_idx-labels
Browse files Browse the repository at this point in the history
Param idx labels for Results and MCMC classes
  • Loading branch information
sblunt committed Jul 20, 2021
2 parents 62a7509 + 1f70982 commit b974a97
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 3 deletions.
18 changes: 15 additions & 3 deletions orbitize/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ def __init__(self, sampler_name=None, post=None, lnlike=None, tau_ref_epoch=None
self.post = post
self.lnlike = lnlike
self.tau_ref_epoch = tau_ref_epoch
self.labels = labels
if self.labels is not None:
self.param_idx = dict(zip(self.labels, np.arange(len(self.labels))))
else:
self.param_idx = None
self.data=data
self.labels=labels
self.num_secondary_bodies=num_secondary_bodies
self.curr_pos = curr_pos
self.version_number = version_number
Expand All @@ -89,13 +93,17 @@ def add_samples(self, orbital_params, lnlikes, labels, curr_pos=None):
Written: Henry Ngo, 2018
"""

# Adding the orbitize version number to the results
self.version_number = orbitize.__version__

# If no exisiting results then it is easy
if self.post is None:
self.post = orbital_params
self.lnlike = lnlikes
self.labels = labels
self.param_idx = dict(zip(self.labels, np.arange(len(self.labels))))

# Otherwise, need to append properly
else:
self.post = np.vstack((self.post, orbital_params))
Expand Down Expand Up @@ -144,7 +152,7 @@ def save_results(self, filename):
hf.create_dataset('lnlike', data=self.lnlike)
if self.labels is not None:
hf['col_names'] = np.array(self.labels).astype('S')
hf.attrs['parameter_labels'] = self.labels # Rob: added this to account for the RV labels
hf.attrs['parameter_labels'] = self.labels
if self.num_secondary_bodies is not None:
hf.attrs['num_secondary_bodies'] = self.num_secondary_bodies
if self.curr_pos is not None:
Expand Down Expand Up @@ -190,6 +198,10 @@ def load_results(self, filename, append=False):
# again, probably an old file without saved parameter labels
# old files only fit single planets
labels = ['sma1', 'ecc1', 'inc1', 'aop1', 'pan1', 'tau1', 'plx', 'mtot']

# rebuild parameter dictionary
self.param_idx = dict(zip(labels, np.arange(len(labels))))

try:
num_secondary_bodies = int(hf.attrs['num_secondary_bodies'])
except KeyError:
Expand Down Expand Up @@ -246,10 +258,10 @@ def load_results(self, filename, append=False):
# Only proceed if object is completely empty
if self.sampler_name is None and self.post is None and self.lnlike is None and self.tau_ref_epoch is None and self.version_number is None:
self._set_sampler_name(sampler_name)
self.labels = labels
self._set_version_number(version_number)
self.add_samples(post, lnlike, self.labels)
self.tau_ref_epoch = tau_ref_epoch
self.labels = labels
self.num_secondary_bodies = num_secondary_bodies
else:
raise Exception(
Expand Down
7 changes: 7 additions & 0 deletions orbitize/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,20 @@ def __init__(self, system, num_temps=20, num_walkers=1000, num_threads=1, like='
# get priors from the system class. need to remove and record fixed priors
self.priors = []
self.fixed_params = []

self.sampled_param_idx = {}
sampled_param_counter = 0
for i, prior in enumerate(system.sys_priors):

# check for fixed parameters
if not hasattr(prior, "draw_samples"):
self.fixed_params.append((i, prior))
else:
self.priors.append(prior)
self.sampled_param_idx[self.system.labels[i]] = sampled_param_counter
sampled_param_counter += 1

# initialize walkers initial postions
self.num_params = len(self.priors)

if prev_result_filename is None:
Expand Down
37 changes: 37 additions & 0 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
import orbitize.results as results
import matplotlib.pyplot as plt

std_param_idx_fixed_mtot_plx = {
'sma1': 0, 'ecc1':1, 'inc1':2, 'aop1':3, 'pan1':4, 'tau1':5
}

std_param_idx = {
'sma1': 0, 'ecc1':1, 'inc1':2, 'aop1':3, 'pan1':4, 'tau1':5, 'plx':6, 'mtot':7
}

def test_mcmc_runs(num_temps=0, num_threads=1):
"""
Expand Down Expand Up @@ -144,6 +151,34 @@ def test_examine_chop_chains(num_temps=0, num_threads=1):
assert mcmc.results.post.shape[0] == expected_total_orbits


def test_mcmc_param_idx():

# use the test_csv dir
input_file = os.path.join(orbitize.DATADIR, 'test_val.csv')
data_table = read_input.read_formatted_file(input_file)

# Manually set 'object' column of data table
data_table['object'] = 1

# construct Driver with fixed mass and plx
n_walkers = 100
myDriver = Driver(input_file, 'MCMC', 1, 1, 0.01,
mcmc_kwargs={'num_temps': 0, 'num_threads': 1,
'num_walkers': n_walkers}
)

# check that sampler.param_idx behaves as expected
assert myDriver.sampler.sampled_param_idx == std_param_idx_fixed_mtot_plx

# construct Driver with no fixed params
myDriver = Driver(input_file, 'MCMC', 1, 1, 0.01, mass_err=0.1, plx_err=0.2,
mcmc_kwargs={'num_temps': 0, 'num_threads': 1,
'num_walkers': n_walkers}
)

assert myDriver.sampler.sampled_param_idx == std_param_idx


if __name__ == "__main__":
# Parallel Tempering tests
test_mcmc_runs(num_temps=2, num_threads=1)
Expand All @@ -154,3 +189,5 @@ def test_examine_chop_chains(num_temps=0, num_threads=1):
# Test examine/chop chains
test_examine_chop_chains(num_temps=5) # PT
test_examine_chop_chains(num_temps=0) # Ensemble
# param_idx utility tests
test_mcmc_param_idx()
4 changes: 4 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import os

std_labels = ['sma1', 'ecc1', 'inc1', 'aop1', 'pan1', 'tau1', 'plx', 'mtot']
std_param_idx = {
'sma1': 0, 'ecc1':1, 'inc1':2, 'aop1':3, 'pan1':4, 'tau1':5, 'plx':6, 'mtot':7
}


def simulate_orbit_sampling(n_sim_orbits):
Expand Down Expand Up @@ -139,6 +142,7 @@ def test_save_and_load_results(results_to_test, has_lnlike=True):
expected_length = original_length * 2
assert loaded_results.post.shape == (expected_length, 8)
assert loaded_results.labels.tolist() == std_labels
assert loaded_results.param_idx == std_param_idx
if has_lnlike:
assert loaded_results.lnlike.shape == (expected_length,)

Expand Down

0 comments on commit b974a97

Please sign in to comment.