In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import kl_objective_lib as kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = 650, x1 = 120)

# psf file 
psf_fit_file = str(sdss_hubble_data.psf_file)

# image 
full_image = sdss_hubble_data.sdss_image.squeeze()
full_background = sdss_hubble_data.sdss_background.squeeze()

# true parameters
which_bright = (sdss_hubble_data.fluxes > 1000.)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# Our simulator

In [None]:
from copy import deepcopy
psf_og = sdss_psf.psf_at_points(0, 0, psf_fit_file = str(sdss_hubble_data.psf_file))

psf_init = torch.Tensor(simulated_datasets_lib._expand_psf(psf_og, full_image.shape[-1]))

psf_init.shape

# define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 9,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = 1,
                                           max_detections = 2)

# load vae

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet-10162019-reweighted', 
                                       map_location=lambda storage, loc: storage))
star_encoder.eval(); 

# Check out map estimates

In [None]:
map_locs_full_image, map_fluxes_full_image, map_n_stars_full, _, _, _ = \
    star_encoder.sample_star_encoder(full_image.unsqueeze(0).unsqueeze(0), 
                                       full_background.unsqueeze(0).unsqueeze(0),
                                       n_samples = 1, return_map = True)

In [None]:
map_recon_mean = simulated_datasets_lib.plot_multiple_stars(full_image.shape[-1], 
                                                         map_locs_full_image, 
                                                         map_n_stars_full, 
                                                         map_fluxes_full_image, psf_init) + \
                                            full_background.unsqueeze(0).unsqueeze(0)

In [None]:
map_resid = map_recon_mean.squeeze().detach() - full_image

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze())

axarr[1].matshow(map_recon_mean.squeeze().detach())

_resid = map_resid / full_image
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])


# check sampling of variational parameters

In [None]:
n_samples = 10

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full, _, _, _ = \
    star_encoder.sample_star_encoder(full_image.unsqueeze(0).unsqueeze(0), 
                                       full_background.unsqueeze(0).unsqueeze(0),
                                       n_samples = n_samples, return_map = False)

In [None]:
sampled_n_stars_full

In [None]:
recon_means = simulated_datasets_lib.plot_multiple_stars(full_image.shape[-1], 
                                                         sampled_locs_full_image, 
                                                         sampled_n_stars_full, 
                                                         sampled_fluxes_full_image, psf_init) + \
                                            full_background.unsqueeze(0).unsqueeze(0)

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze())

axarr[1].matshow(map_recon_mean.squeeze().detach())
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean.squeeze().detach() - full_image.squeeze())
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])
axarr[2].set_title('map residual')

for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze())
    axarr[0].set_title('observed')

    axarr[1].matshow(recon_means[i].squeeze().detach())
    axarr[1].set_title('sample reconstruction '+ str(i))


    _resid = recon_means[i].squeeze().detach() - full_image.squeeze()
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('sample residual ' + str(i))



In [None]:
x0 = 30
x1 = 50
subimage_slen = 10

fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

plotting_utils.plot_subimage(axarr[0], full_image.squeeze(), 
                            map_locs_full_image.squeeze(), 
                            true_locs.squeeze(), 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[0].set_title('observed')

plotting_utils.plot_subimage(axarr[1], map_recon_mean.squeeze(), 
                            map_locs_full_image.squeeze(), 
                            None, 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean.squeeze().detach() - full_image.squeeze())
plotting_utils.plot_subimage(axarr[2], _resid.squeeze(), 
                            map_locs_full_image.squeeze(), 
                            None, 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig, 
                            diverging_cmap = True)
axarr[1].set_title('map residual')

for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    plotting_utils.plot_subimage(axarr[0], full_image.squeeze(), 
                                sampled_locs_full_image[i].squeeze(), 
                                true_locs.squeeze(), 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig)
    axarr[0].set_title('observed')

    plotting_utils.plot_subimage(axarr[1], recon_means[i].squeeze().detach(), 
                                sampled_locs_full_image[i].squeeze(), 
                                None, 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig)
    axarr[1].set_title('sampled reconstruction')

    _resid = (recon_means[i].squeeze().detach() - full_image.squeeze())
    plotting_utils.plot_subimage(axarr[2], _resid.squeeze(), 
                                sampled_locs_full_image[i].squeeze(), 
                                None, 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig, 
                                diverging_cmap = True)
    axarr[1].set_title('sampled residual')



In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(full_image.squeeze()[40:60, 40:60])
axarr[0].set_title('observed')

axarr[1].matshow(map_recon_mean.squeeze().detach()[40:60, 40:60])
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean.squeeze().detach() - full_image.squeeze())[40:60, 40:60]
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])
axarr[2].set_title('map residual')


for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze()[40:60, 40:60])
    axarr[0].set_title('observed')

    axarr[1].matshow(recon_means[i].squeeze().detach()[40:60, 40:60])

    _resid = (recon_means[i].squeeze().detach() - full_image.squeeze())[40:60, 40:60]
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])
