In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib


import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
fmin = 1000

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = [2, 3])

# image 
full_image = sdss_hubble_data.sdss_image
full_background = sdss_hubble_data.sdss_background

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > fmin)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
true_locs.shape

In [None]:
full_image = torch.Tensor(full_image)
print(full_image.shape)

In [None]:
plt.matshow(full_image[0])
plt.colorbar()

# Get simulator 

In [None]:
import fitsio

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = np.array([psf_r, psf_i])

sky_intensity = torch.Tensor([686., 1123.])

In [None]:
simulator1 = simulated_datasets_lib.StarSimulator(psf=psf_og[0:1], 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = sky_intensity[0:1])


simulator = simulated_datasets_lib.StarSimulator(psf=psf_og, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = sky_intensity)



# Simulation with ground truth

In [None]:
truth_recon = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                            fluxes = sdss_hubble_data.fluxes.unsqueeze(0),
                            n_stars = torch.Tensor([len(sdss_hubble_data.fluxes)]).type(torch.LongTensor), 
                            add_noise = False).squeeze()

In [None]:
for i in range(len(sdss_hubble_data.bands)): 
    foo = (truth_recon[i] - full_image[i]) / full_image[i]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
f, axarr = plt.subplots(1, 5, figsize=(12, 6))

for i in range(5): 
    plotting_utils.plot_subimage(axarr[i], full_image[0],
                                 true_locs, 
                                 None, 
                                 x0 = int(np.random.choice(100, 1)), 
                                 x1 = int(np.random.choice(100, 1)), 
                                 subimage_slen = 7)
    axarr[i].set_xticks([]);
    axarr[i].set_yticks([]);
    


# Load results

In [None]:
results_dir = '../../multiband_pcat/pcat-lion-results/20191107-115253/'

chain_results = np.load(results_dir + 'chain.npz')

In [None]:
# n bands 
chain_results['f'].shape

In [None]:
# fudge_factor = 1 / (1 - 0.83)
fudge_factor = sdss_hubble_data.sdss_data[0]['gain'][0] 

In [None]:
include_classical_catalogue = False

if include_classical_catalogue: 
    pcat_catalog = np.loadtxt(results_dir + 'classical_catalog.txt')
    
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
        
    fluxes = pcat_catalog[:, 4] * fudge_factor
    
    # remove na
    is_na = np.isnan(x1_loc) | np.isnan(x1_loc) | np.isnan(fluxes) | (fluxes < fmin)
    
    x1_loc = x1_loc[~is_na]
    x0_loc = x0_loc[~is_na]
    fluxes = fluxes[~is_na]
    
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes)
else: 
    # just take one sample 
    fluxes = chain_results['f'][:, -1, ].transpose() * fudge_factor
    
    x1_loc = chain_results['x'][-1, ].flatten()[fluxes[:, 0] > fmin]
    x0_loc = chain_results['y'][-1, ].flatten()[fluxes[:, 0] > fmin]
    
    fluxes = fluxes[fluxes[:, 0] > fmin]
        
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes) 
    

# x1_loc_samples = chain_results['x'][-300:, ].flatten()
# x0_loc_samples = chain_results['y'][-300:, ].flatten()

# portillos_est_fluxes_sampled = torch.Tensor(chain_results['f'][0, -300:, ].flatten()) * fudge_factor
# portillos_est_locs_sampled = torch.Tensor([x0_loc_samples, x1_loc_samples]).transpose(0,1) \
#                                 / (full_image.shape[-1] - 1)
    
# # filter by fmin
# port_which_bright = portillos_est_fluxes_sampled > fmin
# portillos_est_fluxes_sampled = portillos_est_fluxes_sampled[port_which_bright]
# portillos_est_locs_sampled = portillos_est_locs_sampled[port_which_bright]

In [None]:
portillos_est_fluxes.shape

In [None]:
plt.hist(torch.log10(portillos_est_fluxes[:, 0]))

# get reconstruction mean 

In [None]:
_locs = portillos_est_locs.unsqueeze(0) 
_fluxes = portillos_est_fluxes.unsqueeze(0)
_n_stars = torch.Tensor([len(x0_loc)]).type(torch.LongTensor)

if _fluxes.shape[-1] == 1:
    portillos_recon_mean = simulator1.draw_image_from_params(locs = _locs, 
                                            fluxes = _fluxes,
                                             n_stars = _n_stars,  
                                             add_noise = False).squeeze(0)
else: 
    portillos_recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes,
                                                 n_stars = _n_stars,  
                                                 add_noise = False).squeeze()

plt.matshow(portillos_recon_mean[0]); 
plt.colorbar()

In [None]:
portillos_residuals = portillos_recon_mean - full_image

for i in range(2): 
    foo = (portillos_residuals[i] / full_image[i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
    plt.colorbar()

# Plot subimages

In [None]:
band = 0

In [None]:
subimage_slen = 10

# possible coordinates
x0_vec = np.arange(10, 90, subimage_slen)
x1_vec = np.arange(10, 90, subimage_slen)

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 4, figsize=(15, 4))

# posterior samples
# plotting_utils.plot_subimage(axarr[0], full_image[band], 
#                              portillos_est_locs_sampled, 
#                              true_locs, 
#                              x0, x1, subimage_slen)
# axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# condensed catalog
plotting_utils.plot_subimage(axarr[1], full_image[band],
                             portillos_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[2], portillos_recon_mean[band],
                             portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[2].set_title('reconstructed\n');

# residuals
plotting_utils.plot_subimage(axarr[3], portillos_residuals[band] / full_image[band], 
                            portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True)

axarr[3].set_title('reconstructed\n');



# Compare with my NN 

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = 2,
                                            max_detections = 2)

In [None]:
# star_encoder.load_state_dict(torch.load('../fits/wake_sleep-loc630x310-reweighted_prior-iwae-10252019-encoder-iter6', 
#                                map_location=lambda storage, loc: storage))
# star_encoder.eval(); 
# star_encoder.load_state_dict(torch.load('../fits/results_11052019/wake_sleep3-loc630x310-encoder-iter6', 
#                                map_location=lambda storage, loc: storage))

# star_encoder.eval(); 
star_encoder.load_state_dict(torch.load('../fits/results_11122019/starnet_ri', 
                               map_location=lambda storage, loc: storage))


star_encoder.eval(); 


In [None]:
# get parameters on the full image 
# map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
#     star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
#                                            full_background.unsqueeze(0).unsqueeze(0))

map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    star_encoder.sample_star_encoder(full_image.unsqueeze(0), 
                                    full_background.unsqueeze(0), 
                                    return_map = True)[0:3]

In [None]:
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = map_n_stars_full, 
                                                 add_noise = False).squeeze()

vae_residuals = vae_recon_mean - full_image

In [None]:
simulator.sky_intensity

In [None]:
my_est_locs = map_locs_full_image.squeeze() 
my_est_fluxes = map_fluxes_full_image.squeeze()

In [None]:
band = 0

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

# my catalog
plotting_utils.plot_subimage(axarr[0], full_image[band], my_est_locs, true_locs, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[1], vae_recon_mean[band], my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('reconstructed\n');

# residuals
vmax = torch.abs((vae_residuals / full_image)[band, x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[2], (vae_residuals / full_image)[band], 
                            my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[2].set_title('residuals\n');



# Checkout MSEs

In [None]:
len(my_est_fluxes)

In [None]:
len(portillos_est_fluxes)

In [None]:
len(true_fluxes)

In [None]:
# reconstructions 
fig, axarr = plt.subplots(1, 3, figsize=(15, 12))

axarr[0].matshow(vae_recon_mean[band])
axarr[1].matshow(portillos_recon_mean[band])
axarr[2].matshow(truth_recon[band])

In [None]:
# check out MSEs
_image = full_image[band, 10:90, 10:90] 

_my_residual = (vae_recon_mean[band, 10:90, 10:90] - _image) / _image
_portillos_residual = (portillos_recon_mean[band, 10:90, 10:90] - _image) / _image
_true_residual = (truth_recon[band, 10:90, 10:90] - _image) / _image

print('my_mse: ', torch.mean(_my_residual**2))
print('portillos_mse: ', torch.mean(_portillos_residual**2))
print('truth_mse: ', torch.mean(_true_residual**2))

fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

im1 = axarr[0].matshow(_my_residual, 
                       vmin = -_my_residual.abs().max(), 
                       vmax = _my_residual.abs().max(), 
                      cmap=plt.get_cmap('bwr'))
im2 = axarr[1].matshow(_portillos_residual, 
                       vmin = -_portillos_residual.abs().max(), 
                       vmax = _portillos_residual.abs().max(), 
                      cmap=plt.get_cmap('bwr'))
im3 = axarr[2].matshow(_true_residual,
                       vmin = -_true_residual.abs().max(), 
                       vmax = _true_residual.abs().max(), 
                      cmap=plt.get_cmap('bwr'))

fig.colorbar(im1, ax = axarr[0])
fig.colorbar(im2, ax = axarr[1])
fig.colorbar(im3, ax = axarr[2])

In [None]:
plt.hist((_my_residual * _image).flatten() , bins = 100);

# Compare

In [None]:
fig, axarr = plt.subplots(2, 3, figsize=(15, 12))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

###################
# Plot catalogs
##################
# my catalog
for j in range(2):
    plotting_utils.plot_subimage(axarr[j, 0], full_image[band], my_est_locs, true_locs, x0, x1, subimage_slen, 
                                add_colorbar = True, global_fig = fig)
    axarr[j, 0].set_title('observed; coords: {}\n'.format([x0, x1]));

    # portillos catalogue
    _portillos_est_locs = portillos_est_locs * (full_image.shape[-1] - 1)
    which_locs = (_portillos_est_locs[:, 0] > x0) & \
                    (_portillos_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                    (_portillos_est_locs[:, 1] > x1) & \
                    (_portillos_est_locs[:, 1] < (x1 + subimage_slen - 1))
    portillos_locs = (_portillos_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
    axarr[j, 0].scatter(portillos_locs[:, 1], portillos_locs[:, 0], color = 'c', marker = 'x')

#######################
# Reconstructions 
#######################
# my reconstruction
plotting_utils.plot_subimage(axarr[0, 1], vae_recon_mean[band], my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0, 1].set_title('vae reconstructed\n');

# Portillos reconstruction
plotting_utils.plot_subimage(axarr[1, 1], portillos_recon_mean[band], 
                             portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            color = 'c', marker = 'x')
axarr[1, 1].set_title('portillos reconstructed\n');

######################
# residuals
######################
# my residuals
plotting_utils.plot_subimage(axarr[0, 2], (vae_residuals / full_image)[band], 
                            my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True)

axarr[0, 2].set_title('vae residuals\n');



# portillos residuals
plotting_utils.plot_subimage(axarr[1, 2], (portillos_residuals / full_image)[band], 
                            portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            color = 'c', marker = 'x')

axarr[1, 2].set_title('portillos residuals\n');

# Checkout some summary statistics

In [None]:
# we only look at locations within 10-90; 
# Portillos doesn't detect on the edge

def filter_params(locs, fluxes, slen): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > 5) & (_locs[:, 0] < 95) & \
                        (_locs[:, 1] > 5) & (_locs[:, 1] < 95) 
        
    
    return locs[which_params], fluxes[which_params]


In [None]:
my_est_locs, my_est_fluxes = filter_params(my_est_locs, my_est_fluxes[:, band], full_image.shape[-1])

portillos_est_locs, portillos_est_fluxes = filter_params(portillos_est_locs,
                                                         portillos_est_fluxes[:, band], 
                                                         full_image.shape[-1])

true_locs, true_fluxes = filter_params(true_locs, true_fluxes[:, band], 
                                       full_image.shape[-1])

In [None]:
print('my n_stars: ', len(my_est_locs))
print('portillos n_stars: ', len(portillos_est_locs))
print('true n_stars: ', len(true_locs))

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(20, 5))

axarr[0].hist((_my_residual / _image).flatten(), bins = 100);

axarr[1].hist((_portillos_residual / _image).flatten(), bins = 100);

axarr[2].hist((_true_residual / _image).flatten(), bins = 100);

# Get summary statistics

These are rather coarse measures. My completeness does not take into account the fact that several true stars might be matched with just one estimated star (so not all the true stars were detected); conversely my true positive rate does not take into account that several estimated stars might be matched with just one true star (so only one estimated star is a true positive). 

I tried the Hungarian algorithm to find a minimal matching, but this gave weird results because we're searching for a permutation that minimizes the **global** cost of the matching. 

In [None]:
import image_statistics_lib

In [None]:
# completeness and tpr using locations only
my_completeness, my_tpr, _, _ = \
    image_statistics_lib.get_summary_stats(my_est_locs, true_locs, 
                                           full_image.shape[-1], None, None)
portillos_completeness, portillos_tpr, _, _ = \
    image_statistics_lib.get_summary_stats(portillos_est_locs, true_locs, 
                                           full_image.shape[-1], None, None)

    
print('my completeness: {:0.3f}'.format(my_completeness))
print('portillos completeness: {:0.3f}\n'.format(portillos_completeness))

print('my true positive rate: {:0.3f}'.format(my_tpr))
print('portillos true positive rate: {:0.3f}'.format(portillos_tpr))

In [None]:
# take into account fluxes
my_completeness, my_tpr, my_complete_bool, my_tpr_bool = \
    image_statistics_lib.get_summary_stats(my_est_locs, true_locs, 
                                           full_image.shape[-1], 
                                           my_est_fluxes, true_fluxes)
    
portillos_completeness, portillos_tpr, portillos_complete_bool, portillos_tpr_bool = \
    image_statistics_lib.get_summary_stats(portillos_est_locs, true_locs, 
                                           full_image.shape[-1], 
                                           portillos_est_fluxes, true_fluxes)

    
print('my completeness: {:0.3f}'.format(my_completeness))
print('portillos completeness: {:0.3f}\n'.format(portillos_completeness))

print('my true positive rate: {:0.3f}'.format(my_tpr))
print('portillos true positive rate: {:0.3f}'.format(portillos_tpr))

In [None]:
my_completeness_vec, my_mag_vec,  = \
    image_statistics_lib.get_completeness_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)[0:2]

portillos_completeness_vec, portillos_mag_vec = \
    image_statistics_lib.get_completeness_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)[0:2]

plt.plot(my_mag_vec[0:-1], my_completeness_vec, '--x', label = 'Starnet')
plt.plot(portillos_mag_vec[0:-1], portillos_completeness_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
my_tpr_vec, my_mag_vec, my_counts = \
    image_statistics_lib.get_tpr_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)

portillos_tpr_vec, portillos_mag_vec, portillos_counts = \
    image_statistics_lib.get_tpr_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)

plt.plot(my_mag_vec[0:-1], my_tpr_vec, '--x', label = 'Starnet')
plt.plot(portillos_mag_vec[0:-1], portillos_tpr_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('estimated log flux')
plt.ylabel('true positive rate')

In [None]:
fig, axarr = plt.subplots(3, 2, figsize=(16, 24))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))
subimage_slen = 10

##########################
# PLOT STARS CAUGHT BY ME
##########################
plotting_utils.plot_subimage(axarr[0, 0], full_image, 
                             my_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 0].set_title('my results; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[my_complete_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = my_est_locs[my_tpr_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')

#####################
# PLOT STARS CAUGHT BY PORTILLOS
####################
plotting_utils.plot_subimage(axarr[0, 1], full_image, 
                             portillos_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 1].set_title('portillos results; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[portillos_complete_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = portillos_est_locs[portillos_tpr_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')

##########################
# PLOT STARS CAUGHT BY ONLY ME
##########################
plotting_utils.plot_subimage(axarr[1, 0], full_image, 
                             my_est_locs, 
                             true_locs[(my_complete_bool == 1) & (portillos_complete_bool == 0)], 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)



##########################
# PLOT STARS CAUGHT BY Portillos
##########################
plotting_utils.plot_subimage(axarr[1, 1], full_image, 
                             portillos_est_locs, 
                             true_locs[(my_complete_bool == 0) & (portillos_complete_bool == 1)], 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)

##########################
# RECONSTRUCTIONS
##########################
plotting_utils.plot_subimage(axarr[2, 0], vae_recon_mean, 
                             my_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[2, 0].set_title('my reconstruction')
plotting_utils.plot_subimage(axarr[2, 1], portillos_recon_mean, 
                             portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[2, 0].set_title('portillos reconstruction')


##########################
# RESIDUALS
##########################
# _resid = torch.log(vae_recon_mean / full_image)
# _resid = (vae_recon_mean - full_image)/full_image
# vmax = torch.abs(_resid[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
# plotting_utils.plot_subimage(axarr[2, 0], _resid, 
#                             my_est_locs, None, x0, x1, subimage_slen, 
#                             add_colorbar = True, global_fig = fig,
#                             diverging_cmap = True, 
#                             vmax = vmax, vmin = -vmax)

# axarr[2, 0].set_title('my residuals\n');

# # _resid = torch.log(portillos_recon_mean / full_image)
# _resid = (portillos_recon_mean - full_image)/full_image
# vmax = torch.abs(_resid[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
# plotting_utils.plot_subimage(axarr[2, 1], _resid, 
#                             portillos_est_locs, None, x0, x1, subimage_slen, 
#                             add_colorbar = True, global_fig = fig,
#                             diverging_cmap = True, 
#                             vmax = vmax, vmin = -vmax)

# axarr[2, 1].set_title('Portillos residuals\n');


In [None]:
star_encoder2 = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 9,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = 1,
                                            max_detections = 2)

In [None]:
star_encoder2.load_state_dict(torch.load('../fits/starnet-10162019-reweighted', 
                               map_location=lambda storage, loc: storage))
star_encoder2.eval(); 

In [None]:
# get parameters on the full image 
# map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
#     star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
#                                            full_background.unsqueeze(0).unsqueeze(0))

map_locs_full_image2, map_fluxes_full_image2, map_n_stars_full2 = \
    star_encoder2.sample_star_encoder(full_image.unsqueeze(0).unsqueeze(0), 
                                    full_background.unsqueeze(0).unsqueeze(0), 
                                    return_map = True)[0:3]

In [None]:
my_est_locs2, my_est_fluxes2 = filter_params(map_locs_full_image2.squeeze(), map_fluxes_full_image2.squeeze(), 
                                           full_image.shape[-1])

In [None]:
my_completeness_vec, my_mag_vec,  = \
    image_statistics_lib.get_completeness_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)[0:2]

my_completeness_vec2, my_mag_vec2,  = \
    image_statistics_lib.get_completeness_vec(my_est_locs2, true_locs, full_image.shape[-1],
                                              my_est_fluxes2, true_fluxes)[0:2]

portillos_completeness_vec, portillos_mag_vec = \
    image_statistics_lib.get_completeness_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)[0:2]

plt.plot(my_mag_vec[0:-1], my_completeness_vec, '--x', label = 'Starnet')
# plt.plot(my_mag_vec2[0:-1], my_completeness_vec2, '--x', label = 'Starnet2')
plt.plot(portillos_mag_vec[0:-1], portillos_completeness_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
plt.plot(my_mag_vec[0:-1], my_completeness_vec, '--x', label = 'iter6')
plt.plot(my_mag_vec2[0:-1], my_completeness_vec2, '--x', label = 'iter0')
plt.plot(portillos_mag_vec[0:-1], portillos_completeness_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
my_tpr_vec, my_mag_vec, my_counts = \
    image_statistics_lib.get_tpr_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)
    
my_tpr_vec2, my_mag_vec2, my_counts2 = \
    image_statistics_lib.get_tpr_vec(my_est_locs2, true_locs, full_image.shape[-1],
                                              my_est_fluxes2, true_fluxes)

portillos_tpr_vec, portillos_mag_vec, portillos_counts = \
    image_statistics_lib.get_tpr_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)

plt.plot(my_mag_vec[0:-1], my_tpr_vec, '--x', label = 'iter6')
plt.plot(my_mag_vec2[0:-1], my_tpr_vec2, '--x', label = 'iter0')
plt.plot(portillos_mag_vec[0:-1], portillos_tpr_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('estimated log flux')
plt.ylabel('true positive rate')

# Check out parameters in my starnet patches

In [None]:
image_stamps = \
    image_utils.tile_images(full_image.unsqueeze(0).unsqueeze(0),
                            star_encoder.stamp_slen,
                            star_encoder.step)

In [None]:
background_stamps = \
    image_utils.tile_images(full_background.unsqueeze(0).unsqueeze(0),
                            star_encoder.stamp_slen,
                            star_encoder.step)

In [None]:
which_tile_coords = (star_encoder.tile_coords[:, 0] > 9) & (star_encoder.tile_coords[:, 0] < 91) & \
                        (star_encoder.tile_coords[:, 1] > 9) & (star_encoder.tile_coords[:, 1] < 91)
_tile_coords = star_encoder.tile_coords[which_tile_coords, :]

In [None]:
my_n_stars = image_utils.get_params_in_patches(_tile_coords, 
                                                  my_est_locs.unsqueeze(0), 
                                                  my_est_fluxes.unsqueeze(0), 
                                                slen = star_encoder.full_slen, 
                                                subimage_slen = star_encoder.stamp_slen, 
                                                 edge_padding = star_encoder.edge_padding)[2]

portillos_n_stars = image_utils.get_params_in_patches(_tile_coords, 
                                                  portillos_est_locs.unsqueeze(0), 
                                                  portillos_est_fluxes.unsqueeze(0), 
                                                slen = star_encoder.full_slen, 
                                                subimage_slen = star_encoder.stamp_slen, 
                                                 edge_padding = star_encoder.edge_padding)[2]


true_n_stars = image_utils.get_params_in_patches(_tile_coords, 
                                                  true_locs.unsqueeze(0), 
                                                  true_fluxes.unsqueeze(0), 
                                                slen = star_encoder.full_slen, 
                                                subimage_slen = star_encoder.stamp_slen, 
                                                 edge_padding = star_encoder.edge_padding)[2]


In [None]:
from torch.distributions import poisson


In [None]:
mean_stars_per_patch = 0.4
poisson_dstr = poisson.Poisson(rate = mean_stars_per_patch)
probs = torch.exp(poisson_dstr.log_prob(torch.arange(7).float()))

In [None]:
plt.hist(portillos_n_stars, np.arange(0, max(portillos_n_stars) + 3))
plt.plot(np.arange(7), probs * len(portillos_n_stars), color = 'r', marker = 'x')

In [None]:
plt.hist(my_n_stars, np.arange(0, max(my_n_stars) + 3))
plt.plot(np.arange(7), probs * len(my_n_stars), color = 'r', marker = 'x')

In [None]:
plt.hist(true_n_stars, np.arange(0, max(true_n_stars) + 3))
plt.plot(np.arange(7), probs * len(true_n_stars), color = 'r', marker = 'x')

In [None]:
(true_n_stars == portillos_n_stars).float().mean()

In [None]:
(true_n_stars == my_n_stars).float().mean()

In [None]:
log_prob_stamps = \
    star_encoder.forward(image_stamps[which_tile_coords], background_stamps[which_tile_coords])[4]

In [None]:
_tile_coords_filtered = _tile_coords[true_n_stars != my_n_stars, ]

In [None]:
fig, axarr = plt.subplots(3, 2, figsize=(16, 24))

indx = int(np.random.choice(len(_tile_coords_filtered), 1))
x0 = int(_tile_coords_filtered[indx, 0]) # int(np.random.choice(x0_vec, 1))
x1 = int(_tile_coords_filtered[indx, 1]) # int(np.random.choice(x1_vec, 1))
subimage_slen = 9

##########################
# PLOT STARS CAUGHT BY ME
##########################
plotting_utils.plot_subimage(axarr[0, 0], full_image, 
                             my_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 0].set_title('my results; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[my_complete_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = my_est_locs[my_tpr_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')

#####################
# PLOT STARS CAUGHT BY PORTILLOS
####################
plotting_utils.plot_subimage(axarr[0, 1], full_image, 
                             portillos_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, 
                             global_fig = fig, 
                            color = 'b')
axarr[0, 1].set_title('portillos results; coords: {}\n'.format([x0, x1]));

# true locations that I missed
_locs = true_locs[portillos_complete_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# estimated locations that were false
_locs = portillos_est_locs[portillos_tpr_bool == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0, 1].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')


for i in range(2): 
    axarr[0, i].axvline(x=star_encoder.edge_padding, color = 'r')
    axarr[0, i].axvline(x=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')
    axarr[0, i].axhline(y=star_encoder.edge_padding, color = 'r')
    axarr[0, i].axhline(y=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')

print(true_n_stars[indx])
print(torch.exp(log_prob_stamps[indx]))