In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import fitsio

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

In [None]:
# load psf
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = np.array([psf_r, psf_i])

In [None]:
# sky intensity: for the r and i band
sky_intensity = torch.Tensor([686., 1123.])


In [None]:
use_simulated_data = True

if use_simulated_data: 
   
    n_images = 1

    simulated_dataset = \
        simulated_datasets_lib.load_dataset_from_params(psf_og,
                                data_params,
                                n_images = n_images,
                                sky_intensity = sky_intensity,
                                add_noise = True)

    full_image = simulated_dataset.images
    full_background = simulated_dataset.background

    true_locs = simulated_dataset.locs
    true_fluxes = simulated_dataset.fluxes
    true_n_stars = simulated_dataset.n_stars
    
else: 
    sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = [2, 3])
    
    # image 
    full_image = sdss_hubble_data.sdss_image.unsqueeze(0)
    full_background = sdss_hubble_data.sdss_background.unsqueeze(0)
    
    # true parameters
    which_bright = (sdss_hubble_data.fluxes[:, 0] > data_params['f_min'])
    true_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
    true_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)


In [None]:
plt.matshow(full_image[0, 0])
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# Our simulator

In [None]:
# simulator for one band only
simulator1 = simulated_datasets_lib.StarSimulator(psf=psf_og[0:1], 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = sky_intensity[0:1])


In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf=psf_og, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = sky_intensity)


# First set of results 

In [None]:
star_encoder1 = starnet_vae_lib.StarEncoder(full_slen = 101,
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = 1,
                                           max_detections = 2)


In [None]:
star_encoder1.load_state_dict(torch.load('../fits/results_11052019/starnet3', 
                               map_location=lambda storage, loc: storage))
star_encoder1.eval(); 

In [None]:
true_fluxes[:, :, 0:1]

In [None]:
# get image stamps
image_stamps, true_subimage_locs, true_subimage_fluxes, \
    true_subimage_n_stars, true_is_on_array = \
        star_encoder1.get_image_stamps(full_image[0:1, 0:1], 
                                       true_locs, 
                                       true_fluxes[:, :, 0:1], 
                                      trim_images = False)

use_true_n_stars = False
if use_true_n_stars: 
    _n_stars = true_subimage_n_stars.clamp(max = star_encoder1.max_detections)
else: 
    _n_stars = None


In [None]:
# images_sim = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
#                                 fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
#                                 n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
#                                 add_noise = False) - simulator.sky_intensity + full_background
        
# backgrounds_sim = torch.ones(images_sim.shape) * data_params['sky_intensity']
        
# # get parameters on the full image 
# map_locs_full_image1, map_fluxes_full_image1, map_n_stars_full1 = \
#     star_encoder1.get_results_on_full_image(images_sim, 
#                                            backgrounds_sim)

# plt.matshow((images_sim.squeeze() - full_image.squeeze()) / full_image.squeeze(), 
#            cmap=plt.get_cmap('bwr'))
# plt.colorbar()

# get parameters on the full image 
map_locs_full_image1, map_fluxes_full_image1, map_n_stars_full1 = \
    star_encoder1.sample_star_encoder(full_image[0:1, 0:1], 
                                       full_background[0:1, 0:1], 
                                       n_stars = _n_stars, return_map = True)[0:3]

In [None]:
# get reconstruction means
recon_mean1 = simulator1.draw_image_from_params(locs = map_locs_full_image1, 
                                                fluxes = map_fluxes_full_image1,
                                                 n_stars = map_n_stars_full1, 
                                                 add_noise = False)

residuals1 = recon_mean1[0, 0] - full_image[0, 0]

# Second set of results

In [None]:
star_encoder2 = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = 2,
                                            max_detections = 2)

In [None]:
star_encoder2.load_state_dict(torch.load('../fits/results_11122019/starnet_ri', 
                               map_location=lambda storage, loc: storage))
star_encoder2.eval(); 

In [None]:
# get parameters on the full image 
map_locs_full_image2, map_fluxes_full_image2, map_n_stars_full2 = \
    star_encoder2.sample_star_encoder(full_image, 
                                       full_background, 
                                       n_stars = _n_stars, return_map = True)[0:3]

In [None]:
# reconstruction means
recon_mean2 = simulator.draw_image_from_params(locs = map_locs_full_image2, 
                                                fluxes = map_fluxes_full_image2,
                                                 n_stars = map_n_stars_full2, 
                                                 add_noise = False)

residuals2 = recon_mean2 - full_image 

In [None]:
n_stars2 = star_encoder2.get_image_stamps(full_image, 
                               map_locs_full_image2, map_fluxes_full_image2)[3]

# CHeck out losses

In [None]:
loss1, counter_loss1, locs_loss1, fluxes_loss1, _ = \
    inv_kl_lib.get_encoder_loss(star_encoder1, full_image[0:1, 0:1], 
                                full_background[0:1, 0:1], 
                                true_locs, 
                                true_fluxes[:, :, 0:1], 
                               use_l2_loss = True)
    
print(loss1)

In [None]:
loss2, counter_loss2, locs_loss2, fluxes_loss2, _ = \
    inv_kl_lib.get_encoder_loss(star_encoder2, full_image, 
                                full_background,
                                true_locs, 
                                true_fluxes, 
                               use_l2_loss = True)

print(loss2)

In [None]:
plt.plot(locs_loss1.detach(), locs_loss2.detach(), 'x')
plt.plot(locs_loss1.detach(), locs_loss1.detach(), '-')

assert torch.all((locs_loss1 > 0) == (locs_loss2 > 0))

print(locs_loss1[locs_loss1 > 0].mean())
print(locs_loss2[locs_loss1 > 0].mean())
print('\n')
print((locs_loss2 < locs_loss1)[locs_loss1 > 0].float().mean())

In [None]:
plt.plot(fluxes_loss1.detach(), fluxes_loss2.detach(), 'x')
plt.plot(fluxes_loss1.detach(), fluxes_loss1.detach(), '-')

assert torch.all((fluxes_loss1 > 0) == (fluxes_loss2 > 0))

print(fluxes_loss1[fluxes_loss1 > 0].mean())
print(fluxes_loss2[fluxes_loss1 > 0].mean())
print('\n')
print((fluxes_loss2 < fluxes_loss1)[fluxes_loss1 > 0].float().mean())

In [None]:
plt.plot(counter_loss1.detach(), counter_loss2.detach(), 'x')
plt.plot(counter_loss1.detach(), counter_loss1.detach(), '-')

print(counter_loss1.mean())
print(counter_loss2.mean())
print('\n')
print((counter_loss2 < counter_loss1).float().mean())

# Compare

In [None]:
subimage_slen = 9

# possible coordinates
x0_vec = np.arange(10, 90, subimage_slen)
x1_vec = np.arange(10, 90, subimage_slen)

In [None]:
residuals1.shape

In [None]:
residuals2.shape

In [None]:
map_locs_full_image1.squeeze()

In [None]:
fig, axarr = plt.subplots(2, 3, figsize=(15, 12))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

###################
# Plot catalogs
##################
for j in range(2):
    # first catalog
    plotting_utils.plot_subimage(axarr[j, 0], full_image[0, 0], 
                                 map_locs_full_image1.squeeze(), 
                                 true_locs.squeeze(), 
                                 x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                 global_fig = fig)
    axarr[j, 0].set_title('observed; coords: {}\n'.format([x0, x1]));
    
    # second catalog
    _est_locs = map_locs_full_image2.squeeze() * (full_image.shape[-1] - 1)
    which_locs = (_est_locs[:, 0] > x0) & \
                    (_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                    (_est_locs[:, 1] > x1) & \
                    (_est_locs[:, 1] < (x1 + subimage_slen - 1))
    est_locs = (_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
    axarr[j, 0].scatter(est_locs[:, 1], est_locs[:, 0], color = 'c', marker = 'x')

#######################
# Reconstructions 
#######################
# first reconstruction
plotting_utils.plot_subimage(axarr[0, 1], recon_mean1[0, 0], 
                             map_locs_full_image1.squeeze(), 
                             None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0, 1].set_title('reconstruction 1 \n');

# second reconstruction
plotting_utils.plot_subimage(axarr[1, 1], recon_mean2[0, 0], 
                             map_locs_full_image2.squeeze(), 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            color = 'c', marker = 'x')
axarr[1, 1].set_title('reconstruction 2 \n');

# ######################
# # residuals
# ######################
vmax1 = torch.abs((residuals1 / full_image[0, 0])[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
vmax2 = torch.abs((residuals2 / full_image)[0, 0, x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()

vmax = torch.max(torch.Tensor([vmax1, vmax2]))

# first residuals
plotting_utils.plot_subimage(axarr[0, 2], residuals1 / full_image[0, 0], 
                            map_locs_full_image1.squeeze(), 
                             true_locs.squeeze(), 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[0, 2].set_title('residuals 1\n');



# second residuals
plotting_utils.plot_subimage(axarr[1, 2], residuals2[0, 0] / full_image[0, 0], 
                            map_locs_full_image2.squeeze(), 
                             true_locs.squeeze(), 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax, 
                            color = 'c')

axarr[1, 2].set_title('residuals 2\n');



# Checkout some summary statistics

In [None]:
# we only look at locations within 10-90; 
# Portillos doesn't detect on the edge

def filter_params(locs, fluxes, slen): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > 10) & (_locs[:, 0] < 90) & \
                        (_locs[:, 1] > 10) & (_locs[:, 1] < 90) 
        
    
    return locs[which_params], fluxes[which_params]


In [None]:
true_fluxes.shape

In [None]:
est_locs1, est_fluxes1 = filter_params(map_locs_full_image1.squeeze(), 
                                           map_fluxes_full_image1.squeeze(), 
                                           full_image.shape[-1])

est_locs2, est_fluxes2 = filter_params(map_locs_full_image2.squeeze(), 
                                           map_fluxes_full_image2[0, :, 0], 
                                           full_image.shape[-1])

true_locs, true_fluxes = filter_params(true_locs.squeeze(), true_fluxes[0, :, 0], 
                                       full_image.shape[-1])

In [None]:
print('n_stars 1: ', len(est_locs1))
print('n_stars 2: ', len(est_locs2))
print('true n_stars: ', len(true_locs))

In [None]:
# _recon1 = simulator.draw_image_from_params(locs = est_locs1.unsqueeze(0), 
#                                     fluxes = est_fluxes1.unsqueeze(0),
#                                      n_stars = torch.Tensor([len(est_fluxes1)]).type(torch.LongTensor), 
#                                      add_noise = False).squeeze()

# _recon2 = \
#     simulator.draw_image_from_params(locs = est_locs2.unsqueeze(0), 
#                                     fluxes = est_fluxes2.unsqueeze(0),
#                                      n_stars = torch.Tensor([len(est_fluxes2)]).type(torch.LongTensor), 
#                                      add_noise = False).squeeze()

# _recon_truth = \
#     simulator.draw_image_from_params(locs = true_locs.unsqueeze(0), 
#                                     fluxes = true_fluxes.unsqueeze(0),
#                                      n_stars = torch.Tensor([len(true_locs)]).type(torch.LongTensor), 
#                                      add_noise = False).squeeze()


# fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

# axarr[0].matshow(_recon1)
# axarr[1].matshow(_recon2)
# axarr[2].matshow(_recon_truth)

In [None]:
# # check out MSEs

# _image = full_image[10:90, 10:90] 

# _residual1 = (_recon1[10:90, 10:90] - _image)
# _residual2 = (_recon2[10:90, 10:90] - _image)
# _true_residual = _recon_truth[10:90, 10:90] - _image

# print('mse 1: ', torch.mean(_residual1**2))
# print('mse 2: ', torch.mean(_residual2**2))
# print('truth_mse: ', torch.mean(_true_residual**2))

# fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

# vmax = (_residual2 / _image).abs().max()

# im1 = axarr[0].matshow(_residual1 / _image, vmin = -vmax, vmax = vmax, cmap = plt.get_cmap('bwr'))
# im2 = axarr[1].matshow(_residual2 / _image, vmin = -vmax, vmax = vmax, cmap = plt.get_cmap('bwr'))
# im3 = axarr[2].matshow(_true_residual / _image, vmin = -vmax, vmax = vmax, cmap = plt.get_cmap('bwr'))

# fig.colorbar(im1, ax = axarr[0])
# fig.colorbar(im2, ax = axarr[1])
# fig.colorbar(im3, ax = axarr[2])

In [None]:
# fig, axarr = plt.subplots(1, 3, figsize=(20, 5))

# axarr[0].hist((_residual1 / _image).flatten(), bins = 100);

# axarr[1].hist((_residual2 / _image).flatten(), bins = 100);

# axarr[2].hist((_true_residual / _image).flatten(), bins = 100);

# Get summary statistics

These are rather coarse measures. My completeness does not take into account the fact that several true stars might be matched with just one estimated star (so not all the true stars were detected); conversely my true positive rate does not take into account that several estimated stars might be matched with just one true star (so only one estimated star is a true positive). 

I tried the Hungarian algorithm to find a minimal matching, but this gave weird results because we're searching for a permutation that minimizes the **global** cost of the matching. 

In [None]:
import image_statistics_lib

In [None]:
# completeness and tpr using locations only
completeness1, tpr1, completeness1_bool, _ = \
    image_statistics_lib.get_summary_stats(est_locs1, true_locs, 
                                           full_image.shape[-1], None, None)
completeness2, tpr2, completeness2_bool, _ = \
    image_statistics_lib.get_summary_stats(est_locs2, true_locs, 
                                           full_image.shape[-1], None, None)

    
print('completeness 1: {:0.3f}'.format(completeness1))
print('completeness 2: {:0.3f}\n'.format(completeness2))

print('true positive rate 1: {:0.3f}'.format(tpr1))
print('true positive rate 2: {:0.3f}'.format(tpr2))

In [None]:
# take into account fluxes
completeness1, tpr1, completeness1_bool, tpr1_bool = \
    image_statistics_lib.get_summary_stats(est_locs1, true_locs, 
                                           full_image.shape[-1], 
                                           est_fluxes1, true_fluxes)
    
completeness2, tpr2, completeness2_bool, tpr2_bool = \
    image_statistics_lib.get_summary_stats(est_locs2, true_locs, 
                                           full_image.shape[-1], 
                                           est_fluxes2, true_fluxes)

    
print('completeness1: {:0.3f}'.format(completeness1))
print('completeness2: {:0.3f}\n'.format(completeness2))

print('true positive rate 1 : {:0.3f}'.format(tpr1))
print('true positive rate 2: {:0.3f}'.format(tpr2))

In [None]:
completeness1_vec, mag_vec1 = \
    image_statistics_lib.get_completeness_vec(est_locs1, true_locs, full_image.shape[-1],
                                              est_fluxes1, true_fluxes)[0:2]

completeness2_vec, mag_vec2 = \
    image_statistics_lib.get_completeness_vec(est_locs2, true_locs, full_image.shape[-1],
                                              est_fluxes2, true_fluxes)[0:2]

plt.plot(mag_vec1[0:-1], completeness1_vec, '--x', label = '1')
plt.plot(mag_vec2[0:-1], completeness2_vec, '--x', label = '2')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
tpr1_vec, mag_vec1 = \
    image_statistics_lib.get_tpr_vec(est_locs1, true_locs, full_image.shape[-1],
                                    est_fluxes1, true_fluxes)[0:2]

tpr2_vec, mag_vec2 = \
    image_statistics_lib.get_tpr_vec(est_locs2, true_locs, full_image.shape[-1],
                                    est_fluxes2, true_fluxes)[0:2]

plt.plot(mag_vec1[0:-1], tpr1_vec, '--x', label = '1')
plt.plot(mag_vec2[0:-1], tpr2_vec, '--x', label = '2')

plt.legend()
plt.xlabel('estimated log flux')
plt.ylabel('true positive rate')

In [None]:
(torch.log10(true_fluxes) > 5.0).sum()

In [None]:
which = torch.where((completeness1_bool == 1) & \
                    (torch.log10(true_fluxes) > 5.0) & \
                    (completeness2_bool == 0))

In [None]:
which

In [None]:
true_locs[620]

In [None]:
fig, axarr = plt.subplots(3, 2, figsize=(16, 24))

x0 = 80 # int(np.random.choice(x0_vec, 1)) 
x1 = 15 # int(np.random.choice(x1_vec, 1))


##########################
# PLOT STARS CAUGHT BY ENCODER 1
##########################
plotting_utils.plot_subimage(axarr[0, 0], full_image[0, 0], 
                             est_locs1, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 0].set_title('encoder 1; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[completeness1_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = est_locs1[tpr1_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')

#####################
# PLOT STARS CAUGHT BY ENCODER 2
####################
plotting_utils.plot_subimage(axarr[0, 1], full_image[0, 0], 
                             est_locs2, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 1].set_title('encoder 2; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[completeness2_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = est_locs2[tpr2_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')

##########################
# PLOT STARS CAUGHT BY ONLY ME
##########################
plotting_utils.plot_subimage(axarr[1, 0], full_image[0, 0], 
                             est_locs1, 
                             true_locs[(completeness1_bool == 1) & (completeness2_bool == 0)], 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)



##########################
# PLOT STARS CAUGHT BY Portillos
##########################
plotting_utils.plot_subimage(axarr[1, 1], full_image[0, 0], 
                             est_locs2, 
                             true_locs[(completeness2_bool == 1) & (completeness1_bool == 0)], 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)


##########################
# RESIDUALS
##########################
# _resid = (recon_mean1 - full_image).squeeze()
# vmax = torch.abs(_resid[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
# plotting_utils.plot_subimage(axarr[2, 0], _resid, 
#                             est_locs1, None, x0, x1, subimage_slen, 
#                             add_colorbar = True, global_fig = fig,
#                             diverging_cmap = True, 
#                             vmax = vmax, vmin = -vmax)

# # axarr[2, 0].set_title('my residuals\n');

# _resid = (recon_mean2 - full_image).squeeze()
# vmax = torch.abs(_resid[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
# plotting_utils.plot_subimage(axarr[2, 1], _resid, 
#                             est_locs2, None, x0, x1, subimage_slen, 
#                             add_colorbar = True, global_fig = fig,
#                             diverging_cmap = True, 
#                             vmax = vmax, vmin = -vmax)

# # axarr[2, 1].set_title('my residuals\n');


In [None]:
foo = est_locs2 * (full_image.shape[-1] - 1)

which_locs = (foo[:, 0] > x0) & \
                (foo[:, 0] < (x0 + subimage_slen - 1)) & \
                (foo[:, 1] > x1) & \
                (foo[:, 1] < (x1 + subimage_slen - 1))


In [None]:
torch.log10(est_fluxes2[which_locs])

In [None]:
10**(4.22)

In [None]:
np.log10(70000.)

In [None]:
est_fluxes2[which_locs]