Skip to content

Commit

Permalink
Merge pull request #253 from adj-smith/patch-1
Browse files Browse the repository at this point in the history
Added `check_prior_support()` to `sampler.MCMC()`
  • Loading branch information
semaphoreP committed Sep 1, 2021
2 parents e5776bb + 9da84a1 commit d8ce41a
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
5 changes: 4 additions & 1 deletion orbitize/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Prior(ABC):
Written: Sarah Blunt, 2018
"""

is_correlated = False

@abc.abstractmethod
def draw_samples(self, num_samples):
pass
Expand Down Expand Up @@ -359,9 +361,10 @@ def all_lnpriors(params, priors):
float: prior probability of this set of parameters
"""
logp = 0.

for param, prior in zip(params, priors):
param = np.array([param])

logp += prior.compute_lnprob(param) # retrun a float

return logp
Expand Down
43 changes: 43 additions & 0 deletions orbitize/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,46 @@ def chop_chains(self, burn, trim=0):

# Print a confirmation
print('Chains successfully chopped. Results object updated.')

def check_prior_support(self,):
"""
Review the positions of all MCMC walkers, to verify that they are supported by the prior space.
This function will raise a descriptive ValueError if any positions lie outside prior support.
Otherwise, it will return nothing.
Args:
None.
Returns:
None.
(written): Adam Smith, 2021
"""

# Flatten the walker/temperature positions for ease of manipulation.
all_positions = self.curr_pos.reshape(self.num_walkers*self.num_temps,self.num_params)

# Placeholder list to track any bad parameters that come up.
bad_parameters = []

# If there are no covarient priors, loop on each variable to locate any out-of-place parameters. (this is why we transpose the walkers)
if not np.any([prior.is_correlated for prior in self.priors]):
for i, x in enumerate(all_positions.T):
# Any issues with this parameter?
lnprob = self.priors[i].compute_lnprob(np.array(x))
supported = np.isfinite(lnprob).all() == True

if supported == False:
# Problem detected. Take note and continue the loop - we want to catch all the problem parameters.
bad_parameters.append(str(i))

# Throw our ValueError if necessary,
if len(bad_parameters) > 0:
raise ValueError("Attempting to start with walkers outside of prior support: check parameter(s) "+', '.join(bad_parameters))

# We're not done yet, however. There may be errors in covariant priors; run a check for that.
else:
for y in all_positions:
lnprob = orbitize.priors.all_lnpriors(y,self.priors)
if not np.isfinite(lnprob).all():
raise ValueError("Attempting to start with walkers outside of prior support: covariant prior failure.")

# otherwise exit the function and continue.
return
58 changes: 58 additions & 0 deletions tests/test_checkpriorsupport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import orbitize
import orbitize.sampler as sampler
from orbitize import driver
import random
import warnings

warnings.filterwarnings('ignore')


def test_check_prior_support(PriorChanges=False):
'''
Test the check_prior_support() function to ensure it behaves correctly.
Should fail with a ValueError if any parameters are outside prior support.
Should behave normally if all parameters are within prior support.
'''

# set up as if we are running the tutorial retrieval
myDriver = driver.Driver(
'{}/GJ504.csv'.format(orbitize.DATADIR), # data file
'MCMC', # choose from: ['OFTI', 'MCMC']
1, # number of planets in system
1.22, # total system mass [M_sun]
56.95, # system parallax [mas]
mass_err=0.08, # mass error [M_sun]
plx_err=0.26, # parallax error [mas]
mcmc_kwargs={'num_temps':2, 'num_walkers':18,'num_threads':1}
)

# mess with the priors if requested
if PriorChanges:
zlist = []
for a in range(4):
x = random.randint(0,myDriver.sampler.num_temps-1)
y = random.randint(0,myDriver.sampler.num_walkers-1)
z = random.randint(0,myDriver.sampler.num_params-1)
myDriver.sampler.curr_pos[x,y,z] = -1000
zlist.append(z)


# run the tests
try:
orbits = myDriver.sampler.check_prior_support()
# catch the correct error
except ValueError as error:
errorCaught = True
# make sure nothing else broke
except:
print('something has gone horribly wrong')
# state if otherwise
else:
errorCaught = False

assert errorCaught == PriorChanges

if __name__ == '__main__':
test_check_prior_support()
test_check_prior_support(PriorChanges=True)

0 comments on commit d8ce41a

Please sign in to comment.