# Quick Start Guide

This tutorial shows how to quickly get started using *scarlet* to model an hyperspectral image cube. For a more in-depth introduction to *scarlet*, read our [User Guide](user_docs.ipynb).

In order to run this tutorial you will need either `astropy` (http://www.astropy.org) or `sep` (https://github.com/kbarbary/sep) installed to open/create the source catalog and `matplotlib` (https://matplotlib.org) to display the images

In [None]:
# Import Packages and setup
import logging
logger = logging.getLogger('scarlet')
logger.setLevel(logging.DEBUG)
logger = logging.getLogger("proxmin")
logger.setLevel(logging.DEBUG)


import autograd.numpy as np
import scarlet
import scarlet.display

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='inferno')
matplotlib.rc('image', interpolation='none')

## Load and display the sample data

In [None]:
# Load the sample images
data = np.load("../data/hsc_cosmos_35.npz")
images = data["images"]
filters = data["filters"]
catalog = data["catalog"]
weights = 1/data["variance"]
psfs = scarlet.PSF(data["psfs"])

### Display a raw image cube
This is an example of how to display an RGB image from an image cube of multiband data. In this case the image uses a $sinh^{-1}$ function to normalize the flux in each filter consistently to create an RGB image.

In [None]:
from astropy.visualization.lupton_rgb import AsinhMapping, LinearMapping

stretch = 0.1
Q = 10
norm = AsinhMapping(minimum=0, stretch=stretch, Q=Q)
img_rgb = scarlet.display.img_to_rgb(images, norm=norm)
plt.imshow(img_rgb)

# Mark all of the sources from the detection cataog
for k, src in enumerate(catalog):
    plt.text(src["x"], src["y"], str(k), color="red")

## Define the model frame and the observation

A `Frame` in *scarlet* is a description of the hyperspectral cube of the model or the observations. Think of it as the metadata, what aspects of the sky are described here. For observations, most of those are contained in FITS headers. At the least, a `Frame` holds the `shape` of the cube, for which we use the convention `(C, Ny, Nx)` for the number of elements in 3 dimensions: `C` for the number of bands/channels and `Ny, Nx` for the number of pixels at every channel.

Additionally, you can and often must provide weights for all elements in the data cube, an image cube of the PSF model (one image for all or one for each channel), an `astropy.WCS` structure to translate from pixel to sky coordinates, and labels for all channels. The reason for specifying them is to enable the code to internally map from the model frame, in which you seek to fit a model, to the observed data frame.

In this example, we assume that bands and pixel locations are identical between the model and the observation. But we have ground-based images with different PSFs in each band, so we need to provide a reference PSF for the model. We simply choose a minimal Gaussian PSF that is barely well sampled and use it as our reference kernel:

In [None]:
from functools import partial
model_psf = scarlet.PSF(partial(scarlet.psf.gaussian, sigma=.8))

With this we can fully specify the `Frame` and the `Observation`. Think of the latter as a `Frame` with a data portion.

In [None]:
model_frame = scarlet.Frame(
    images.shape,
    psf=model_psf,
    channels=filters)

observation = scarlet.Observation(
    images, 
    psf=psfs, 
    weights=weights, 
    channels=filters).match(model_frame)

The last command calls the `match` method to compute e.g. PSF difference kernel and filter transformations.

We generally recommend this pattern:
1. define model frame
2. construct observation
3. match it to the model frame

Steps 2 and 3 are combined above using a fluent pattern.

## Initialize the sources

You now need to define sources that are going to be fit. The full model, which we will call `Blend`, is a collection of those sources. We provide several pre-built source types:

* [RandomSource](api/scarlet.source.html#scarlet.source.RandomSource) fit per-band amplitude and non-parametric morphology starting from uniform random draws for both.
* [PointSource](api/scarlet.source.html#scarlet.source.PointSource) fits centers and per-band amplitude using the observed PSF model.
* [ExtendedSource](api/scarlet.source.html#scarlet.source.ExtendedSource) fits per-band amplitude and a non-parametric morphology (which can be constrained to be symmetric and/or monotonic with respect to the center).
* [MultiComponentSource](api/scarlet.source.html#scarlet.source.MultiComponentSource) splits an `ExtendedSource` into multiple components that are initially radially separated.

In our example, we assume *prior* knowledge that object 0 is a star, and object 1 should be modeled as a bulge-disc model. Everything else is assumed a galaxy.

In [None]:
sources = []
for k,src in enumerate(catalog):
    if k == 0:
        new_source = scarlet.PointSource(model_frame, (src['y'], src['x']), observation)
    elif k == 1:
        new_source = scarlet.MultiComponentSource(model_frame, (src['y'], src['x']), observation, symmetric=False, monotonic=True, thresh=5)
    else:
        new_source = scarlet.ExtendedSource(model_frame, (src['y'], src['x']), observation, symmetric=False, monotonic=True, thresh=5)
    sources.append(new_source)

## Create and fit the model
The `scarlet.Blend` class represent the sources as a tree and has the machinery to fit all of the sources to the given images. In this example the code is set to run for a maximum of 200 iterations, but will end early if the likelihood and all of the constraints converge.

In [None]:
blend = scarlet.Blend(sources, observation)
%time blend.fit(200)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
plt.plot(blend.loss)
plt.xlabel('Iteration')
plt.ylabel('Negative log-Likelihood')

## View the results

### View the full model
First we load the model for the entire blend, render it in the observation frame, and compute its residuals. We then show model and date with the same $sinh^{-1}$ stretch and the residuals with a linear stretch.

In [None]:
# Load the model and calculate the residual
model = blend.get_model()
model_ = observation.render(model)
residual = images-model_
# Create RGB images
model_rgb = scarlet.display.img_to_rgb(model_, norm=norm)
residual_rgb = scarlet.display.img_to_rgb(residual)

# Show the data, model, and residual
fig = plt.figure(figsize=(15,5))
ax = [fig.add_subplot(1,3,n+1) for n in range(3)]
ax[0].imshow(img_rgb)
ax[0].set_title("Data")
ax[1].imshow(model_rgb)
ax[1].set_title("Model")
ax[2].imshow(residual_rgb)
ax[2].set_title("Residual")

for k,component in enumerate(blend):
    y,x = component.center
    ax[0].text(x, y, k, color="w")
    ax[1].text(x, y, k, color="w")
    ax[2].text(x, y, k, color="w")
plt.show()

### View the source models
We will now inspect the model for each source, in its original frame and in its observed frame. In this example, the two frames differ by an extra convolution from the minimal `model_psf` to the observed psfs.

In [None]:
# Set the stretch based on the model
stretch = .3
Q = 10
norm = AsinhMapping(minimum=0, stretch=stretch, Q=Q)

for k,src in enumerate(sources):
    # Get the model for a single source
    model = src.get_model()
    model_ = observation.render(model)
    
    # Convert observation and models to RGB
    img_rgb = scarlet.display.img_to_rgb(images, norm=norm)
    model_rgb = scarlet.display.img_to_rgb(model, norm=norm)
    model_rgb_ = scarlet.display.img_to_rgb(model_, norm=norm)

    # Set the figure size
    ratio = src.frame.shape[2]/src.frame.shape[1]
    fig_height = 3*src.frame.shape[1]/20
    fig_width = max(2*fig_height*ratio,2)
    fig = plt.figure(figsize=(fig_width, fig_height))
    
    # Generate and show the figure
    ax = [fig.add_subplot(1,3,n+1) for n in range(3)]
    ax[0].imshow(img_rgb)
    ax[0].set_title("Data")
    ax[1].imshow(model_rgb_)
    ax[1].set_title("Observed model {0}".format(k))
    ax[2].imshow(model_rgb)
    ax[2].set_title("Model {0}".format(k))
    # Mark the source in the data image
    y,x = src.center
    ax[0].plot(x, y, "wx", mew=1, ms=10)
    ax[1].plot(x, y, "wx", mew=1, ms=10)
    ax[2].plot(x, y, "wx", mew=1, ms=10)
    plt.show()

We can see that the model of object 0 assumes a simple Gaussian shape, which is the internal representation of a point source. It also shows the effective PSF of all the other models. Source 1 uses the freedom of the 2-compoent model to represent a slightly redder core.

### SEDs and Fluxes

The color information in these plots stems from the per-band amplitude, which can be obtained as `source.sed`. However, it is more useful to compute per-band fluxes, which integrate over the morphology. The convention of these fluxes is given by the units and ordering of the original data cube. In the case of multi-component sources, the fluxes of all components are combined.

In [None]:
print ("----------------- {}".format(filters))
for k, src in enumerate(sources):
    model = src.get_model()
    print ("Source {}, Fluxes: {}".format(k, scarlet.measure.flux(model)))

### Parameters and Errors

Internally, `Blend` solves an optimization problem, namely reducing the loss by adjusting the parameters of each component. The loss is the log-likelihood of the observed data given the model. Every component can declare its own parameters, which we can access by with the `parameters` property:

In [None]:
for k,src in enumerate(sources):
    for p in src.parameters:
        print ("Source {}, Parameter shape {}, Converged {}".format(k, p.shape, p.converged))

The parameter with length 5 is the SED, while the other describes the morphology. For object 0, this is simple a 2D center, the rest use images of different sizes.

Each parameter is a souped up numpy array. It has attributes that store any constraints that were enforced during optimization, whether this parameter is considered converged, and an error estimate. In our example, several parameters have converged within relative changes of `e_rel=1e-3` (the default setting of `Blend.fit`), but others have not. This is why the fitter complained about non-convergence. The run above stopped because the loss did not change noticeably anymore.

To demonstrate the use of error estimate, we make a signal-to-noise map of the morphology of source 5:

In [None]:
p = sources[5].parameters[1]
plt.imshow(p / p.std)
plt.colorbar(label='SNR')

The SNR map shows that the center region is well determined by the data. However, this error estimate is purely statistical and does not include correlations between different parameters or different components. In fact, there's an upper lobe in the top-left corner of this source that is part of source 0. The gradient optimizer would exploit that and increase the morphology values there, but the monotonicity constraint has largely prevented that.

In [None]:
model = sources[6].get_model()
scarlet.measure.centroid(model, bbox=sources[6].bbox)

In [None]:
scarlet.measure.flux(model).sum()/5

In [None]:
scarlet.measure.centroid(model)

In [None]:
plt.imshow(sources[6]._morph)