# Fit Multiple Observations

This tutorial shows how to model sources from images observed in different ways, which could mean images taken with the same instrument but different pointings and PSFs, or with different instruments. For this guide we will use a multi-band observation from the Hyper-Suprime Cam (HSC) and a single high-resolution image from the Hubble Space Telescope (HST).

In [None]:
import astropy.io.fits as fits
import astropy.units as u
# Import Packages
import jax.numpy as jnp
from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy.wcs import WCS

import scarlet2

## Load Data

We first load the HSC and HST images, PSFs and precomputeed weight/variance maps. We also load a catalog of sources detected jointly from the observations (see [here](https://github.com/astro-data-lab/scarlet-test-data/blob/main/scarlet_test_data/data/multiresolution_tutorial/get_source_catalog.py) for details on how this catalog was created).

In [None]:
from huggingface_hub import hf_hub_download

filename = hf_hub_download(repo_id="astro-data-lab/scarlet-test-data",
                           filename="multiresolution_tutorial/data.fits.gz",
                           repo_type="dataset")
with fits.open(filename) as hdul:
    # Load HSC observation
    data_hsc = jnp.array(hdul['HSC_OBS'].data, jnp.float32)
    wcs_hsc = WCS(hdul['HSC_OBS'].header)

    # Load HSC PSF and weights
    psf_hsc_data = jnp.array(hdul['HSC_PSF'].data, jnp.float32)
    obs_hsc_weights = jnp.array(hdul['HSC_WEIGHTS'].data, jnp.float32)

    # Load HST observation
    data_hst = jnp.array(hdul['HST_OBS'].data, jnp.float32)
    wcs_hst = WCS(hdul['HST_OBS'].header)

    # Load HST PSF and weights
    psf_hst_data = jnp.array(hdul['HST_PSF'].data, jnp.float32)
    obs_hst_weights = jnp.array(hdul['HST_WEIGHTS'].data, jnp.float32)

    # Load catalog table and metadata
    coords_table = Table(hdul['CATALOG'].data)
    radecsys = hdul['CATALOG'].header['RADECSYS']
    equinox = hdul['CATALOG'].header['EQUINOX']

In [None]:
# Write sources coordinates in SkyCoord
ra_dec = SkyCoord(ra=coords_table['RA'] * u.deg,
                  dec=coords_table['DEC'] * u.deg,
                  frame=radecsys.lower(),
                  equinox=f'J{equinox}')

## Create Frame and Observations

We have two different instruments with different pixel resolutions, so we need two different observations. Since the HST image is at a much higher resolution, we define our model `Frame` to use the HST PSF and the HST resolution. The high resolution and low resolution `Observation` are then matched to the model frame, to define the renderering operation. 

In [None]:
# Scarlet Observations
obs_hst = scarlet2.Observation(data_hst,
                               wcs=wcs_hst,
                               psf=scarlet2.ArrayPSF(psf_hst_data),
                               channels=['F814W'],
                               weights=obs_hst_weights)

obs_hsc = scarlet2.Observation(data_hsc,
                               wcs=wcs_hsc,
                               psf=scarlet2.ArrayPSF(psf_hsc_data),
                               channels=['g', 'r', 'i', 'z', 'y'],
                               weights=obs_hsc_weights)

Define the model frame by the union (or intersection) of the observation frames.

In [None]:
model_frame = scarlet2.Frame.from_observations(
    observations=[obs_hst, obs_hsc],
    coverage="union"  # or "intersection"
)

Finally we can visualize the detections for the multi-band HSC and single-band HST images in their native resolutions:

In [None]:
norm_hst = scarlet2.plot.AsinhAutomaticNorm(obs_hst)
norm_hsc = scarlet2.plot.AsinhAutomaticNorm(obs_hsc)

scarlet2.plot.observation(obs_hst, norm=norm_hst, sky_coords=ra_dec, show_psf=True, label_kwargs={'color': 'red'});
scarlet2.plot.observation(obs_hsc, norm=norm_hsc, sky_coords=ra_dec, show_psf=True);

## Initialize sources from multiple observations

In [None]:
import scarlet2.init as init

with scarlet2.Scene(model_frame) as scene:
    for i, center in enumerate(ra_dec):
        try:
            spectrum, morph = init.from_gaussian_moments([obs_hst, obs_hsc], center, min_corr=0.99)
        except ValueError:
            spectrum = init.pixel_spectrum([obs_hst, obs_hsc], center)
            morph = init.compact_morphology()
        scarlet2.Source(center, spectrum, morph)

In [None]:
scarlet2.plot.scene(scene,
                    observation=obs_hsc,
                    show_rendered=True,
                    show_observed=True,
                    show_residual=True,
                    add_boxes=True,
                    norm=norm_hsc);
scarlet2.plot.scene(scene,
                    observation=obs_hst,
                    show_rendered=True,
                    show_observed=True,
                    show_residual=True,
                    norm=norm_hst,
                    add_boxes=True,
                    label_kwargs={'color': 'red'});

## Fit Multiple Observations

The definition of the parameters follows our general recommendation (from e.g. the [quickstart guide](../0-quickstart)):

In [None]:
from numpyro.distributions import constraints
from functools import partial
from scarlet2.module import relative_step

spec_step = partial(relative_step, factor=0.05)
morph_step = partial(relative_step, factor=1e-3)

parameters = scene.make_parameters()
for i in range(len(scene.sources)):
    parameters += scarlet2.Parameter(scene.sources[i].spectrum,
                                     name=f"spectrum.{i}",
                                     constraint=constraints.positive,
                                     stepsize=spec_step)
    parameters += scarlet2.Parameter(scene.sources[i].morphology,
                                     name=f"morph.{i}",
                                     constraint=constraints.unit_interval,
                                     stepsize=morph_step)

 But the initial linear solver for the spectrum amplitudes and the fitting method receive lists of observations now:

In [None]:
scene.set_spectra_to_match([obs_hsc, obs_hst], parameters)
scene_ = scene.fit([obs_hsc, obs_hst], parameters, max_iter=100, progress_bar=False)

The result of this operation is a much more accurate model for both observations (although it could use a few more iterations):

In [None]:
scarlet2.plot.scene(scene_,
                    observation=obs_hsc,
                    show_rendered=True,
                    show_observed=True,
                    show_residual=True,
                    add_labels=True,
                    add_boxes=True,
                    norm=norm_hsc);
scarlet2.plot.scene(scene_,
                    observation=obs_hst,
                    show_rendered=True,
                    show_observed=True,
                    show_residual=True,
                    add_labels=True,
                    add_boxes=True,
                    norm=norm_hst,
                    box_kwargs={'edgecolor': 'red', 'facecolor': 'none'},
                    label_kwargs={'color': 'red'});