Skip to content

Commit

Permalink
Merge pull request #296 from djones1040/master
Browse files Browse the repository at this point in the history
new SALT3 model
  • Loading branch information
kboone committed Apr 12, 2021
2 parents 750865e + 9839def commit 987b688
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 29 deletions.
5 changes: 3 additions & 2 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ Model & Components
TimeSeriesSource
StretchSource
SALT2Source

SALT3Source

*Effect components of Model: interstellar dust extinction*

.. autosummary::
Expand Down Expand Up @@ -130,7 +131,7 @@ magnitude systems*
Class Inheritance Diagrams
==========================

.. inheritance-diagram:: Source TimeSeriesSource StretchSource SALT2Source
.. inheritance-diagram:: Source TimeSeriesSource StretchSource SALT2Source SALT3Source
:parts: 1

.. inheritance-diagram:: PropagationEffect F99Dust OD94Dust CCM89Dust
Expand Down
16 changes: 15 additions & 1 deletion sncosmo/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ABMagSystem, CompositeMagSystem, SpectralMagSystem, _MAGSYSTEMS)

from .models import (
MLCS2k2Source, SALT2Source, SNEMOSource, SUGARSource,
MLCS2k2Source, SALT2Source, SALT3Source, SNEMOSource, SUGARSource,
TimeSeriesSource, _SOURCES)

from .specmodel import SpectrumModel
Expand Down Expand Up @@ -484,6 +484,11 @@ def load_salt2model(relpath, name=None, version=None):
return SALT2Source(modeldir=abspath, name=name, version=version)


def load_salt3model(relpath, name=None, version=None):
abspath = DATADIR.abspath(relpath, isdir=True)
return SALT3Source(modeldir=abspath, name=name, version=version)


def load_2011fe(relpath, name=None, version=None):

# filter warnings about RADESYS keyword in files
Expand Down Expand Up @@ -638,6 +643,15 @@ def load_2011fe(relpath, name=None, version=None):
args=('models/pierel/salt2-extended',), version='2.0',
meta=meta)

# SALT3
meta = {'type': 'SN Ia',
'subclass': '`~sncosmo.SALT3Source`',
'url': 'https://salt3.readthedocs.io/en/latest/',
'note': "See Kenworthy et al. 2021, ApJ, submitted."}
_SOURCES.register_loader('salt3', load_salt3model,
args=('models/salt3/salt3-k21',), version='1.0',
meta=meta)

meta = {'type': 'SN Ia',
'subclass': '`~sncosmo.SALT2Source`',
'url': 'http://snana.uchicago.edu/',
Expand Down
172 changes: 166 additions & 6 deletions sncosmo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
from .utils import integration_grid

__all__ = ['get_source', 'Source', 'TimeSeriesSource', 'StretchSource',
'SUGARSource', 'SALT2Source', 'MLCS2k2Source', 'SNEMOSource',
'Model', 'PropagationEffect', 'CCM89Dust', 'OD94Dust', 'F99Dust']
'SUGARSource', 'SALT2Source', 'SALT3Source', 'MLCS2k2Source',
'SNEMOSource', 'Model', 'PropagationEffect', 'CCM89Dust',
'OD94Dust', 'F99Dust']

_SOURCES = Registry()

Expand Down Expand Up @@ -952,7 +953,7 @@ def _set_colorlaw_from_file(self, name_or_obj):

# If there are more than 1+ncoeffs words in the file, we expect them to
# be of the form `keyword value`.
version = 0
version = None
colorlaw_range = [3000., 7000.]
for i in range(1+ncoeffs, len(words), 2):
if words[i] == 'Salt2ExtinctionLaw.version':
Expand All @@ -966,12 +967,12 @@ def _set_colorlaw_from_file(self, name_or_obj):

# Set extinction function to use.
if version == 0:
raise Exception("Salt2ExtinctionLaw.version 0 not supported.")
raise RuntimeError("Salt2ExtinctionLaw.version 0 not supported.")
elif version == 1:
self._colorlaw = SALT2ColorLaw(colorlaw_range, colorlaw_coeffs)
else:
raise Exception('unrecognized Salt2ExtinctionLaw.version: ' +
version)
raise RuntimeError('unrecognized Salt2ExtinctionLaw.version: ' +
version)

def colorlaw(self, wave=None):
"""Return the value of the CL function for the given wavelengths.
Expand All @@ -997,6 +998,165 @@ def colorlaw(self, wave=None):
return self._colorlaw(wave)


class SALT3Source(SALT2Source):
"""The SALT3 Type Ia supernova spectral timeseries model.
Kenworthy et al., 2021, ApJ, submitted. Model definitions
are the same as SALT2 except for the errors, which are now
given in flux space. Unlike SALT2, no file is used for scaling
the errors.
The spectral flux density of this model is given by
.. math::
F(t, \\lambda) = x_0 (M_0(t, \\lambda) + x_1 M_1(t, \\lambda))
\\times 10^{-0.4 CL(\\lambda) c}
where ``x0``, ``x1`` and ``c`` are the free parameters of the model,
``M_0``, ``M_1`` are the zeroth and first components of the model, and
``CL`` is the colorlaw, which gives the extinction in magnitudes for
``c=1``.
Parameters
----------
modeldir : str, optional
Directory path containing model component files. Default is `None`,
which means that no directory is prepended to filenames when
determining their path.
m0file, m1file, clfile : str or fileobj, optional
Filenames of various model components. Defaults are:
* m0file = 'salt2_template_0.dat' (2-d grid)
* m1file = 'salt2_template_1.dat' (2-d grid)
* clfile = 'salt2_color_correction.dat'
lcrv00file, lcrv11file, lcrv01file, cdfile : str or fileobj
(optional) Filenames of various model components for
model covariance in synthetic photometry. See
``bandflux_rcov`` for details. Defaults are:
* lcrv00file = 'salt2_lc_relative_variance_0.dat' (2-d grid)
* lcrv11file = 'salt2_lc_relative_variance_1.dat' (2-d grid)
* lcrv01file = 'salt2_lc_relative_covariance_01.dat' (2-d grid)
* cdfile = 'salt2_color_dispersion.dat' (1-d grid)
Notes
-----
The "2-d grid" files have the format ``<phase> <wavelength>
<value>`` on each line.
The phase and wavelength values of the various components don't
necessarily need to match. (In the most recent salt2 model data,
they do not all match.) The phase and wavelength values of the
first model component (in ``m0file``) are taken as the "native"
sampling of the model, even though these values might require
interpolation of the other model components.
"""

_param_names = ['x0', 'x1', 'c']
param_names_latex = ['x_0', 'x_1', 'c']
_SCALE_FACTOR = 1e-12

def __init__(self, modeldir=None,
m0file='salt3_template_0.dat',
m1file='salt3_template_1.dat',
clfile='salt3_color_correction.dat',
cdfile='salt3_color_dispersion.dat',
lcrv00file='salt3_lc_variance_0.dat',
lcrv11file='salt3_lc_variance_1.dat',
lcrv01file='salt3_lc_covariance_01.dat',
name=None, version=None):

self.name = name
self.version = version
self._model = {}
self._parameters = np.array([1., 0., 0.])

names_or_objs = {'M0': m0file, 'M1': m1file,
'LCRV00': lcrv00file, 'LCRV11': lcrv11file,
'LCRV01': lcrv01file,
'cdfile': cdfile, 'clfile': clfile}

# Make filenames into full paths.
if modeldir is not None:
for k in names_or_objs:
v = names_or_objs[k]
if (v is not None and isinstance(v, str)):
names_or_objs[k] = os.path.join(modeldir, v)

# model components are interpolated to 2nd order
for key in ['M0', 'M1']:
phase, wave, values = read_griddata_ascii(names_or_objs[key])
values *= self._SCALE_FACTOR
self._model[key] = BicubicInterpolator(phase, wave, values)

# The "native" phases and wavelengths of the model are those
# of the first model component.
if key == 'M0':
self._phase = phase
self._wave = wave

# model covariance is interpolated to 1st order
for key in ['LCRV00', 'LCRV11', 'LCRV01']:
phase, wave, values = read_griddata_ascii(names_or_objs[key])
self._model[key] = BicubicInterpolator(phase, wave, values)

# Set the colorlaw based on the "color correction" file.
self._set_colorlaw_from_file(names_or_objs['clfile'])

# Set the color dispersion from "color_dispersion" file
w, val = np.loadtxt(names_or_objs['cdfile'], unpack=True)
self._colordisp = Spline1d(w, val, k=1) # linear interp.

def _bandflux_rvar_single(self, band, phase):
"""Model relative variance for a single bandpass."""

# Raise an exception if bandpass is out of model range.
if (band.minwave() < self._wave[0] or band.maxwave() > self._wave[-1]):
raise ValueError('bandpass {0!r:s} [{1:.6g}, .., {2:.6g}] '
'outside spectral range [{3:.6g}, .., {4:.6g}]'
.format(band.name, band.wave[0], band.wave[-1],
self._wave[0], self._wave[-1]))

x1 = self._parameters[1]

# integrate m0 and m1 components
wave, dwave = integration_grid(band.minwave(), band.maxwave(),
MODEL_BANDFLUX_SPACING)
trans = band(wave)
m0 = self._model['M0'](phase, wave)
m1 = self._model['M1'](phase, wave)
tmp = trans * wave
f0 = np.sum(m0 * tmp, axis=1) * dwave / HC_ERG_AA
m1int = np.sum(m1 * tmp, axis=1) * dwave / HC_ERG_AA
ftot = f0 + x1 * m1int

# In the following, the "[:,0]" reduces from a 2-d array of shape
# (nphase, 1) to a 1-d array.
lcrv00 = self._model['LCRV00'](phase, band.wave_eff)[:, 0]
lcrv11 = self._model['LCRV11'](phase, band.wave_eff)[:, 0]
lcrv01 = self._model['LCRV01'](phase, band.wave_eff)[:, 0]

v = lcrv00 + 2.0 * x1 * lcrv01 + x1 * x1 * lcrv11

# v is supposed to be variance but can go negative
# due to interpolation. Correct negative values to some small
# number. (at present, use prescription of snfit : set
# negatives to 0.0001)
v[v < 0.0] = 0.0001

# avoid warnings due to evaluating 0. / 0. in f0 / ftot
with np.errstate(invalid='ignore'):
# new SALT3 error prescription
result = v/(ftot/(trans*wave*dwave).sum())/HC_ERG_AA/1e12

# treat cases where ftot is negative the same as snfit
result[ftot <= 0.0] = 10000.

return result


class MLCS2k2Source(Source):
"""A spectral time series model based on the MLCS2k2 model light curves,
using the Hsiao template at each phase, mangled to match the model
Expand Down
126 changes: 106 additions & 20 deletions sncosmo/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def setup_class(self):
phase = np.linspace(0., 100., 10)
wave = np.linspace(1000., 10000., 100)
vals1d = np.zeros(len(phase), dtype=np.float64)
vals = np.ones([len(phase), len(wave)], dtype=np.float64)

# Create some 2-d grid files
files = []
Expand Down Expand Up @@ -126,26 +127,111 @@ def setup_class(self):
cdfile.close()
clfile.close()

def test_bandflux_rcov(self):

# component 1:
# ans = (F0/F1)^2 S^2 (V00 + 2 x1 V01 + x1^2 V11)
# when x1=0, this reduces to S^2 V00 = 1^2 * 0.01 = 0.01
#
# component 2:
# cd^2 = 0.04

band = ['bessellb', 'bessellb', 'bessellr', 'bessellr',
'besselli']
phase = [10., 20., 30., 40., 50.]
self.source.set(x1=0.0)
result = self.source.bandflux_rcov(band, phase)
expected = np.array([[0.05, 0.04, 0., 0., 0.],
[0.04, 0.05, 0., 0., 0.],
[0., 0., 0.05, 0.04, 0.],
[0., 0., 0.04, 0.05, 0.],
[0., 0., 0., 0., 0.05]])
assert_allclose(result, expected)
def test_bandflux_rcov(self):

# component 1:
# ans = (F0/F1)^2 S^2 (V00 + 2 x1 V01 + x1^2 V11)
# when x1=0, this reduces to S^2 V00 = 1^2 * 0.01 = 0.01
#
# component 2:
# cd^2 = 0.04
phase = np.linspace(0., 100., 10)
wave = np.linspace(1000., 10000., 100)
vals = np.ones([len(phase), len(wave)], dtype=np.float64)

band = ['bessellb', 'bessellb', 'bessellr', 'bessellr',
'besselli']
band = np.array([sncosmo.get_bandpass(b) for b in band])
phase = np.array([10., 20., 30., 40., 50.])
self.source.set(x1=0.0)
result = self.source.bandflux_rcov(band, phase)
expected = np.array([[0.05, 0.04, 0., 0., 0.],
[0.04, 0.05, 0., 0., 0.],
[0., 0., 0.05, 0.04, 0.],
[0., 0., 0.04, 0.05, 0.],
[0., 0., 0., 0., 0.05]])
assert_allclose(result, expected)


class TestSALT3Source:

def setup_class(self):
"""Create a SALT3 model with a lot of components set to 1."""

phase = np.linspace(0., 100., 10)
wave = np.linspace(1000., 10000., 100)
vals1d = np.zeros(len(phase), dtype=np.float64)
vals = np.ones([len(phase), len(wave)], dtype=np.float64)

# Create some 2-d grid files
files = []
for i in [0, 1]:
f = StringIO()
sncosmo.write_griddata_ascii(phase, wave, vals, f)
f.seek(0) # return to start of file.
files.append(f)

# CL file. The CL in magnitudes will be
# CL(wave) = -(wave - B) / (V - B) [B = 4302.57, V = 5428.55]
# and transmission will be 10^(-0.4 * CL(wave))^c
clfile = StringIO()
clfile.write("1\n"
"0.0\n"
"Salt2ExtinctionLaw.version 1\n"
"Salt2ExtinctionLaw.min_lambda 3000\n"
"Salt2ExtinctionLaw.max_lambda 8000\n")
clfile.seek(0)

# Create some more 2-d grid files
for factor in [1., 0.01, 0.01, 0.01]:
f = StringIO()
sncosmo.write_griddata_ascii(phase, wave, factor * vals, f)
f.seek(0) # return to start of file.
files.append(f)

# Create a 1-d grid file (color dispersion)
cdfile = StringIO()
for w in wave:
cdfile.write("{0:f} {1:f}\n".format(w, 0.2))
cdfile.seek(0) # return to start of file.

# Create a SALT2Source
self.source = sncosmo.SALT3Source(m0file=files[0],
m1file=files[1],
clfile=clfile,
lcrv00file=files[3],
lcrv11file=files[4],
lcrv01file=files[5],
cdfile=cdfile)

for f in files:
f.close()
cdfile.close()
clfile.close()

def test_bandflux_rcov(self):

# component 1:
# ans = (F0/F1)^2 S^2 (V00 + 2 x1 V01 + x1^2 V11)
# when x1=0, this reduces to S^2 V00 = 1^2 * 0.01 = 0.01
#
# component 2:
# cd^2 = 0.04
phase = np.linspace(0., 100., 10)
wave = np.linspace(1000., 10000., 100)
vals = np.ones([len(phase), len(wave)], dtype=np.float64)
band = ['bessellb', 'bessellb', 'bessellr', 'bessellr',
'besselli']
band = np.array([sncosmo.get_bandpass(b) for b in band])
phase = np.array([10., 20., 30., 40., 50.])
self.source.set(x1=0.0)
result = self.source.bandflux_rcov(band, phase)
expected = np.array([[0.05, 0.04, 0., 0., 0.],
[0.04, 0.05, 0., 0., 0.],
[0., 0., 0.05, 0.04, 0.],
[0., 0., 0.04, 0.05, 0.],
[0., 0., 0., 0., 0.05]])
assert_allclose(result, expected)


class TestModel:
Expand Down

0 comments on commit 987b688

Please sign in to comment.