Skip to content

Commit

Permalink
Merge 5a9784c into 401c1b4
Browse files Browse the repository at this point in the history
  • Loading branch information
djperrefort committed Feb 24, 2020
2 parents 401c1b4 + 5a9784c commit aabd41f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
6 changes: 2 additions & 4 deletions sncosmo/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,7 @@ def fit_lc(data, model, vparam_names, bounds=None, method='minuit',
vparam_names = [s for s in model.param_names if s in vparam_names]

# initialize bounds
if bounds is None:
bounds = {}
bounds = copy.deepcopy(bounds) if bounds else {}

# Check that 'z' is bounded (if it is going to be fit).
if 'z' in vparam_names:
Expand Down Expand Up @@ -1067,8 +1066,7 @@ def mcmc_lc(data, model, vparam_names, bounds=None, priors=None,
# Make a copy of the model so we can modify it with impunity.
model = copy.copy(model)

if bounds is None:
bounds = {}
bounds = copy.deepcopy(bounds) if bounds else {}
if priors is None:
priors = {}

Expand Down
80 changes: 71 additions & 9 deletions sncosmo/tests/test_fitting.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,40 @@
# Licensed under a 3-clause BSD style license - see LICENSES

from copy import deepcopy
from os.path import dirname, join

import pytest
import numpy as np
from numpy.random import RandomState
from numpy.testing import assert_allclose, assert_almost_equal
import pytest
from astropy.table import Table
from numpy.random import RandomState
from numpy.testing import assert_allclose

import sncosmo

try:
import iminuit

HAS_IMINUIT = True

except ImportError:
HAS_IMINUIT = False

try:
import nestle

HAS_NESTLE = True

except ImportError:
HAS_NESTLE = False

try:
import emcee

HAS_EMCEE = True

except ImportError:
HAS_EMCEE = False


class TestFitting:
def setup_class(self):
Expand All @@ -38,12 +51,14 @@ def setup_class(self):
model.set(**params)
flux = model.bandflux(bands, times, zp=zp, zpsys=zpsys)
fluxerr = len(bands) * [0.1 * np.max(flux)]
data = Table({'time': times,
'band': bands,
'flux': flux,
'fluxerr': fluxerr,
'zp': zp,
'zpsys': zpsys})
data = Table({
'time': times,
'band': bands,
'flux': flux,
'fluxerr': fluxerr,
'zp': zp,
'zpsys': zpsys
})

# reset parameters
model.set(z=0., t0=0., amplitude=1.)
Expand All @@ -52,6 +67,53 @@ def setup_class(self):
self.data = data
self.params = params

def _test_mutation(self, fit_func):
"""Test a fitting function does not mutate arguments"""

# Some fitting functions require bounds for all varied parameters
bounds = {}
for param, param_val in self.params.items():
bounds[param] = (param_val * .95, param_val * 1.05)

# Preserve original input data
vparams = list(self.params.keys())
test_data = deepcopy(self.data)
test_model = deepcopy(self.model)
test_bounds = deepcopy(bounds)
test_vparams = deepcopy(vparams)

# Check for argument mutation
fit_func(test_data, test_model, test_vparams, bounds=test_bounds)
param_preserved = all(a == b for a, b in zip(vparams, test_vparams))
model_preserved = all(
a == b for a, b in
zip(self.model.parameters, test_model.parameters)
)

err_msg = '``{}`` argument was mutated'
assert all(self.data == test_data), err_msg.format('data')
assert bounds == test_bounds, err_msg.format('bounds')
assert param_preserved, err_msg.format('vparam_names')
assert model_preserved, err_msg.format('model')

@pytest.mark.skipif('not HAS_IMINUIT')
def test_fitlc_arg_mutation(self):
"""Test ``fit_lc`` does not mutate it's arguments"""

self._test_mutation(sncosmo.fit_lc)

@pytest.mark.skipif('not HAS_NESTLE')
def test_nestlc_arg_mutation(self):
"""Test ``nest_lc`` does not mutate it's arguments"""

self._test_mutation(sncosmo.nest_lc)

@pytest.mark.skipif('not HAS_EMCEE')
def test_mcmclc_arg_mutation(self):
"""Test ``mcmc_lc`` does not mutate it's arguments"""

self._test_mutation(sncosmo.mcmc_lc)

@pytest.mark.skipif('not HAS_IMINUIT')
def test_fit_lc(self):
"""Ensure that fit results match input model parameters (data are
Expand Down

0 comments on commit aabd41f

Please sign in to comment.