In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import kl_objective_lib as kl_lib
import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 101
data_params['min_stars'] = 2000
data_params['max_stars'] = 2000
data_params['alpha'] = 0.5


In [None]:
use_simulated_data = True

if use_simulated_data: 
    psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
    
    n_images = 1

    simulated_dataset = \
        simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                                data_params,
                                n_images = n_images,
                                add_noise = True)

    full_image = simulated_dataset.images.squeeze()
    full_background = torch.ones(full_image.shape) * 686.

    true_locs = simulated_dataset.locs.squeeze()
    true_fluxes = simulated_dataset.fluxes.squeeze()
    
else: 
    sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()
    
    # psf file 
    psf_fit_file = str(sdss_hubble_data.psf_file)
    
    # image 
    full_image = sdss_hubble_data.sdss_image.squeeze()
    full_background = sdss_hubble_data.sdss_background.squeeze()
    
    # true parameters
    which_bright = (sdss_hubble_data.fluxes > 1300.)
    true_locs = sdss_hubble_data.locs[which_bright]
    true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

# Our simulator

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = 0.)


# Encoder trained on KL 

In [None]:
star_encoder_kl = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 9,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = 1,
                                            max_detections = 4)

In [None]:
star_encoder_kl.load_state_dict(torch.load('../fits/starnet_invKL_encoder-10092019-reweighted_samples', 
                               map_location=lambda storage, loc: storage))
star_encoder_kl.eval(); 

In [None]:
kl_lib.get_kl_loss(star_encoder_kl,
                    full_image.unsqueeze(0).unsqueeze(0),
                    full_background.unsqueeze(0).unsqueeze(0),
                    simulator)[0]

In [None]:
# get parameters on the full image 
map_locs_full_image_kl, map_fluxes_full_image_kl, map_n_stars_full_kl = \
    star_encoder_kl.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                           full_background.unsqueeze(0).unsqueeze(0))

In [None]:
kl_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image_kl, 
                                                fluxes = map_fluxes_full_image_kl,
                                                 n_stars = map_n_stars_full_kl, 
                                                 add_noise = False).squeeze()

kl_residuals = kl_recon_mean - (full_image - full_background)

# encoder trained on inv. KL

In [None]:
star_encoder_invkl = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 9,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = 1,
                                            max_detections = 4)

In [None]:
star_encoder_invkl.load_state_dict(torch.load('../fits/starnet_invKL_encoder-10072019', 
                               map_location=lambda storage, loc: storage))
star_encoder_invkl.eval(); 

In [None]:
kl_lib.get_kl_loss(star_encoder_invkl,
                    full_image.unsqueeze(0).unsqueeze(0),
                    full_background.unsqueeze(0).unsqueeze(0),
                    simulator)[0]

In [None]:
# get parameters on the full image 
map_locs_full_image_invkl, map_fluxes_full_image_invkl, map_n_stars_full_invkl = \
    star_encoder_invkl.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                           full_background.unsqueeze(0).unsqueeze(0))

In [None]:
invkl_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image_invkl, 
                                                fluxes = map_fluxes_full_image_invkl,
                                                 n_stars = map_n_stars_full_invkl, 
                                                 add_noise = False).squeeze()

invkl_residuals = invkl_recon_mean - (full_image - full_background)

# Compare

In [None]:
subimage_slen = 10

# possible coordinates
x0_vec = np.arange(10, 90, subimage_slen)
x1_vec = np.arange(10, 90, subimage_slen)

In [None]:
map_locs_full_image_invkl

In [None]:
fig, axarr = plt.subplots(2, 3, figsize=(15, 12))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

###################
# Plot catalogs
##################
# my catalog
for j in range(2):
    plotting_utils.plot_subimage(axarr[j, 0], full_image, 
                                 map_locs_full_image_invkl.squeeze(), 
                                 true_locs, 
                                 x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                 global_fig = fig)
    axarr[j, 0].set_title('observed; coords: {}\n'.format([x0, x1]));

    # portillos catalogue
    _est_locs = map_locs_full_image_kl.squeeze() * (full_image.shape[-1] - 1)
    which_locs = (_est_locs[:, 0] > x0) & \
                    (_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                    (_est_locs[:, 1] > x1) & \
                    (_est_locs[:, 1] < (x1 + subimage_slen - 1))
    est_locs = (_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
    axarr[j, 0].scatter(est_locs[:, 1], est_locs[:, 0], color = 'c', marker = 'x')

#######################
# Reconstructions 
#######################
# my reconstruction
plotting_utils.plot_subimage(axarr[0, 1], invkl_recon_mean, 
                             map_locs_full_image_invkl.squeeze(), 
                             None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0, 1].set_title('inv kl reconstruction \n');

# Portillos reconstruction
plotting_utils.plot_subimage(axarr[1, 1], kl_recon_mean, 
                             map_locs_full_image_kl.squeeze(), 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            color = 'c', marker = 'x')
axarr[1, 1].set_title('kl reconstruction \n');

# ######################
# # residuals
# ######################
vmax1 = torch.abs((kl_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
vmax2 = torch.abs((invkl_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()

vmax = torch.max(torch.Tensor([vmax1, vmax2]))

# my residuals
plotting_utils.plot_subimage(axarr[0, 2], invkl_residuals / full_image, 
                            map_locs_full_image_invkl.squeeze(), 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[0, 2].set_title('inv residuals\n');



# my residuals
plotting_utils.plot_subimage(axarr[1, 2], kl_residuals / full_image, 
                            map_locs_full_image_kl.squeeze(), 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax, 
                            color = 'c')

axarr[1, 2].set_title('kl residuals\n');


# Checkout some summary statistics

In [None]:
# we only look at locations within 10-90; 
# Portillos doesn't detect on the edge

def filter_params(locs, fluxes, slen): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > 10) & (_locs[:, 0] < 90) & \
                        (_locs[:, 1] > 10) & (_locs[:, 1] < 90) 
        
    
    return locs[which_params], fluxes[which_params]


In [None]:
kl_est_locs, kl_est_fluxes = filter_params(map_locs_full_image_kl.squeeze(), 
                                           map_fluxes_full_image_kl.squeeze(), 
                                           full_image.shape[-1])

invkl_est_locs, invkl_est_fluxes = filter_params(map_locs_full_image_invkl.squeeze(), 
                                           map_fluxes_full_image_invkl.squeeze(), 
                                           full_image.shape[-1])

true_locs, true_fluxes = filter_params(true_locs, true_fluxes, 
                                       full_image.shape[-1])

In [None]:
print('kl n_stars: ', len(kl_est_locs))
print('invkl n_stars: ', len(invkl_est_locs))
print('true n_stars: ', len(true_locs))

In [None]:
_recon_kl = simulator.draw_image_from_params(locs = kl_est_locs.unsqueeze(0), 
                                    fluxes = kl_est_fluxes.unsqueeze(0),
                                     n_stars = torch.Tensor([len(kl_est_fluxes)]).type(torch.LongTensor), 
                                     add_noise = False).squeeze()

_recon_invkl = \
    simulator.draw_image_from_params(locs = invkl_est_locs.unsqueeze(0), 
                                    fluxes = invkl_est_fluxes.unsqueeze(0),
                                     n_stars = torch.Tensor([len(invkl_est_fluxes)]).type(torch.LongTensor), 
                                     add_noise = False).squeeze()

_recon_truth = \
    simulator.draw_image_from_params(locs = true_locs.unsqueeze(0), 
                                    fluxes = true_fluxes.unsqueeze(0),
                                     n_stars = torch.Tensor([len(true_locs)]).type(torch.LongTensor), 
                                     add_noise = False).squeeze()


fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(_recon_kl)
axarr[1].matshow(_recon_invkl)
axarr[2].matshow(_recon_truth)

In [None]:
# check out MSEs

_image = full_image[10:90, 10:90] - full_background[10:90, 10:90]

_kl_residual = _recon_kl[10:90, 10:90] - _image
_invkl_residual = _recon_invkl[10:90, 10:90] - _image
_true_residual = _recon_truth[10:90, 10:90] - _image

print('kl_mse: ', torch.mean(_kl_residual**2))
print('invkl_mse: ', torch.mean(_invkl_residual**2))
print('truth_mse: ', torch.mean(_true_residual**2))

fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

im1 = axarr[0].matshow(_kl_residual)
im2 = axarr[1].matshow(_invkl_residual)
im3 = axarr[2].matshow(_true_residual)

fig.colorbar(im1, ax = axarr[0])
fig.colorbar(im2, ax = axarr[1])
fig.colorbar(im3, ax = axarr[2])

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(20, 5))

axarr[0].hist((_kl_residual / _image).flatten(), bins = 100);

axarr[1].hist((_invkl_residual / _image).flatten(), bins = 100);

axarr[2].hist((_true_residual / _image).flatten(), bins = 100);

# Get summary statistics

These are rather coarse measures. My completeness does not take into account the fact that several true stars might be matched with just one estimated star (so not all the true stars were detected); conversely my true positive rate does not take into account that several estimated stars might be matched with just one true star (so only one estimated star is a true positive). 

I tried the Hungarian algorithm to find a minimal matching, but this gave weird results because we're searching for a permutation that minimizes the **global** cost of the matching. 

In [None]:
import image_statistics_lib

In [None]:
# completeness and tpr using locations only
kl_completeness, kl_tpr = \
    image_statistics_lib.get_summary_stats(kl_est_locs, true_locs, 
                                           full_image.shape[-1], None, None)
invkl_completeness, invkl_tpr = \
    image_statistics_lib.get_summary_stats(invkl_est_locs, true_locs, 
                                           full_image.shape[-1], None, None)

    
print('kl completeness: {:0.3f}'.format(kl_completeness))
print('invkl completeness: {:0.3f}\n'.format(invkl_completeness))

print('kl positive rate: {:0.3f}'.format(kl_tpr))
print('invkl true positive rate: {:0.3f}'.format(invkl_tpr))

In [None]:
# take into account fluxes
kl_completeness, kl_tpr = \
    image_statistics_lib.get_summary_stats(kl_est_locs, true_locs, 
                                           full_image.shape[-1], 
                                           kl_est_fluxes, true_fluxes)
    
invkl_completeness, invkl_tpr = \
    image_statistics_lib.get_summary_stats(invkl_est_locs, true_locs, 
                                           full_image.shape[-1], 
                                           invkl_est_fluxes, true_fluxes)

    
print('kl completeness: {:0.3f}'.format(kl_completeness))
print('invkl completeness: {:0.3f}\n'.format(invkl_completeness))

print('kl positive rate: {:0.3f}'.format(kl_tpr))
print('invkl true positive rate: {:0.3f}'.format(invkl_tpr))

In [None]:
kl_completeness_vec, kl_mag_vec = \
    image_statistics_lib.get_completeness_vec(kl_est_locs, true_locs, full_image.shape[-1],
                                              kl_est_fluxes, true_fluxes)

invkl_completeness_vec, invkl_mag_vec = \
    image_statistics_lib.get_completeness_vec(invkl_est_locs, true_locs, full_image.shape[-1],
                                              invkl_est_fluxes, true_fluxes)

plt.plot(kl_mag_vec[0:-1], kl_completeness_vec, '--x', label = 'kl')
plt.plot(invkl_mag_vec[0:-1], invkl_completeness_vec, '--x', label = 'inv kl')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
kl_tpr_vec, kl_mag_vec = \
    image_statistics_lib.get_tpr_vec(kl_est_locs, true_locs, full_image.shape[-1],
                                    kl_est_fluxes, true_fluxes)

invkl_tpr_vec, invkl_mag_vec = \
    image_statistics_lib.get_tpr_vec(invkl_est_locs, true_locs, full_image.shape[-1],
                                    invkl_est_fluxes, true_fluxes)

plt.plot(kl_mag_vec[0:-1], kl_tpr_vec, '--x', label = 'kl')
plt.plot(invkl_mag_vec[0:-1], invkl_tpr_vec, '--x', label = 'invkl')

plt.legend()
plt.xlabel('estimated log flux')
plt.ylabel('true positive rate')