Skip to content

Commit

Permalink
fix #76 and modify test to catch it
Browse files Browse the repository at this point in the history
  • Loading branch information
kbarbary committed Feb 13, 2015
1 parent 8916712 commit 00f90b0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
29 changes: 17 additions & 12 deletions sncosmo/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,9 @@ def _nest_lc(data, model, vparam_names, modelcov,
bounds=None, priors=None, ppfs=None, tied=None,
nobj=100, maxiter=10000, maxcall=1000000, verbose=False):
"""Assumes that data has already been standardized.
Run `data = standardize_data(data)`"""

# Order vparam_names the same way it is ordered in the model:
vparam_names = [s for s in model.param_names if s in vparam_names]
(do `data = standardize_data(data)`). Output samples and parameter
names are not reordered (order wil be the same as input parameter list).
"""

if ppfs is None:
ppfs = {}
Expand Down Expand Up @@ -596,7 +594,7 @@ def loglikelihood(parameters):

res = nest.nest(loglikelihood, prior, npar, nipar, nobj=nobj,
maxiter=maxiter, maxcall=maxcall, verbose=verbose)
res.vparam_names = vparam_names
res.vparam_names = copy.copy(vparam_names)
res.ndof = len(data) - len(vparam_names)
return res

Expand Down Expand Up @@ -706,6 +704,9 @@ def nest_lc(data, model, vparam_names, bounds, guess_amplitude_bound=False,
model = copy.copy(model)
bounds = copy.copy(bounds) # need to copy this dict b/c we modify it below

# Order vparam_names the same way it is ordered in the model:
vparam_names = [s for s in model.param_names if s in vparam_names]

# Drop data that the model doesn't cover.
data = cut_bands(data, model, z_bounds=bounds.get('z', None))

Expand Down Expand Up @@ -747,7 +748,7 @@ def nest_lc(data, model, vparam_names, bounds, guess_amplitude_bound=False,
return res, model


def mcmc_lc(data, model, param_names, errors, bounds=None, nwalkers=10,
def mcmc_lc(data, model, vparam_names, errors, bounds=None, nwalkers=10,
nburn=100, nsamples=500, verbose=False):
"""Run an MCMC chain to get model parameter samples.
Expand All @@ -766,7 +767,7 @@ def mcmc_lc(data, model, param_names, errors, bounds=None, nwalkers=10,
Table of photometric data. Must include certain column names.
model : `~sncosmo.Model`
The model to fit.
param_names : iterable
vparam_names : iterable
Model parameters to vary.
errors : iterable
The starting positions of the walkers are randomly selected from a
Expand Down Expand Up @@ -795,21 +796,25 @@ def mcmc_lc(data, model, param_names, errors, bounds=None, nwalkers=10,
raise ImportError("mcmc_lc() requires the emcee package.")

data = standardize_data(data)
ndim = len(param_names)
idx = np.array([model.param_names.index(name) for name in param_names])

# Order vparam_names the same way it is ordered in the model:
vparam_names = [s for s in model.param_names if s in vparam_names]

ndim = len(vparam_names)
idx = np.array([model.param_names.index(name) for name in vparam_names])

# Check that z is bounded if it is being varied.
if bounds is None:
bounds = {}
if 'z' in param_names:
if 'z' in vparam_names:
if 'z' not in bounds or None in bounds['z']:
raise ValueError('z must be bounded if fit.')

# Drop data that the model doesn't cover.
data = cut_bands(data, model, z_bounds=bounds.get('z', None))

# Convert bounds indicies to integers
bounds_idx = dict([(param_names.index(name), bounds[name])
bounds_idx = dict([(vparam_names.index(name), bounds[name])
for name in bounds])

# define likelihood
Expand Down
15 changes: 11 additions & 4 deletions sncosmo/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def setup_class(self):

def test_fit_lc(self):
"""Ensure that fit results match input model parameters (data are
noise-free)."""
noise-free).
Pass in parameter names in order different from that stored in
model; tests parameter re-ordering."""
res, fitmodel = sncosmo.fit_lc(self.data, self.model,
['z', 't0', 'amplitude'],
['amplitude', 'z', 't0'],
bounds={'z': (0., 1.0)})

# set model to true parameters and compare to fit results.
Expand All @@ -70,14 +73,18 @@ def test_wrong_param_names(self):
res, fitmodel = sncosmo.fit_lc(self.data, self.model, [])

def test_nest_lc(self):
"""Ensure that nested sampling runs."""
"""Ensure that nested sampling runs.
Pass in parameter names in order different from that stored in
model; tests parameter re-ordering.
"""

np.random.seed(0) # seed the RNG for reproducible results.

self.model.set(**self.params)

res, fitmodel = sncosmo.nest_lc(
self.data, self.model, ['z', 't0', 'amplitude'],
self.data, self.model, ['amplitude', 'z', 't0'],
bounds={'z': (0., 1.0)}, guess_amplitude_bound=True, nobj=50)

assert_allclose(fitmodel.parameters, self.model.parameters, rtol=0.05)

0 comments on commit 00f90b0

Please sign in to comment.