In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import kl_objective_lib as kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = 650, x1 = 120)

# psf file 
psf_fit_file = str(sdss_hubble_data.psf_file)

# image 
full_image = sdss_hubble_data.sdss_image.squeeze()
full_background = sdss_hubble_data.sdss_background.squeeze()

# true parameters
which_bright = (sdss_hubble_data.fluxes > 1000.)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# Our simulator

In [None]:
from copy import deepcopy
psf_og = sdss_psf.psf_at_points(0, 0, psf_fit_file = str(sdss_hubble_data.psf_file))

psf_init = torch.Tensor(simulated_datasets_lib._expand_psf(psf_og, full_image.shape[-1]))

# define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 9,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = 1,
                                           max_detections = 2)

# load vae

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet-10172019-no_reweighting', 
                                       map_location=lambda storage, loc: storage))
star_encoder.eval(); 

# Define transform

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1], 
                                    kernel_size = 3)

In [None]:
psf = psf_transform.forward()

# Checkout map estimates

In [None]:
map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    wake_sleep_lib.sample_star_encoder(star_encoder, 
                                       full_image.unsqueeze(0).unsqueeze(0), 
                                       full_background.unsqueeze(0).unsqueeze(0),
                                       n_samples = 1, return_map = True)

In [None]:
map_locs_full_image_test, map_fluxes_full_image_test, map_n_stars_full_test = \
    star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                           full_background.unsqueeze(0).unsqueeze(0))

In [None]:
(map_fluxes_full_image_test - map_fluxes_full_image).abs().max()

In [None]:
(map_locs_full_image - map_locs_full_image_test).abs().max()

In [None]:
image_stamps = star_encoder.get_image_stamps(full_image.unsqueeze(0).unsqueeze(0),
                        locs = None, fluxes = None, trim_images = False)[0]
background_stamps = star_encoder.get_image_stamps(full_background.unsqueeze(0).unsqueeze(0),
                    locs = None, fluxes = None, trim_images = False)[0]

In [None]:
h = star_encoder._forward_to_last_hidden(image_stamps, background_stamps).detach()

In [None]:
log_probs = star_encoder._get_logprobs_from_last_hidden_layer(h)

In [None]:
n_stars_sampled = torch.argmax(log_probs, dim = 1).unsqueeze(0)

In [None]:
logit_loc_mean, logit_loc_logvar, \
        log_flux_mean, log_flux_logvar = \
            wake_sleep_lib._get_params_from_last_hidden_layer_2dn_stars(star_encoder, h, n_stars_sampled)

In [None]:
logit_loc_mean.shape

In [None]:
logit_loc_mean_test, logit_loc_log_var_test, \
            log_flux_mean_test, log_flux_log_var_test, log_probs_test = \
                star_encoder.forward(image_stamps, background_stamps, n_stars_sampled.squeeze())

is_on_array_test = starnet_vae_lib.get_is_on_from_n_stars(n_stars_sampled.squeeze(), star_encoder.max_detections)

map_locs = torch.sigmoid(logit_loc_mean_test).detach() * is_on_array_test.unsqueeze(2).float()
map_fluxes = torch.exp(log_flux_mean_test).detach() * is_on_array_test.float()

In [None]:
(logit_loc_mean.squeeze() - logit_loc_mean_test).abs().max()

In [None]:
(logit_loc_logvar.squeeze() - logit_loc_log_var_test).abs().max()

In [None]:
is_on_array = wake_sleep_lib.get_is_on_from_n_stars_2d(n_stars_sampled,
                            star_encoder.max_detections)

In [None]:
locs_randn = torch.randn(logit_loc_mean.shape)

logit_loc_sd = torch.zeros(logit_loc_logvar.shape)

subimage_locs_sampled = \
    torch.sigmoid(logit_loc_mean + \
                    locs_randn * logit_loc_sd) * \
                    is_on_array.unsqueeze(2).float()

In [None]:
torch.sigmoid(logit_loc_mean)

In [None]:
map_recon_mean, map_loss = psf_transform_lib.get_psf_loss(full_image.squeeze(), full_background.squeeze(),
                                        map_locs_full_image,
                                        map_fluxes_full_image,
                                        n_stars = map_n_stars_full,
                                        psf = psf_init,
                                        pad = 5, grid = None)

# map_recon_mean, map_loss = psf_transform_lib.get_psf_loss(full_image.squeeze(), full_background.squeeze(),
#                                         true_locs.unsqueeze(0),
#                                         true_fluxes.unsqueeze(0),
#                                         n_stars = torch.Tensor([len(true_fluxes)]).type(torch.long),
#                                         psf = psf_init,
#                                         pad = 5, grid = None)

In [None]:
map_loss

In [None]:
map_resid = map_recon_mean.squeeze().detach() - full_image

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze())

axarr[1].matshow(map_recon_mean.squeeze().detach())

_resid = map_resid / full_image
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])


# check sampling of variational parameters

In [None]:
n_samples = 10

In [None]:
torch.sum(sampled_fluxes_full_image > 0, dim = 1)

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
    wake_sleep_lib.sample_star_encoder(star_encoder, 
                                       full_image.unsqueeze(0).unsqueeze(0), 
                                       full_background.unsqueeze(0).unsqueeze(0),
                                       n_samples = n_samples, return_map = False)

In [None]:
recon_means, loss = psf_transform_lib.get_psf_loss(full_image.squeeze(), full_background.squeeze(),
                                        sampled_locs_full_image,
                                        sampled_fluxes_full_image,
                                        n_stars = sampled_n_stars_full,
                                        psf = psf,
                                        pad = 5, grid = None)

In [None]:
plt.hist(loss.detach())

In [None]:
loss.mean()

In [None]:
for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze())

    axarr[1].matshow(recon_means[i].squeeze().detach())

    _resid = recon_means[i].squeeze().detach() - full_image.squeeze()
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])


In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze()[40:60, 40:60])

axarr[1].matshow(map_recon_mean.squeeze().detach()[40:60, 40:60])

_resid = (map_recon_mean.squeeze().detach() - full_image.squeeze())[40:60, 40:60]
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])

for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze()[40:60, 40:60])

    axarr[1].matshow(recon_means[i].squeeze().detach()[40:60, 40:60])

    _resid = (recon_means[i].squeeze().detach() - full_image.squeeze())[40:60, 40:60]
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])


In [None]:
for i in range(len(loss)): 
    _full_image = full_image[5:96, 5:96]
    _recon_mean = recon_means[i].squeeze()[5:96, 5:96]
    loss_i = inv_kl_lib.eval_normal_logprob(_full_image, _recon_mean, torch.log(_recon_mean)).sum()
    
    assert loss[i] == - loss_i

# Check my sampling of fluxes and locations

Here, we condition on a given set of n_stars

In [None]:
n_samples = 10

In [None]:
image_stamps = star_encoder.get_image_stamps(full_image.unsqueeze(0).unsqueeze(0),
                        locs = None, fluxes = None, trim_images = False)[0]
background_stamps = star_encoder.get_image_stamps(full_background.unsqueeze(0).unsqueeze(0),
                    locs = None, fluxes = None, trim_images = False)[0]

In [None]:
h = star_encoder._forward_to_last_hidden(image_stamps, background_stamps).detach()
# get log probs
log_probs = star_encoder._get_logprobs_from_last_hidden_layer(h)

In [None]:
# sample number of stars
from kl_objective_lib import sample_class_weights
n_stars_sampled = sample_class_weights(torch.exp(log_probs), n_samples)

In [None]:
is_on_array = wake_sleep_lib.get_is_on_from_n_stars_2d(n_stars_sampled,
                            star_encoder.max_detections)

In [None]:
logit_loc_mean, logit_loc_logvar, \
        log_flux_mean, log_flux_logvar = \
            wake_sleep_lib._get_params_from_last_hidden_layer_2dn_stars(star_encoder, h, n_stars_sampled)

In [None]:
logit_loc_mean.shape

In [None]:
# CHECK THAT THIS MATCHES MY OLD PARAMETERS
for i in range(n_samples): 
    logit_loc_mean_i, logit_loc_logvar_i, \
        log_flux_mean_i, log_flux_logvar_i = \
            star_encoder._get_params_from_last_hidden_layer(h, n_stars_sampled[i])
            
    assert torch.all(logit_loc_mean_i == logit_loc_mean[i])
    assert torch.all(logit_loc_logvar_i == logit_loc_logvar[i])
    assert torch.all(log_flux_mean_i == log_flux_mean[i])
    assert torch.all(log_flux_logvar_i == log_flux_logvar[i])

In [None]:
is_on_array.shape

In [None]:
logit_loc_logvar.shape