# Fit Transients

This guide shows how to model a series of multi-epoch images containing a transient.

In [None]:
import astropy.io.fits as fits

# Import Packages and setup
import jax.numpy as jnp
import matplotlib.pyplot as plt
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS

import scarlet2

We will load four ZTF images, in g and r band, from before and after the appearance of the transient. To speed up the processing and fitting, we have already resampled all images and PSFs to the same wcs using swarp, so we can create one {py:class}`~scarlet2.Observation` to hold all four images:

In [None]:
from huggingface_hub import hf_hub_download

filename = hf_hub_download(
    repo_id="astro-data-lab/scarlet-test-data",
    filename="transient_tutorial/data.fits.gz",
    repo_type="dataset",
)

data = []
weight = []
psf = []
channels = []
with fits.open(filename) as hdul:
    for i in range(4):
        # getting observation, weights, and PFS for each epoch
        idx = i * 3
        header = hdul[idx].header
        print("loading", header["FILENAME"])
        if i == 0:
            wcs = WCS(header)
        data.append(hdul[idx].data)
        weight.append(hdul[idx + 1].data)
        psf.append(hdul[idx + 2].data)

        # channel: combined band and epoch identifier
        # any labels is valid for each, e.g. timestamps
        channels.append((header["FILTER"], i))

obs = scarlet2.Observation(
    jnp.array(data).astype(float),
    weights=jnp.array(weight).astype(float),
    psf=scarlet2.ArrayPSF(jnp.array(psf).astype(float)),
    wcs=wcs,
    channels=channels,
)

```{note}
The `channel` attribute of the frame can be extended to identify different bands and epochs.
```

If we want to avoid the preprocessing step to resample and align the images (and PSFs), one can treat each image as its own {py:class}`~scarlet2.Observation` and resampling them on-the-fly (see the [multi-resolution tutorial](multiresolution) for details), but that makes the fitting much more computationally demanding.

But let's look at what we have:

In [None]:
norm = scarlet2.plot.AsinhAutomaticNorm(obs)
scarlet2.plot.observation(obs, norm=norm, add_labels=False, show_psf=True);

This is very much a false-color image because each channel (here there are four) is interpreted as a distinct band, but we actually only have g and r-bands. Because our plotting routines assumes channels are ordered with increasing wavelength, and the transient appears in the two latter epochs ,it's visible as an excess in red in the color image above, slightly to the left of the source center. To see the channels separately, use the `split_channels` option for {py:func}`scarlet2.plot.observation`:



In [None]:
scarlet2.plot.observation(obs, add_labels=False, show_psf=True, split_channels=True);

## Define Transient Scene

We first need to define a model frame, which covers the same sky area as the data. As the ZTF PSF is not well-sampled, we reduce the internal model PSF to a very narrow Gaussian:

In [None]:
model_psf = scarlet2.GaussianPSF(sigma=0.5)
model_frame = scarlet2.Frame.from_observations(obs, model_psf=model_psf)

In _scarlet2_ we treat transients as sources that have independent amplitudes in every band and epoch (defined by {py:class}`~scarlet2.TransientArraySpectrum`), while static sources only have independent amplitudes in every band, i.e. their `spectrum` are shared across all epochs (implemented in {py:class}`~scarlet2.StaticArraySpectrum`). If we know that the transient is "off" for some epochs (e.g. pre-explosion), we can set those amplitudes to zero.

As our model frame treats the channels as a combined (band, epoch) identifier, the `spectrum` attributes for every source inherit this overloaded definition. So, we need to take care to set/fit the elements of this generalized spectrum vector correctly. For that purpose, we define lookup functions (`band_selector` and `epoch_selector`), which operate on the channel information and return the band or the epoch, respectively.

We can now define a {py:class}`~scarlet2.Scene`:

In [None]:
# coordinates of the transient
ra = 215.39425925333
dec = 37.90971372
coord = SkyCoord(ra, dec, unit="deg")

# separate channel information into band and epoch: 0 and 1 element
# depends on how channels encodes multi-epoch information
band_selector = lambda channel: channel[0]
epoch_selector = lambda channel: channel[1]

with scarlet2.Scene(model_frame) as scene:
    # 1) Host galaxy that is static across epochs
    try:
        spectrum, morph = scarlet2.init.from_gaussian_moments(obs, coord, box_sizes=[15, 21])
    except IndexError:
        morph = scarlet2.init.compact_morphology()
    # the host is barely resolved and the data are noisy:
    # use a starlet morphology for extra stability (esp to noise)
    morph = scarlet2.StarletMorphology.from_image(morph)

    # Select the transient-free epochs to initialize amplitudes for the static source
    # These will be shared across all epochs
    spectrum = spectrum[0:2]
    bands = ["ZTF_g", "ZTF_r"]
    scarlet2.Source(
        coord, scarlet2.StaticArraySpectrum(spectrum, bands=bands, band_selector=band_selector), morph
    )

    # 2) Point source for the transient, placed initially at same center
    # Define the epochs where the transient is allowed to have a non-zero amplitude
    epochs = [2, 3]
    # As we already know that the transient is present, we can measure the flux at the center location
    # This will be a mixture of host and transient light, to be corrected by the fitting procedure
    # Initializing as zero also works
    spectrum = scarlet2.init.pixel_spectrum(obs, coord)
    scarlet2.PointSource(
        coord, scarlet2.TransientArraySpectrum(spectrum, epochs=epochs, epoch_selector=epoch_selector)
    )

print(scene.sources)

## Fitting

Fitting works as usual by defining the {py:class}`~scarlet2.Parameters`. Because the two spectra and the host morphology (of type {py:class}`~scarlet2.StarletMorphology`) aren't simple arrays but models themselves, their free parameters are the array attributes `.data` and `.coeffs`, respectively, as show in the source definition above, e.g. `TransientArraySpectrum(data=f32[4],...)`. These fundamental degrees of freedom of the scene is what we have to pass to the parameters class:

In [None]:
from numpyro.distributions import constraints

pos_step = 1e-2
morph_step = lambda p: scarlet2.relative_step(p, factor=1e-3)
SED_step = lambda p: scarlet2.relative_step(p, factor=5e-2)

parameters = scene.make_parameters()
# Static host galaxy parameters
parameters += scarlet2.Parameter(
    scene.sources[0].spectrum.data, name=f"spectrum.{0}", constraint=constraints.positive, stepsize=SED_step
)
parameters += scarlet2.Parameter(
    scene.sources[0].morphology.coeffs,
    name=f"morph.{0}",
    stepsize=morph_step,
)

# Transient point source parameters:
# no positive constraint on spectrum because it can be zero
parameters += scarlet2.Parameter(scene.sources[1].spectrum.data, name=f"spectrum.{1}", stepsize=SED_step)
parameters += scarlet2.Parameter(
    scene.sources[1].center, name=f"center.{1}", constraint=constraints.positive, stepsize=pos_step
)

In [None]:
# Fit the scene
stepnum = 1000
scene_ = scene.fit(obs, parameters, max_iter=stepnum, e_rel=1e-4, progress_bar=False)

## Inspect Result

In [None]:
# Plot the model, for each epoch
scarlet2.plot.scene(
    scene_,
    observation=obs,
    norm=norm,
    show_model=True,
    show_observed=True,
    show_rendered=True,
    show_residual=True,
    add_labels=True,
    add_boxes=True,
    split_channels=False,
    box_kwargs={"edgecolor": "red", "facecolor": "none"},
    label_kwargs={"color": "red"},
)
plt.show()

Looks good, modest reddening on the left of the center, with no noticeable residuals. Here are the best-fitting fluxes:

In [None]:
print("----------------- {}".format(channels))
for k, src in enumerate(scene_.sources):
    print("Source {}, Fluxes: {}".format(k, scarlet2.measure.flux(src)))

Note that the host galaxy, source 0, has the same flux in each epoch of the same band, while the transient, source 1, has zero flux in the epochs where we forced it to be 'off'.


Note that the host galaxy, source 0, has the same flux in each epoch of the same band, while the transient, source 1, has zero flux in the epochs where we forced it to be 'off'.
