# $\text{Morpheus}^{++}$ Label Idea

The input, $x$, to $\text{Morpheus}^{++}$ is a multiband image with a height $H$, width $W$, and number of bands $B$

$$x_{ijb} \in \mathbb{R}^{H \times W \times B}$$

The label has three components:
- Morphology $\mathbb{R}^{1}$
- claim vectors, $\mathbb{Z}^{B \times 8 \times 2}$
- claim distribution, $\mathbb{Z}^{B \times 8}$ 
- Centerpoint distribution value $\mathbb{R}^{1}$

The first component, morphology, is given by the original morpheus setup.

The claim vectors and claim distribution are new to morpheus and describe deblending

The center of mass distribution value represents whether or not that pixels is a center of mass indicating an instance which the claim vectors and claim distribtution point to


## Center of Mass Distribution

The center of mass distribution indicates individual instances in the image. In https://arxiv.org/pdf/1911.10194.pdf, during training, the center of masses are encoded as a gaussian with standard deviation of 8 pixels.


In [1]:
from itertools import starmap
from functools import partial

import numpy as np
from scipy import signal

#https://stackoverflow.com/a/46892763/2691018
def gaussian_kernel_2d(kernlen, std=8):
    """Returns a 2D Gaussian kernel array."""
    gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
    gkern2d = np.outer(gkern1d, gkern1d)
    return gkern2d


# UPDATES 'image' in place
def insert_gaussian(image, g_kern, y, x) -> None:
    height, width = image.shape
    half_kernel_len = g_kern.shape[0] // 2

    def image_slice_f(yx, bound):
        return slice(
            max(yx-half_kernel_len, 0),
            min(yx+half_kernel_len, bound)
        )

    def kernel_slice_f(yx, bound):
        begin = half_kernel_len - min(
            half_kernel_len, 
            half_kernel_len-(half_kernel_len-yx)
        )

        end = half_kernel_len + min(
            half_kernel_len,
            bound - yx
        ) 
        return slice(begin, end)

    image_ys = image_slice_f(y, height)
    image_xs = image_slice_f(x, width)

    kernel_ys = kernel_slice_f(y, height)
    kernel_xs = kernel_slice_f(x, width)

    tmp_image = image[image_ys, image_xs].copy()
    tmp_kernel = g_kern[kernel_ys, kernel_xs].copy()

    image[image_ys, image_xs] = np.maximum(tmp_image, tmp_kernel)



# source_locations is a 2d array indicating source locations
def build_center_mass_image(
    source_locations:np.ndarray,
    gaussian_kernel_std:int,
) -> np.ndarray:
    center_of_mass = np.zeros_like(source_locations, dtype=np.float32)
    src_ys, src_xs = np.nonzero(source_locations)

    width, height = center_of_mass.shape

    gaussian_kernel = gaussian_kernel_2d(width, std=gaussian_kernel_std)

    insert_gaussian_f = partial(insert_gaussian, center_of_mass, gaussian_kernel)

    for _ in starmap(insert_gaussian_f, zip(src_ys, src_xs)):pass

    return center_of_mass

## Claim Vector

Unlike in https://arxiv.org/pdf/1911.10194.pdf, where each pixel is associated with a single instance. In our setting, each pixel can be associated with multiple intances. This is encoded by considering each pixel as an 8-connected pixel where each pixel in the 8-connected set encodes an xy offset to its nearest center of mass. Enoding the data this way allows a pixel to be associated with at most 8 different sources.

## Claim Map

Each source associated with a pixel via a claim vector only contributes some fraction
of the flux in the image. To incorporate that information into the claim map is a 
vector weights that sum to one that indicate how much of the flux comes from a single
source.

In [2]:
from typing import Callable, List, Tuple

import scarlet
import scarlet.psf as psf

def get_scarlet_source(
    model_frame,
    observation,
    morpheus_label:np.ndarray, 
    segmap:np.ndarray,
    source_id:int, 
    source_yx: Tuple[int, int]
) -> scarlet.component.FactorizedComponent:
    # For now go with the default scarlet recommendation of ExtendedSource with k=1
    return scarlet.ExtendedSource(model_frame, source_yx, observation)


def get_scarlet_fit(
    filters:List[str], 
    psfs:np.ndarray, 
    model_psf:Callable, 
    flux:np.ndarray, # [b, h, w]
    source_locations: np.ndarray,
    morpheus_label:np.ndarray, # [h, w, m]
    segmap:np.ndarray # [h, w]
):
    """Fit scarlet to image for generating labels"""


    model_frame = scarlet.Frame(
        flux.shape,
        psfs=model_psf,
        channels=filters
    )

    observation = scarlet.Observation(
        images,
        psfs=psfs,
        weights=weights,
        channels=filters,
    ).match(model_frame)

    get_source = partial(
        get_scarlet_source, 
        model_frame, 
        observation, 
        morpheus_label, 
        segmap
    )

    src_ys, src_xs = np.nonzero(source_locations)
    src_ids = source_locations[src_ys, src_xs]

    sources = list(starmap(get_source, src_ids, zip(src_ys, src_xs)))

    blend = scarlet.Blend(sources, observation)
    blend.fit(200, e_rel = 1.e-6)

    def render_source(obs, src):
        return obs.render(src.get_model(frame=src.frame))
    render_f = partial(render_source, observation)

    model_src_vals = list(map(render_f, sources))
    
    return model_src_vals

def get_claim_vector_image_and_map(
    source_locations:np.ndarray, 
    bhw:Tuple[int, int, int],
    model_src_vals:List[np.ndarray],
):
    
    # Updates claim_vector_image and claim_map_image in place
    def single_pixel_vector(
        claim_vector_image:np.ndarray,
        claim_map_image:np.ndarray,
        centers:np.ndarray, 
        i:int, 
        j:int, 
        b: int
    ) -> None:
        connected_idxs = list(product([i, i-1, i+1], [j, j-1, j+1]))
        connected_idxs.remove((i, j))
        connected_array = np.array(connected_idxs)        

        ijb_src_flux = np.array([m[b,i,j] for m in model_src_vals])
        ijb_src_flux_mask = ijb_src_flux > 0
        
        ijb_normed_src_flux = (
            (ijb_src_flux * ijb_src_flux_mask) 
            / (ijb_src_flux * ijb_src_flux_mask).sum()
        )

        def closest_center(
            centers:np.array, 
            flux_mask:np.ndarray, 
            idx:np.ndarray
        ):
            dist = np.linalg.norm(centers-idx, axis=1)
            masked_dist = np.where(flux_mask, dist, np.inf)
            return centers[np.argmin(masked_dist)]

        closest_f = partial(closest_center, centers, ijb_src_flux_mask)
        closest_sources = np.array(list(map(closest_f, connected_array)))
        claim_vector = connected_array - closest_sources # [8]

        claim_vector_image[i, j, b, ...] = claim_vector

        def convert_to_claim_map(
            centers:np.ndarray, 
            normed_flux:np.ndarray, 
            src:np.ndarray
        ):
            return ((src==centers).all(axis=1).astype(np.float32) * normed_f_ijb).sum()

        convert_to_map_f = partial(convert_to_claim_map, centers, ijb_normed_src_flux)
        raw_claim_map = np.array(list(map(convert_to_map_f, closest_sources)))
        claim_map = raw_claim_map / raw_claim_map.sum()

        claim_map_image[i, j, b, ...] = claim_map


    n_bands, height, width = bhw
    claim_vector_image = np.zeros([height, width, n_bands, 8, 2], dtype=np.float32)
    claim_map_image = np.zeros([height, width, n_bands, 8], dtype=np.float)

    src_ys, src_xs = np.nonzero(source_locations)
    centers = np.array([src_ys, src_xs]).T # [n, 2]

    single_pixel_f = partial(
        single_pixel_vector, 
        claim_vector_image, 
        claim_map_image, 
        centers
    )

    idxs = tqdm(
        product(range(height), range(width), range(n_bands)),
        total=height * width * n_bands,
        unit="pixels" 
    )

    for _ in starmap(single_pixel_f, idxs): pass
    
    return claim_vector_image, claim_map_image

## Label Function

Given a single image containing a flux image, source label, and source locations we can create a label
