In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import sys
sys.path.insert(0, './../')
import simulated_datasets_lib
import starnet_lib
import plotting_utils
import psf_transform_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('torch version: ', torch.__version__)
print(device)

In [None]:
np.random.seed(453)
_ = torch.manual_seed(786)

# The PSF

In [None]:
# this is the PSF I fitted using ground truth Hubble locations/fluxes. 
init_psf_params = torch.Tensor(np.load('../data/fitted_powerlaw_psf_params.npy'))
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params.to(device))
psf = power_law_psf.forward().detach()

# number of bands. Here, there are two. 
n_bands = psf.shape[0]

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)
    
print(data_params)

In [None]:
# set background 
background = torch.zeros(n_bands, data_params['slen'], data_params['slen'])
background[0] = 686.
background[1] = 1123.

In [None]:
# draw data 
n_images = 1

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf,
                    data_params,
                    background = background,
                    n_images = n_images,
                    transpose_psf = False, 
                    add_noise = True)

images = simulated_dataset.images.detach()

# images is n_images x n_bands x slen x slen
print(images.shape)

In [None]:
plt.matshow(images[0, 0])

# Load Encoder

In [None]:
star_encoder = starnet_lib.StarEncoder(slen = data_params['slen'],
                                            patch_slen = 8,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = n_bands,
                                            max_detections = 2)

In [None]:
star_encoder.load_state_dict(torch.load('../example_starnet_encoder',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

# Getting samples from the approximate posterior

At the most generic level, the star encoder takes in an image, and returns a set of varational parameters. We can then sample from the variational distribution.  

This is done using the `sample_star_encoder` method.

We demonstrate this right now

In [None]:
locs_sampled, fluxes_sampled, n_stars_sampled, _, _, _ = \
        star_encoder.sample_star_encoder(image = images, 
                                     n_samples = 100)

In [None]:
# the sampled number of stars
print(n_stars_sampled)

In [None]:
# locs_sampled is of shape n_samples x max(n_stars_sampled) x 2 
locs_sampled.shape

In [None]:
# fluxes_sampled is of shape n_samples x max(n_stars_sampled) x n_bands
locs_sampled.shape

For those rows with less than `max(n_stars_sampled)` number of stars, "empty" stars have zeros in the entries for locations and fluxes

### As an example, lets take a look at the catalog from the first sample

In [None]:
# to avoid overplotting, we'll just look at the bright stars 
which_bright = torch.log10(fluxes_sampled[0, :, 0]) > 4.5

plt.matshow(images[0, 0])
plt.scatter(locs_sampled[0, which_bright, 1] * (data_params['slen'] - 1), 
            locs_sampled[0, which_bright, 0] * (data_params['slen'] - 1), 
           marker = 'x', color = 'red', alpha = 0.5)

### we can also return the map estimate, rather than samples

In [None]:
locs_map, fluxes_map, n_stars_map, _, _, _ = \
        star_encoder.sample_star_encoder(image = images, 
                                         return_map_n_stars = True,
                                        return_map_star_params = True)

In [None]:
# This time, let us zoom in on a subimage. 
# I have a handy function, 

# we randomly choose some coordinates to look at
x0 = int(np.random.choice(data_params['slen'], 1))
x1 = int(np.random.choice(data_params['slen'], 1))

fig, axarr = plt.subplots(1, 2, figsize=(12, 4))

# the map estimates in red, and truth in blue
_ = plotting_utils.plot_subimage(axarr[0], 
                             images[0, 0], 
                             locs_map[0], 
                             simulated_dataset.locs[0], 
                             x0, x1, 
                             patch_slen = 10, 
                            add_colorbar = True, 
                             global_fig = fig)

# here, I'm plotting posterior samples
_ = plotting_utils.plot_subimage(axarr[1], 
                             images[0, 0], 
                             locs_sampled.view(-1, 2), 
                             simulated_dataset.locs[0], 
                             x0, x1, 
                             patch_slen = 10, 
                            add_colorbar = True, 
                             global_fig = fig, 
                                alpha = 0.2)

# Under the hood: tiling and un-tiling

We go into some detail now, about the pieces that go into the `sample_star_encoder` method. 

The way our variational distribution works is, we take the large 100 x 100 image and create image patches. 

Our variational distribution factorizes over 2 x 2 tiles. 

To construct a variational distribution on a 2 x 2 tile, we feed the 2 x 2 tile, plus some border, into a neural network. 

The patch size was is passed to the `patch_slen` argument in the `__init__` when defining the star encoder. The size of the border is set by `edge_padding`. 

The method `get_image_patches` breaks up the full image, and returns the patches to be inputed in to the neural netowrk. 

In [None]:
image_patches, patch_locs, patch_fluxes, patch_n_stars, _ = star_encoder.get_image_patches(images, 
                                           locs = simulated_dataset.locs, 
                                           fluxes = simulated_dataset.fluxes)

In [None]:
# image patches has shape n_patches x n_bands x patch_slen x patch_slen 
# patch_slen was set in the __init__ of the encoder. 

image_patches.shape

Note that this method also takes in (optionally) the ground truth locations and fluxes. These are originally in the parameterization for the full image; this function will then give you the parameterization of the locations and fluxes on each individual patch (`patch_locs`, `patch_fluxes`). This is needed for the sleep phase. 

### lets print an image patch

In [None]:
# just pick one 
indx = 1837 #  np.random.choice(image_patches.shape[0])
plt.matshow(image_patches[indx, 0])


# the neural network returns locations / fluxes for any stars in the center of the grid
plt.axvline(x=star_encoder.edge_padding - 0.5, color = 'r')
plt.axvline(x=star_encoder.patch_slen - star_encoder.edge_padding - 0.5, color = 'r')
plt.axhline(y=star_encoder.edge_padding - 0.5, color = 'r')
plt.axhline(y=star_encoder.patch_slen - star_encoder.edge_padding - 0.5, color = 'r')


# lets mark any true stars in this particular tile 
if patch_n_stars[indx] > 0: 
    _loc = patch_locs[indx][0:patch_n_stars[indx]]
    # you have to scale and shift appropriately so the location plots in the right place. 
    # all locations, whether on the full image or the tile, is parameterized to be between 0 and 1
    _loc = _loc * (star_encoder.patch_slen - 2 * star_encoder.edge_padding) +  star_encoder.edge_padding - 0.5
    plt.scatter(_loc[:, 1], _loc[:, 0], marker = 'o', color = 'b')

### the forward method

The forward method takes in image patches and returns variational distribution parameters on each tile. 

In [None]:
# pass through the neural network is implemnted by ``get_var_params_all''
var_params_all = star_encoder.get_var_params_all(image_patches)

This is a giant mess of a tensor, giving the triangular array of variational parameters for each patch. 

For each patch, we have variational parameters for locations 

$\{\ell_{n, i}, f_{n, i}\}$ for $i = 1, ..., n; n = 1, ..., N_{max}$, 

as well as the $N_{max}$ probabilities for the number of stars. 

### indexing into this triangular array

In [None]:
import utils

### In the sleep phase, we need to evaluate the variational distribution at the true number of stars

In [None]:
# we need to return the variational parameters that correspond to the true number of stars. 

# recall that patch_n_stars was the true number of stars on each patch. 

loc_mean, loc_logvar, \
    log_flux_mean, log_flux_logvar = \
        star_encoder.get_var_params_for_n_stars(var_params_all, 
                                        n_stars=patch_n_stars.clamp(max = star_encoder.max_detections))

In [None]:
# lets look at the parameters in an arbirary patch
indx = 1837 #  np.random.choice(image_patches.shape[0])

# loc_mean has shape n_patches x max_detections x 2
print(loc_mean.shape)

# as before, ``empty" stars have zero entries. 
print(loc_mean[indx])

In [None]:
# same plot as before, but I also mark my estimated star on this star

plt.matshow(image_patches[indx, 0])

# the neural network returns locations / fluxes for any stars in the center of the grid
plt.axvline(x=star_encoder.edge_padding - 0.5, color = 'r')
plt.axvline(x=star_encoder.patch_slen - star_encoder.edge_padding - 0.5, color = 'r')
plt.axhline(y=star_encoder.edge_padding - 0.5, color = 'r')
plt.axhline(y=star_encoder.patch_slen - star_encoder.edge_padding - 0.5, color = 'r')


# lets mark any true stars in this particular tile 
if patch_n_stars[indx] > 0: 
    _loc = patch_locs[indx][0:patch_n_stars[indx]]
    # you have to scale and shift appropriately so the location plots in the right place. 
    # all locations, whether on the full image or the tile, is parameterized to be between 0 and 1
    _loc = _loc * (star_encoder.patch_slen - 2 * star_encoder.edge_padding) +  star_encoder.edge_padding - 0.5
    plt.scatter(_loc[:, 1], _loc[:, 0], marker = 'o', color = 'b')
    
# plot my estimated star
if patch_n_stars[indx] > 0: 
    _loc = loc_mean[indx][0:patch_n_stars[indx]].detach()
    # you have to scale and shift appropriately so the location plots in the right place. 
    # all locations, whether on the full image or the tile, is parameterized to be between 0 and 1
    _loc = _loc * (star_encoder.patch_slen - 2 * star_encoder.edge_padding) +  star_encoder.edge_padding - 0.5
    plt.scatter(_loc[:, 1], _loc[:, 0], marker = 'x', color = 'r')


### in the wake phase we need to sample from the variational distribution on n_stars. 


In [None]:
# we can get the distribution for n_stars on each patch using `get_logprob_n_from_var_params`

n_samples = 5

# log_probs_n has shape n_patches x max_detections
log_probs_nstar_patch = star_encoder.get_logprob_n_from_var_params(var_params_all)

# these are the sampled number of stars on each patch, of shape
# n_samples x n_patches
patch_n_stars_sampled = \
        utils.sample_class_weights(torch.exp(log_probs_nstar_patch.detach()), n_samples).view(n_samples, -1)

In [None]:
loc_mean, loc_logvar, \
    log_flux_mean, log_flux_logvar = \
        star_encoder.get_var_params_for_n_stars(var_params_all,
                                                # so the drawn variational parameters depend on n_stars
                                                # n_stars tell us which rows of the triangular array to index
                                                n_stars = patch_n_stars_sampled)

In [None]:
# now loc_mean has an extra dimension, corresponding to the number of samples
print(loc_mean.shape)

# we can then draw samples from loc_mean, loc_logvar, etc. 
# they are all just Gaussian

# Incorporating the Galaxy model

In [None]:
# Now our variational distribution, instead of returning a mean and variance for fluxes, 
# will return a mean and variance for galaxy parameters

In [None]:
gal_dim = 8
galaxy_encoder = starnet_lib.StarEncoder(
                            slen = 100, 
                            n_bands = 2,
                            # patches may need to be larger now
                            # and corresponding padding parameters should change
                            patch_slen = 8, 
                            step = 2, 
                            edge_padding = 3, 
                            max_detections = 2, 
                            # the number of source parameters 
                            n_source_params = gal_dim)

## Sleep phase

In [None]:
# latent variables for ** full image **

# locations
locs = simulated_dataset.locs 

# I'm just making up galaxy parameters right now. 
# for nonzero stars, simulate galaxy parameters
galaxy_params = torch.zeros(1, simulated_dataset.max_stars, gal_dim)
galaxy_params[0, 0:simulated_dataset.n_stars[0], :] = torch.rand(simulated_dataset.n_stars[0], gal_dim)

print(galaxy_params)

In [None]:
# convert full image to image patches
# convert parameters on full image to parameters on patches 

image_patches, patch_locs, patch_galaxy_params, patch_n_stars, _ = \
    galaxy_encoder.get_image_patches(images, 
                                        locs = simulated_dataset.locs, 
                                        # I still need to chage variable names. 
                                        # fluxes = galaxy_params
                                        fluxes = galaxy_params)

In [None]:
# get all variational parameters on image patches
var_params_all = galaxy_encoder.get_var_params_all(image_patches)

In [None]:
# get variational parameters for n_sources
log_probs_nstar_patch = galaxy_encoder.get_logprob_n_from_var_params(var_params_all)

# get variational parameters at true number of stars 
loc_mean, loc_logvar, \
    galaxy_param_mean, galaxy_param_logvar = \
        galaxy_encoder.get_var_params_for_n_stars(var_params_all, 
                                        # we clip at max detections
                                        n_stars=patch_n_stars.clamp(max = galaxy_encoder.max_detections))

# both location and galaxies paremeters are normal with a mean and variance 
# note that the **log-variance** is returned. Take exp to give variance

# Take a note! Off stars are marked with zero. taking exp of zero is not zero :(
# in my source code, you'll often see a matrix called is_on_array -- this is just a binary vector 
# with 1 if star is on and 0 if star is off. 
# premultiplying by this array sovles this problem. 
# see get_is_on_from_n_stars_2d and get_is_on_from_n_stars

In [None]:
# galaxy_param_mean has shape n_patches x max_detections x gal_dim
print(galaxy_param_mean.shape)

## wake phase

In [None]:
# sample from variational distribution 
locs_sampled, gal_params_sampled, n_stars_sampled, _, _, _ = \
        galaxy_encoder.sample_star_encoder(images, n_samples = 5)

In [None]:
# you would then take these sampled n_stars, locs, galaxy parameters, and pass it into a generative model. 