In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import kl_objective_lib as kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# Load the data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = 650, x1 = 120)

# psf file 
psf_fit_file = str(sdss_hubble_data.psf_file)

# image 
full_image = sdss_hubble_data.sdss_image.squeeze()
full_background = sdss_hubble_data.sdss_background.squeeze()

# true parameters
which_bright = (sdss_hubble_data.fluxes > 1000.)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# Our simulator

In [None]:
from copy import deepcopy
psf_og = sdss_psf.psf_at_points(0, 0, psf_fit_file = str(sdss_hubble_data.psf_file))

psf_init = torch.Tensor(simulated_datasets_lib._expand_psf(psf_og, full_image.shape[-1]))

# define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 9,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = 1,
                                           max_detections = 2)

# load vae

In [None]:
star_encoder.load_state_dict(torch.load('../fits/wake_sleep-loc650x120-iwae-10232019-encoder-iter2', 
                                       map_location=lambda storage, loc: storage))
star_encoder.eval(); 

# Define transform

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1], 
                                    kernel_size = 3)

In [None]:
psf = psf_transform.forward()

# Check out map estimates

In [None]:
map_locs_full_image, map_fluxes_full_image, map_n_stars_full, _, _, _ = \
    wake_sleep_lib.sample_star_encoder(star_encoder, 
                                       full_image.unsqueeze(0).unsqueeze(0), 
                                       full_background.unsqueeze(0).unsqueeze(0),
                                       n_samples = 1, return_map = True)

In [None]:
# assert this matches an earlier implementation
map_locs_full_image_test, map_fluxes_full_image_test, map_n_stars_full_test = \
    star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                           full_background.unsqueeze(0).unsqueeze(0))

In [None]:
assert torch.all(map_fluxes_full_image_test == map_fluxes_full_image)
assert torch.all(map_locs_full_image_test == map_locs_full_image)
assert torch.all(map_n_stars_full_test == map_n_stars_full)

In [None]:
# check reconstruction
map_recon_mean, map_loss = psf_transform_lib.get_psf_loss(full_image.squeeze(), full_background.squeeze(),
                                        map_locs_full_image,
                                        map_fluxes_full_image,
                                        n_stars = map_n_stars_full,
                                        psf = psf_init,
                                        pad = 5, grid = None)

In [None]:
map_loss

In [None]:
map_resid = map_recon_mean.squeeze().detach() - full_image

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze())

axarr[1].matshow(map_recon_mean.squeeze().detach())

_resid = map_resid / full_image
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])


# check sampling of variational parameters

In [None]:
n_samples = 10

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full, _, _, _ = \
    wake_sleep_lib.sample_star_encoder(star_encoder, 
                                       full_image.unsqueeze(0).unsqueeze(0), 
                                       full_background.unsqueeze(0).unsqueeze(0),
                                       n_samples = n_samples, return_map = False)

In [None]:
recon_means, loss = psf_transform_lib.get_psf_loss(full_image.squeeze(), full_background.squeeze(),
                                        sampled_locs_full_image,
                                        sampled_fluxes_full_image,
                                        n_stars = sampled_n_stars_full,
                                        psf = psf,
                                        pad = 5, grid = None)

In [None]:
plt.hist(loss.detach())

In [None]:
loss.mean()

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze())

axarr[1].matshow(map_recon_mean.squeeze().detach())
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean.squeeze().detach() - full_image.squeeze())
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])
axarr[2].set_title('map residual')

for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze())

    axarr[1].matshow(recon_means[i].squeeze().detach())
    axarr[1].set_title('sample reconstruction '+ str(i))


    _resid = recon_means[i].squeeze().detach() - full_image.squeeze()
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('sample residual ' + str(i))



In [None]:
x0 = 40
x1 = 40
subimage_slen = 10

fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

plotting_utils.plot_subimage(axarr[0], full_image.squeeze(), 
                            map_locs_full_image.squeeze(), 
                            true_locs.squeeze(), 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)

plotting_utils.plot_subimage(axarr[1], map_recon_mean.squeeze(), 
                            map_locs_full_image.squeeze(), 
                            None, 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean.squeeze().detach() - full_image.squeeze())
plotting_utils.plot_subimage(axarr[2], _resid.squeeze(), 
                            map_locs_full_image.squeeze(), 
                            None, 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig, 
                            diverging_cmap = True)
axarr[1].set_title('map residual')

for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    plotting_utils.plot_subimage(axarr[0], full_image.squeeze(), 
                                sampled_locs_full_image[i].squeeze(), 
                                true_locs.squeeze(), 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig)

    plotting_utils.plot_subimage(axarr[1], recon_means[i].squeeze().detach(), 
                                sampled_locs_full_image[i].squeeze(), 
                                None, 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig)
    axarr[1].set_title('sampled reconstruction')

    _resid = (recon_means[i].squeeze().detach() - full_image.squeeze())
    plotting_utils.plot_subimage(axarr[2], _resid.squeeze(), 
                                sampled_locs_full_image[i].squeeze(), 
                                None, 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig, 
                                diverging_cmap = True)
    axarr[1].set_title('sampled residual')



In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze()[40:60, 40:60])

axarr[1].matshow(map_recon_mean.squeeze().detach()[40:60, 40:60])
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean.squeeze().detach() - full_image.squeeze())[40:60, 40:60]
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])
axarr[2].set_title('map residual')


for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze()[40:60, 40:60])

    axarr[1].matshow(recon_means[i].squeeze().detach()[40:60, 40:60])

    _resid = (recon_means[i].squeeze().detach() - full_image.squeeze())[40:60, 40:60]
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])


In [None]:
for i in range(len(loss)): 
    _full_image = full_image[5:96, 5:96]
    _recon_mean = recon_means[i].squeeze()[5:96, 5:96]
    loss_i = inv_kl_lib.eval_normal_logprob(_full_image, _recon_mean, torch.log(_recon_mean)).sum()
    
    assert loss[i] == - loss_i

# Check my _get_params_from_hidden ...

Here, we condition on a given set of n_stars

In [None]:
n_samples = 10

In [None]:
image_stamps = star_encoder.get_image_stamps(full_image.unsqueeze(0).unsqueeze(0),
                        locs = None, fluxes = None, trim_images = False)[0]
background_stamps = star_encoder.get_image_stamps(full_background.unsqueeze(0).unsqueeze(0),
                    locs = None, fluxes = None, trim_images = False)[0]

In [None]:
h = star_encoder._forward_to_last_hidden(image_stamps, background_stamps).detach()
# get log probs
log_probs = star_encoder._get_logprobs_from_last_hidden_layer(h)

In [None]:
# sample number of stars
from kl_objective_lib import sample_class_weights
n_stars_sampled = sample_class_weights(torch.exp(log_probs), n_samples)

In [None]:
is_on_array = wake_sleep_lib.get_is_on_from_n_stars_2d(n_stars_sampled,
                            star_encoder.max_detections)

In [None]:
logit_loc_mean, logit_loc_logvar, \
    log_flux_mean, log_flux_logvar = \
        wake_sleep_lib._get_params_from_last_hidden_layer_2dn_stars(star_encoder, h, n_stars_sampled)

In [None]:
logit_loc_mean[0][0]

In [None]:
# CHECK THAT THIS MATCHES MY OLD PARAMETERS
for i in range(n_samples): 
    logit_loc_mean_i, logit_loc_logvar_i, \
        log_flux_mean_i, log_flux_logvar_i = \
            star_encoder._get_params_from_last_hidden_layer(h, n_stars_sampled[i])
            
    assert torch.all(logit_loc_mean_i == logit_loc_mean[i])
    assert torch.all(logit_loc_logvar_i == logit_loc_logvar[i])
    assert torch.all(log_flux_mean_i == log_flux_mean[i])
    assert torch.all(log_flux_logvar_i == log_flux_logvar[i])

# sample

In [None]:
logit_loc_sd = torch.exp(0.5 * logit_loc_logvar)
log_flux_sd = torch.exp(0.5 * log_flux_logvar)

In [None]:
# sample locations
locs_randn = torch.randn(logit_loc_mean.shape).to(device)

logit_locs_sampled = logit_loc_mean + locs_randn * logit_loc_sd
subimage_locs_sampled = \
    torch.sigmoid(logit_locs_sampled) * is_on_array.unsqueeze(3).float()

# sample fluxes
fluxes_randn = torch.randn(log_flux_mean.shape).to(device)
log_flux_sampled = log_flux_mean + fluxes_randn * log_flux_sd
subimage_fluxes_sampled = \
    torch.exp(log_flux_sampled) * is_on_array.float()


In [None]:
from inv_kl_objective_lib import eval_normal_logprob

In [None]:
log_q_locs = (eval_normal_logprob(logit_locs_sampled, logit_loc_mean,
                                            logit_loc_logvar) * \
                                            is_on_array.float().unsqueeze(3)).view(n_samples, -1).sum(1)
log_q_fluxes = (eval_normal_logprob(log_flux_sampled, log_flux_mean,
                                    log_flux_logvar) * \
                                    is_on_array.float()).view(n_samples, -1).sum(1)
log_q_n_stars = torch.gather(log_probs, 1, n_stars_sampled.transpose(0, 1)).transpose(0, 1).sum(1)

In [None]:
log_q_locs

In [None]:
log_q_fluxes

In [None]:
log_q_n_stars

# Checkout my IWAE loss

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full, \
    log_q_locs, log_q_fluxes, log_q_n_stars = \
        wake_sleep_lib.sample_star_encoder(star_encoder, 
                                           full_image.unsqueeze(0).unsqueeze(0), 
                                           full_background.unsqueeze(0).unsqueeze(0),
                                            n_samples, return_map = False,
                                            return_log_q = True)

In [None]:
recon_means, neg_logprob = psf_transform_lib.get_psf_loss(full_image.squeeze(), full_background.squeeze(),
                                        sampled_locs_full_image,
                                        sampled_fluxes_full_image,
                                        n_stars = sampled_n_stars_full,
                                        psf = psf,
                                        pad = 5, grid = None)

In [None]:
 - neg_logprob

In [None]:
log_q_locs

In [None]:
log_q_fluxes

In [None]:
log_q_n_stars

In [None]:
log_pq = - neg_logprob - log_q_locs - log_q_fluxes - log_q_n_stars

In [None]:
log_pq - np.log(n_samples)

In [None]:
torch.logsumexp(log_pq - np.log(n_samples), 0)

In [None]:
np.pi