In [1]:
import os
import numpy as np
import sncosmo
from astropy.table import Table

In [3]:
# -----------------------
# 1) Model registry (sncosmo)
# -----------------------
# Built-in sources that work out-of-the-box.
# You can extend this dict by registering more templates later (e.g., SNANA files, MOSFiT, Villar KN, etc.).
CLASS_TO_SOURCE = {
    "SNIa": "salt2",          # SALT2
    "91bg": "nugent-sn91bg",  # 91bg-like
    "SNIbc": "nugent-sn1bc",  # stripped-envelope
    "II": "nugent-sn2p",      # Type II-P
    # Placeholders below â€” add templates later with sncosmo.register(). For now we map to something reasonable:
    # "Iax": "nugent-sn91bg",   # placeholder
    # "KN": "nugent-sn1bc",     # placeholder
    # "SLSN": "nugent-sn1bc",   # placeholder
    # "ILOT": "nugent-sn2p",    # placeholder
    # "CART": "nugent-sn1bc",   # placeholder
    # "PISN": "nugent-sn2p",    # placeholder
}

LSST_BANDS = ["lsstu", "lsstg", "lsstr", "lssti", "lsstz", "lssty"]


In [4]:
# -----------------------
# 2) Helpers: parameter sampling and flux generation
# -----------------------
rng = np.random.default_rng(42)

def sample_params(class_name):
    """Return (model, param_dict). Sets z, t0, and amplitude/shape parameters."""
    src = CLASS_TO_SOURCE[class_name]
    model = sncosmo.Model(source=src)

    # Redshift/time of max
    z = rng.uniform(0.03, 0.15)          # narrow for speed/stability
    t0 = 0.0

    params = {"z": z, "t0": t0}

    if src == "salt2":
        # SALT2: x1 and c shape parameters, x0 scales amplitude
        params["x1"] = rng.normal(0.0, 1.0)
        params["c"]  = rng.normal(0.0, 0.1)
        # Choose target peak mag in lsstg, then set x0 accordingly
        target_peak_mag = rng.uniform(20.0, 22.5)
        x0 = sncosmo.models.SALT2Source()._fluxscale_for_peakmag(band="lsstg", mag=target_peak_mag, zpsys="ab")
        params["x0"] = x0
    else:
        # Nugent style models use overall amplitude
        target_peak_mag = rng.uniform(20.0, 22.5)
        # Convert desired peak magnitude in lsstg into amplitude.
        # We search amplitude to match peak mag.
        # Simple 1D search (few steps) is fine for a quick solution.
        amp = 1.0
        best_amp = amp
        best_err = 1e9
        for a in np.logspace(-6, 2, 30):
            model.set(z=z, t0=t0, amplitude=a)
            try:
                m = model.bandmag("lsstg", "ab", t0)  # near peak
                err = abs(m - target_peak_mag)
                if err < best_err:
                    best_err = err
                    best_amp = a
            except Exception:
                pass
        params["amplitude"] = best_amp

    model.set(**params)
    return model, params

def flux_grid(model, times, bands, zp=27.5, zpsys="ab", add_noise=None):
    """Return array shape (len(times), len(bands)) of fluxes at times for each band.
       add_noise: if float in mag, adds Gaussian noise per point in mag space."""
    n_t = len(times)
    n_b = len(bands)
    out = np.zeros((n_t, n_b), dtype=np.float32)
    for j, b in enumerate(bands):
        mag = model.bandmag(b, zpsys, times)   # length n_t
        if add_noise is not None:
            mag = mag + rng.normal(0.0, add_noise, size=mag.shape)
        # convert mag->flux at zp
        # flux = 10^(-0.4*(mag - zp))
        flux = 10.0 ** (-0.4 * (mag - zp))
        out[:, j] = flux.astype(np.float32)
    return out


In [5]:
# -----------------------
# 3) Dataset synthesis
# -----------------------
def simulate_pair(class_name, TDENSE=400, TVIS=60, tmin=-20.0, tmax=80.0,
                  visible_mag_noise=0.05):
    """
    Returns:
      dense_x:  (TDENSE, 1)
      dense_y:  (TDENSE, NBANDS)
      vis_x:    (TVIS, 1)
      vis_y:    (TVIS, NBANDS)
    """
    model, _ = sample_params(class_name)

    dense_times = np.linspace(tmin, tmax, TDENSE).astype(np.float32)
    vis_times   = np.linspace(tmin, tmax, TVIS).astype(np.float32)
    # You can randomize vis_times for variety, but keep TVIS fixed for array shapes.
    # vis_times += rng.normal(0.0, 0.5, size=vis_times.shape)

    dense_y = flux_grid(model, dense_times, LSST_BANDS, add_noise=None)           # hidden
    vis_y   = flux_grid(model, vis_times,   LSST_BANDS, add_noise=visible_mag_noise)  # visible noisy

    dense_x = dense_times.reshape(-1, 1)
    vis_x   = vis_times.reshape(-1, 1)

    return dense_x, dense_y, vis_x, vis_y


In [6]:
def build_dataset(classes, n_per_class=100, TDENSE=400, TVIS=60,
                  tmin=-20.0, tmax=80.0, visible_mag_noise=0.05):
    samples = []
    for cname in classes:
        for _ in range(n_per_class):
            dx, dy, vx, vy = simulate_pair(cname, TDENSE, TVIS, tmin, tmax, visible_mag_noise)
            samples.append((dx, dy, vx, vy))
    # Stack into arrays with shapes:
    # context_x: (N, TDENSE, 1)
    # context_y: (N, TDENSE, 6)
    # target_x:  (N, TVIS, 1)
    # target_y:  (N, TVIS, 6)
    context_x = np.stack([s[0] for s in samples], axis=0).astype(np.float32)
    context_y = np.stack([s[1] for s in samples], axis=0).astype(np.float32)
    target_x  = np.stack([s[2] for s in samples], axis=0).astype(np.float32)
    target_y  = np.stack([s[3] for s in samples], axis=0).astype(np.float32)
    return context_x, context_y, target_x, target_y

In [8]:
context_x, context_y, target_x, target_y = build_dataset(
        ["SNIa", "91bg", "SNIbc", "II"], n_per_class=100, TDENSE=400, TVIS=60
    )

FileNotFoundError: [Errno 2] No such file or directory: 'salt2_template_0.dat'