In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import sdss_dataset_lib
import plotting_utils
import image_statistics_lib

import inv_kl_objective_lib as inv_kl_lib

import image_utils

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(22)
_ = torch.manual_seed(22)

# Compare true SDSS image with simulated SDSS image

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=str(sdss_hubble_data.psf_file), 
                                                            slen = sdss_hubble_data.slen, 
                                                            sky_intensity = 686.)

In [None]:
# simulate data using hubble parameters
sim_images_full = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                        fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                        n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                        add_noise = True) - simulator.sky_intensity + \
                        sdss_hubble_data.sdss_background.unsqueeze(0)

# the oberved data 
sdss_images_full = sdss_hubble_data.sdss_image.unsqueeze(0)

# get true parameters
backgrounds_full = sdss_hubble_data.sdss_background.unsqueeze(0)

which_bright = sdss_hubble_data.fluxes > 1300
true_full_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
true_full_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)

# Check out residuals

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

im0 = axarr[0].matshow(sdss_images_full.squeeze()); 
f.colorbar(im0, ax = axarr[0])
axarr[0].set_title('simulated sdss image')

im1 = axarr[1].matshow(sim_images_full.squeeze()); 
f.colorbar(im1, ax = axarr[1])
axarr[1].set_title('true sdss image')


# residual = (sim_images_full.squeeze() - sdss_images_full.squeeze()) / sdss_images_full.squeeze()
residual = torch.log10(sim_images_full.squeeze()) - torch.log10(sdss_images_full.squeeze())
vmax = residual[10:90, 10:90].abs().max()
im2 = axarr[2].matshow(residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr')); 
f.colorbar(im2, ax = axarr[2])
axarr[2].set_title('residual')

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

x0 = 0
x1 = 20

plotting_utils.plot_subimage(axarr[0], sdss_images_full.squeeze(), 
                            None, 
                            true_full_locs.squeeze(), 
                            x0 = x0,
                            x1 = x1, 
                            subimage_slen = 20, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[1], sim_images_full.squeeze(), 
                            None, 
                            true_full_locs.squeeze(), 
                            x0 = x0,
                            x1 = x1, 
                            subimage_slen = 20, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[2], residual, 
                            None, 
                            true_full_locs.squeeze(), 
                            x0 = x0,
                            x1 = x1, 
                            subimage_slen = 20, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())

In [None]:
one_star_img = sdss_hubble_data.sdss_image_full.squeeze()[200:301, 1450:1551]
one_star_back = sdss_hubble_data.sdss_background_full.squeeze()[200:301, 1450:1551]
plt.matshow(one_star_img)

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 9,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = 1,
                                            max_detections = 4)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder-10072019',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    star_encoder.get_results_on_full_image(torch.Tensor(one_star_img).unsqueeze(0).unsqueeze(0), 
                                           torch.Tensor(one_star_back).unsqueeze(0).unsqueeze(0))

In [None]:
# get reconstructed mean
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = map_n_stars_full, 
                                                 add_noise = False).squeeze()

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

im0 = axarr[0].matshow(one_star_img.squeeze()); 
f.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean.squeeze()); 
f.colorbar(im1, ax = axarr[1])

residual = torch.log10(torch.Tensor(one_star_img).squeeze()) - torch.log10(vae_recon_mean.squeeze())
vmax = residual[10:90, 10:90].abs().max()
im2 = axarr[2].matshow(residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr')); 
f.colorbar(im2, ax = axarr[2])
axarr[2].set_title('residual')