# Multi-resolution Deblending

This tutorial shows how to deblending images taken from different telescopes with different resolutions. In this case we will use a multiband observation with the Hyper-Sprime Cam (HSC) and a single band, high resolution image from teh Hubble Space Telescope (HST). Before using this tutorial you should be familiar with the *scarlet* [User Guide](../user_docs.ipynb) and how to deblend single resolution images.

In [None]:
# Import Packages and setup
import logging

import numpy as np

import scarlet
import scarlet.display
import astropy.io.fits as fits
from astropy.wcs import WCS
from astropy.visualization.lupton_rgb import AsinhMapping, LinearMapping
import sep

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='gist_stern')
matplotlib.rc('image', interpolation='none')

## Load and display the sample data

We first load the HSC and HST images, swapping the byte order if necessary because a bug in astropy does not respect the local endianness.

In [None]:
# Load the HSC image data
obs_hdu = fits.open('../../data/test_resampling/Cut_HSC.fits')
data_hsc = obs_hdu[0].data.byteswap().newbyteorder()
wcs_hsc = WCS(obs_hdu[0].header)
# Load the HSC PSF data
psf_hsc = fits.open('../../data/test_resampling/PSF_HSC.fits')[0].data
x,y = np.where(psf_hsc[0]*0==0)
Np1, Np2 = psf_hsc[0].shape

# Load the HST image data
hst_hdu = fits.open('../../data/test_resampling/Cut_HST.fits')
data_hst = hst_hdu[0].data
wcs_hst = WCS(hst_hdu[0].header)
#appply wcs correction
wcs_hst.wcs.crval += 2.4750118475607095e-05*np.array([-np.cos(0.4136047623181346 ), -np.sin(0.4136047623181346 )])
# Load the HST PSF data
psf_hst = fits.open('../../data/test_resampling/PSF_HST.fits')[0].data
np1,np2 = np.shape(psf_hst)
psf_hst = psf_hst.reshape(1,np1,np2)

# Scale the HST data
n1,n2 = np.shape(data_hst)
data_hst = data_hst.reshape(1, n1, n2).byteswap().newbyteorder()*np.max(data_hsc)/np.max(data_hst)

r, N1, N2 = data_hsc.shape

Next we have to create a source catalog for the images. Since HST is higher resolution and prone to less blending, we use it for detection but we also run detection on the HSC image to calculate the background RMS:

In [None]:
def makeCatalog(img):
    if np.size(img.shape) == 3:
        detect = img.mean(axis=0) # simple average for detection
    else:
        detect = img#.byteswap().newbyteorder()
        
    bkg = sep.Background(detect)
    catalog = sep.extract(detect, 4, err=bkg.globalrms)
    if np.size(img.shape) == 3:
        bg_rms = np.array([sep.Background(band).globalrms for band in img])
    else:
        bg_rms =  sep.Background(detect).globalrms
    return catalog, bg_rms

catalog_hst, bg_rms_hst = makeCatalog(data_hst)
catalog_hsc, bg_rms_hsc = makeCatalog(data_hsc)

Finally we can visualize both the multiband HSC and single band HST images in their native resolutions:

In [None]:
# Create a color mapping for the HSC image
hsc_norm = AsinhMapping(minimum=data_hsc.min(), stretch=data_hsc.max()/20, Q=10)

# Get the source coordinates from the HST catalog
xo,yo = catalog_hst['x'], catalog_hst['y']
# Convert the HST coordinates to the HSC WCS
ra, dec = wcs_hst.wcs_pix2world(yo,xo,0)
Yo,Xo, l = wcs_hsc.wcs_world2pix(ra, dec, 0, 0)

# Map the HSC image to RGB
img_rgb = scarlet.display.img_to_rgb(data_hsc, norm=hsc_norm)
# Apply Asinh to the HST data
hst_img = np.arcsinh(data_hst[0])

plt.subplot(121)
plt.imshow(img_rgb)
plt.plot(Xo,Yo, 'o')
plt.subplot(122)
plt.imshow(hst_img)
plt.show()

## Create the Scene and Observations

Unlike the single resolution examples, we now have two different instruments with different pixel resolutions, so we need two different observations. Since the HST image is at a much higher resolution we define our [Scene](../observation.ipynb#scarlet.observation.Scene) to use the HST PSF in the HST resolution, so our HST observation is a regular [Observation](../observation.ipynb#scarlet.observation.Observation). Since the HSC images need to be upsampled we use the [LowResObservation](../observation.ipynb#scarlet.observation.LowResObservation) class.

In order to tell the scene which bands belong to which observation, we use the `structure` parameter to define the mapping to the 4 bands in the [Scene](../observation.ipynb#scarlet.observation.Scene) for both observations, where a `1` in the `structure` tells the scene to use a band in the observation.

In [None]:
# Map the HST band to the last band in the scene
obs_hst = scarlet.Observation(data_hst, wcs = wcs_hst, psfs = None, structure = np.array([0,0,0,1]))
# Map the HSC bands to the first three bands in the scene
obs_hsc = scarlet.LowResObservation(data_hsc,  wcs = wcs_hsc, psfs = psf_hsc, structure = np.array([1,1,1,0]))
# Initialize the Scene using the HST PSF and WCS
scene = scarlet.Scene((r+1, n1,n2), wcs = wcs_hst, psfs = psf_hst)

# Keep the order of the observations consistent with the `structure` parameter
# This implementation is a bit of a hack and will be refined in the future
obs = [obs_hsc, obs_hst]

# Set the background levels for both the HSC and HST images
bg_rms = np.concatenate((bg_rms_hsc, bg_rms_hst))

## Initialize the Sources

The standard sources discussed in the [User Guide](../user_docs.ipynb#Components-and-Sources) are designed to be initialized with a single dataset. For multi-resolution models we need to use the [CombinedExtendedSource](../source.ipynb#scarlet.source.CombinedExtendedSource) to initialize extended sources by using the morphology from a high resolution observation and the SED from all observations. Because the initialization takes a list of observations, the `obs_idx` argument tells the [CombinedExtendedSource](../source.ipynb#scarlet.source.CombinedExtendedSource) which observation in the list of observations is used to initialize the morphology.

In [None]:
sources = [
    scarlet.CombinedExtendedSource((ra[i], dec[i]), scene, obs, bg_rms, symmetric=False, monotonic=True, obs_idx=1)
    for i in range(ra.size)
]

## Intialize the blend

The process of initializing a [Blend](../blend.ipynb#scarlet.blend.Blend) with multi-resolution observations is considerably more time consuming that the single resolution version, where the mappings must be created to map from the low resolution to high resolution frame, as well as matching the PSFs in all resolutions.

In [None]:
import time
t0 = time.clock()

blend = scarlet.Blend(scene, sources, obs)
t1 = time.clock()
print('setup time: {}'.format(t1-t0))

## Display the initial guess

Compare the initial guess of the model in both the model frame and HSC observation frame

In [None]:


# Load the model and calculate the residual
_model = blend.get_model()
model = obs_hsc.model_to_frame(_model)
_init_rgb = scarlet.display.img_to_rgb(_model[:-1], norm=hsc_norm)
init_rgb = scarlet.display.img_to_rgb(model, norm=hsc_norm)
residual_lr = data_hsc - model
# Trim the bottom source not part of the blend from the image
residual_lr_rgb = scarlet.display.img_to_rgb(residual_lr[:,:-5])

# Get the HR residual
residual_hr = (data_hst - _model)[0]
vmax = residual_hr.max()

plt.figure(figsize=(15, 10))
plt.subplot(231)
plt.imshow(img_rgb)
plt.title("Data")
plt.subplot(232)
plt.imshow(_init_rgb)
plt.title("HighRes Model")
plt.subplot(233)
plt.imshow(init_rgb)
plt.title("LowRes Model")
plt.subplot(235)
plt.imshow(residual_hr, cmap="seismic", vmin=-vmax, vmax=vmax)
plt.colorbar(fraction=.045)
plt.title("HST residual")
plt.subplot(236)
plt.imshow(residual_lr_rgb)
plt.title("HSC residual")
plt.show()

## Fit the model

In [None]:
t0 = time.clock()
blend.fit(200, e_rel = 1e-3)
t2 = time.clock()
print("scarlet ran for {0} iterations in {1} seconds".format(blend.it, t2-t0))

### View the full model
First we load the model for the entire blend and its residual. Then we display the model using the same $sinh^{-1}$ stretch as the full image and a linear stretch for the residual to see the improvement from our initial guess.

In [None]:
# Load the model and calculate the residual
_model = blend.get_model()
model = obs_hsc.model_to_frame(_model)
_init_rgb = scarlet.display.img_to_rgb(_model[:-1], norm=hsc_norm)
init_rgb = scarlet.display.img_to_rgb(model, norm=hsc_norm)
residual_lr = data_hsc - model
# Trim the bottom source not part of the blend from the image
residual_lr_rgb = scarlet.display.img_to_rgb(residual_lr[:,:-5])

# Get the HR residual
residual_hr = (data_hst - _model)[0]
vmax = residual_hr.max()

plt.figure(figsize=(15, 10))
plt.subplot(231)
plt.imshow(img_rgb)
plt.title("Data")
plt.subplot(232)
plt.imshow(_init_rgb)
plt.title("HighRes Model")
plt.subplot(233)
plt.imshow(init_rgb)
plt.title("LowRes Model")
plt.subplot(235)
plt.imshow(residual_hr, cmap="seismic", vmin=-vmax, vmax=vmax)
plt.colorbar(fraction=.045)
plt.title("HST residual")
plt.subplot(236)
plt.imshow(residual_lr_rgb)
plt.title("HSC residual")
plt.show()

### View the source models
It can also be useful to view the model for each source. For each source we extract the portion of the image contained in the sources bounding box, the true simulated source flux, and the model of the source, scaled so that all of the images have roughly the same pixel scale.

In [None]:
has_truth = False
axes = 2

for k,src in enumerate(blend.sources):
    # Get the model for a single source
    model_hr = src.get_model()
    model_lr = obs_hsc.model_to_frame(model_hr)
    
    # Display the low resolution image and residuals
    img_lr_rgb = scarlet.display.img_to_rgb(model_lr)
    res = data_hsc-model_lr
    res_rgb = scarlet.display.img_to_rgb(res)
    
    plt.figure(figsize=(15,15))
    
    plt.subplot(331)
    plt.imshow(img_rgb)
    plt.plot(Xo[k],Yo[k], 'o', markersize = 5)
    plt.title("HSC Data")
    plt.subplot(332)
    plt.imshow(img_lr_rgb)
    plt.title("LR Model")
    plt.subplot(333)
    plt.imshow(res_rgb)
    plt.title("Data-Model")
    
    img_hr = obs_hst.get_model(model_hr)
    res = data_hst-img_hr[-1]
    vmax = res.max()
    
    plt.subplot(334)
    plt.imshow(data_hst[0], cmap='gist_stern')
    plt.plot(xo[k],yo[k], 'o', markersize = 5)
    plt.title("HST Data")
    plt.subplot(335)
    plt.imshow(img_hr[-1])
    plt.title("HR Model")
    plt.subplot(336)
    plt.imshow(res[0], cmap='seismic', vmin=-vmax, vmax=vmax)
    plt.title("Data-Model")
    
    # Display the morphology in high resolution
    morph_hr = src.morph
    plt.subplot(337)
    plt.imshow(morph_hr)
    plt.title('HR Morphology')
    
    # Display the morphology in low resolution
    # Eventually this will be a class method,
    # but for now we have to calcualte this explicitly
    morph_lr = np.zeros((N1,N2))
    morph_lr[obs_hsc._coord_lr[0].astype(int), obs_hsc._coord_lr[1]] = np.dot(morph_hr.flatten(),obs_hsc.resconv_op[0,:,:])
    plt.subplot(338)
    plt.imshow(morph_lr)
    plt.title('LR Morphology')
    plt.subplot(339)
    plt.plot(src.sed, '.-')
    plt.title('SED')
    plt.suptitle("Source {0}".format(k), y=.92)
    plt.show()
