In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_lib
import sdss_dataset_lib
import plotting_utils
import image_statistics_lib
import utils
import psf_transform_lib
import sleep_lib

import image_utils

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
np.random.seed(22)
_ = torch.manual_seed(22)

# Load data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = [2, 3])

In [None]:
f_min = 1000.
n_elect_per_nmgy = sdss_hubble_data.nelec_per_nmgy.mean()

In [None]:
# the oberved data 
sdss_images_full = sdss_hubble_data.sdss_image.unsqueeze(0)

# get true parameters
backgrounds_full = sdss_hubble_data.sdss_background.unsqueeze(0)

which_bright = sdss_hubble_data.fluxes[:, 0] > f_min
true_full_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
true_full_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)

In [None]:
# get PSFs 
bands = [2, 3]
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
init_psf_params = psf_transform_lib.get_psf_params(
                                    psfield_file,
                                    bands = bands)
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params.to(device))
psf_og = power_law_psf.forward().detach()

In [None]:
# default background
import wake_lib
init_background_params = torch.zeros(len(bands), 3).to(device)
init_background_params[:, 0] = torch.Tensor([686., 1123.])
planar_background = wake_lib.PlanarBackground(image_slen = sdss_images_full.shape[-1], 
                            init_background_params = init_background_params.to(device))
background = planar_background.forward().detach()


In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_og, sdss_images_full.shape[-1], 
                                                 background, 
                                                 transpose_psf = False)

# Load VAE

In [None]:
star_encoder = starnet_lib.StarEncoder(full_slen = sdss_images_full.shape[-1], 
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = 2,
                                            max_detections = 2, 
                                          estimate_flux = True)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/results_2020-02-17/starnet_ri',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

# OK first, let us check results on data simulated using hubble parameters

In [None]:
# simulate data using hubble parameters
# two options for here: can either draw from bright stars only or all the stars

draw_all_stars = False
if draw_all_stars: 
    sim_images_full = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                            fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                            n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                            add_noise = True) 
else: 
    sim_images_full = simulator.draw_image_from_params(locs = true_full_locs,
                            fluxes = true_full_fluxes,
                            n_stars = torch.Tensor([true_full_locs.shape[1]]).type(torch.LongTensor), 
                            add_noise = True) 

### check residuals between simulated image and true image

In [None]:
band = 0

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

im0 = axarr[0].matshow(sdss_images_full[0, band]); 
f.colorbar(im0, ax = axarr[0])
axarr[0].set_title('true sdss image')

im1 = axarr[1].matshow(sim_images_full[0, band]); 
f.colorbar(im1, ax = axarr[1])
axarr[1].set_title('simulated sdss image')


residual = torch.log10(sim_images_full[0, band]) - torch.log10(sdss_images_full[0, band])
vmax = residual[10:90, 10:90].abs().max()
im2 = axarr[2].matshow(residual[10:90, 10:90], vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr')); 
f.colorbar(im2, ax = axarr[2])
axarr[2].set_title('residual')

## Check loss

In [None]:
# check loss 
loss, counter_loss, locs_loss, fluxes_loss, perm_indx = \
    sleep_lib.get_encoder_loss(star_encoder, sim_images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes)[0:5]

In [None]:
print('loss: {:06f}'.format(loss))

In [None]:
print(counter_loss.mean())
print(locs_loss.mean())
print(fluxes_loss.mean())

## Check reconstructions

In [None]:
# get parameters on the simulated image 
map_locs_sim_image, map_fluxes_sim_image, map_n_stars_sim_image = \
        star_encoder.sample_star_encoder(sim_images_full, backgrounds_full, 
                                        return_map_n_stars = True, 
                                        return_map_star_params = True)[0:3]

In [None]:
# get reconstructed mean
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_sim_image, 
                                                fluxes = map_fluxes_sim_image,
                                                 n_stars = map_n_stars_sim_image, 
                                                 add_noise = False).detach()

In [None]:
band

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(sim_images_full[0, band][5:95, 5:95])
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean[0, band][5:95, 5:95])
fig.colorbar(im1, ax = axarr[1])

residual = torch.log10(vae_recon_mean[0, band]) - torch.log10(sim_images_full[0, band])
_residual = (residual * 2.5)[5:95, 5:95]
# (torch.log(vae_recon_mean.squeeze()) - torch.log(images_full.squeeze()))[10:90, 10:90]
vmax = _residual.abs().max()
im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])

In [None]:
def get_which_tile(x0, x1, tile_coords, edge_padding, stamp_slen): 
    coords = tile_coords + edge_padding
    
    view_slen = stamp_slen - 2 * edge_padding
    
    indx = torch.where((x0 > coords[:, 0]) & \
                       (x0 < coords[:, 0] + view_slen) & \
                       (x1 > coords[:, 1]) & \
                       (x1 < coords[:, 1] + view_slen))
        
    return tile_coords[indx], indx

In [None]:
get_which_tile(77, 67, star_encoder.tile_coords, star_encoder.edge_padding, star_encoder.stamp_slen)

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

indx = 1808 # int(np.random.choice(star_encoder.tile_coords.shape[0], 1))

results = plotting_utils.plot_subimage(axarr[0], sim_images_full[0, band],
                            map_locs_sim_image.squeeze(), 
                            true_full_locs.squeeze(), 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[1], vae_recon_mean[0, band],
                            map_locs_sim_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = torch.log10(vae_recon_mean[0, band]) - torch.log10(sim_images_full[0, band])
plotting_utils.plot_subimage(axarr[2], foo * 2.5, 
                            map_locs_sim_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

axarr[0].axvline(x=2, color = 'r')
axarr[0].axvline(x=4, color = 'r')
axarr[0].axhline(y=2, color = 'r')
axarr[0].axhline(y=4, color = 'r')

axarr[1].axvline(x=2, color = 'r')
axarr[1].axvline(x=4, color = 'r')
axarr[1].axhline(y=2, color = 'r')
axarr[1].axhline(y=4, color = 'r')

## Check summary statistics

In [None]:
pad = 0

In [None]:
completeness, tpr, completeness_bool, tpr_bool = \
    image_statistics_lib.get_summary_stats(map_locs_sim_image.squeeze(), 
                                           true_full_locs.squeeze(), 
                                           star_encoder.full_slen, 
                                           map_fluxes_sim_image.squeeze(0)[:, 0], 
                                           true_full_fluxes.squeeze(0)[:, 0], 
                                          n_elect_per_nmgy, pad = 0)
    
print('completeness: {:0.3f}'.format(completeness))
print('true positive rate: {:0.3f}'.format(tpr))

In [None]:
completeness_vec, mag_vec_c, counts = \
    image_statistics_lib.get_completeness_vec(map_locs_sim_image.squeeze(), 
                                           true_full_locs.squeeze(), 
                                           star_encoder.full_slen, 
                                           map_fluxes_sim_image.squeeze(0)[:, 0], 
                                           true_full_fluxes.squeeze(0)[:, 0], 
                                             n_elect_per_nmgy)

plt.plot(mag_vec_c[0:-1], completeness_vec, '--x', label = 'sim')
plt.xlabel('true log flux')
plt.ylabel('completeness')


In [None]:
counts

In [None]:
true_mags = sdss_dataset_lib.convert_nmgy_to_mag(true_full_fluxes.squeeze(0)[:, 0] / n_elect_per_nmgy)
est_mags = sdss_dataset_lib.convert_nmgy_to_mag(map_fluxes_sim_image.squeeze(0)[:, 0] / n_elect_per_nmgy)

In [None]:
indx = (completeness_bool == 0) & (true_mags < 16)
true_full_locs.squeeze()[indx]

In [None]:
tpr_vec, mag_vec_t, counts = \
    image_statistics_lib.get_tpr_vec(map_locs_sim_image.squeeze(), 
                                           true_full_locs.squeeze(), 
                                           star_encoder.full_slen, 
                                           map_fluxes_sim_image.squeeze(0)[:, 0], 
                                           true_full_fluxes.squeeze(0)[:, 0], 
                                    n_elect_per_nmgy)


In [None]:
counts

In [None]:
plt.plot(mag_vec_t[0:-1], tpr_vec, '--x', label = 'sim')
plt.xlabel('est magnitude')
plt.ylabel('tpr')


# Now compare against results on the true image

In [None]:
# get parameters on the simulated image 
map_locs_sdss_image, map_fluxes_sdss_image, map_n_stars_sdss_image = \
        star_encoder.sample_star_encoder(sdss_images_full, backgrounds_full, 
                                        return_map_n_stars = True, 
                                        return_map_star_params = True)[0:3]

### Compare losses

In [None]:
loss, counter_loss, locs_loss, fluxes_loss, perm_indx = \
    sleep_lib.get_encoder_loss(star_encoder, sim_images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes, 
                               use_l2_loss = True)[0:5]

print('loss: {:06f}'.format(loss))
print(counter_loss.mean())
print(locs_loss.mean())
print(fluxes_loss.mean())

In [None]:
loss, counter_loss, locs_loss, fluxes_loss, perm_indx = \
    sleep_lib.get_encoder_loss(star_encoder, sdss_images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes, 
                               use_l2_loss = True)[0:5]

print('loss: {:06f}'.format(loss))
print(counter_loss.mean())
print(locs_loss.mean())
print(fluxes_loss.mean())

In [None]:
completeness2, tpr2, completeness2_bool, tpr2_bool = \
    image_statistics_lib.get_summary_stats(map_locs_sdss_image.squeeze(), 
                                           true_full_locs.squeeze(), 
                                           star_encoder.full_slen, 
                                           map_fluxes_sdss_image.squeeze(0)[:, 0], 
                                           true_full_fluxes.squeeze(0)[:, 0], 
                                          n_elect_per_nmgy)
    
print('completeness: {:0.3f}'.format(completeness2))
print('true positive rate: {:0.3f}'.format(tpr2))

In [None]:
completeness2_vec, mag2_vec_c = \
    image_statistics_lib.get_completeness_vec(map_locs_sdss_image.squeeze(), 
                                           true_full_locs.squeeze(), 
                                           star_encoder.full_slen, 
                                           map_fluxes_sdss_image.squeeze(0)[:, 0], 
                                           true_full_fluxes.squeeze(0)[:, 0], 
                                             n_elect_per_nmgy)[0:2]

plt.plot(mag_vec_c[0:-1], completeness_vec, '--x', label = 'sim')
plt.plot(mag2_vec_c[0:-1], completeness2_vec, '--x', label = 'sdss')
plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')


In [None]:
tpr2_vec, mag2_vec_t = \
    image_statistics_lib.get_tpr_vec(map_locs_sdss_image.squeeze(), 
                                           true_full_locs.squeeze(), 
                                           star_encoder.full_slen, 
                                           map_fluxes_sdss_image.squeeze(0)[:, 0], 
                                           true_full_fluxes.squeeze(0)[:, 0], 
                                    n_elect_per_nmgy)[0:2]

plt.plot(mag_vec_t[0:-1], tpr_vec, '--x', label = 'sim')
plt.plot(mag2_vec_t[0:-1], tpr2_vec, '--x', label = 'sdss')

plt.legend()
plt.xlabel('est magnitude')
plt.ylabel('tpr')


## Compare reconstructions

In [None]:
vae_recon_mean2 = simulator.draw_image_from_params(locs = map_locs_sdss_image, 
                                                fluxes = map_fluxes_sdss_image,
                                                 n_stars = map_n_stars_sdss_image, 
                                                 add_noise = False).detach()

In [None]:
# reconstruted simulated
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(sim_images_full[0, band][5:95, 5:95])
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean[0, band][5:95, 5:95])
fig.colorbar(im1, ax = axarr[1])

residual = torch.log10(vae_recon_mean[0, band]) - torch.log10(sim_images_full[0, band])
_residual = (residual * 2.5)[5:95, 5:95]
# (torch.log(vae_recon_mean.squeeze()) - torch.log(images_full.squeeze()))[10:90, 10:90]
vmax = _residual.abs().max()
im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])

In [None]:
# reconstructed real
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(sdss_images_full[0, band][5:95, 5:95])
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean2[0, band][5:95, 5:95])
fig.colorbar(im1, ax = axarr[1])

residual = torch.log10(vae_recon_mean2[0, band]) - torch.log10(sdss_images_full[0, band])
_residual = (residual * 2.5)[5:95, 5:95]
# (torch.log(vae_recon_mean.squeeze()) - torch.log(images_full.squeeze()))[10:90, 10:90]
vmax = _residual.abs().max()
im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])

In [None]:
f, axarr = plt.subplots(2, 3, figsize=(16, 8))

indx = int(np.random.choice(star_encoder.tile_coords.shape[0], 1))

###############
# simulated data results
plotting_utils.plot_subimage(axarr[0, 0], sim_images_full[0, band],
                            map_locs_sim_image.squeeze(), 
                            true_full_locs.squeeze(), 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[0, 1], vae_recon_mean[0, band],
                            map_locs_sim_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = torch.log10(vae_recon_mean[0, band]) - torch.log10(sim_images_full[0, band])
plotting_utils.plot_subimage(axarr[0, 2], foo * 2.5, 
                            map_locs_sim_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)


###############
# real data results
plotting_utils.plot_subimage(axarr[1, 0], sdss_images_full[0, band],
                            map_locs_sdss_image.squeeze(), 
                            true_full_locs.squeeze(), 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[1, 1], vae_recon_mean2[0, band],
                            map_locs_sdss_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = torch.log10(vae_recon_mean2[0, band]) - torch.log10(sdss_images_full[0, band])
plotting_utils.plot_subimage(axarr[1, 2], foo * 2.5, 
                            map_locs_sdss_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)


for i in range(2): 
    axarr[i, 0].axvline(x=2, color = 'r')
    axarr[i, 0].axvline(x=4, color = 'r')
    axarr[i, 0].axhline(y=2, color = 'r')
    axarr[i, 0].axhline(y=4, color = 'r')

    axarr[i, 1].axvline(x=2, color = 'r')
    axarr[i, 1].axvline(x=4, color = 'r')
    axarr[i, 1].axhline(y=2, color = 'r')
    axarr[i, 1].axhline(y=4, color = 'r')