In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
use_simulated_data = False

if use_simulated_data: 
    psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
    
    # load image
    full_image = np.loadtxt('../../multiband_pcat/Data/sdss_simulated/cts/sdss_simulated-ctsr.txt')
    full_image = torch.Tensor(full_image)

    # true parameters
    true_params = np.load('../../multiband_pcat/Data/sdss_simulated/true_params.npz')
    
    true_locs = torch.Tensor(true_params['true_locs'].squeeze())
    true_fluxes = torch.Tensor(true_params['true_fluxes'].squeeze())
    
else: 
    sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()
    
    # psf file 
    psf_fit_file = str(sdss_hubble_data.psf_file)
    
    # image 
    full_image = sdss_hubble_data.sdss_image.squeeze()
    full_background = sdss_hubble_data.sdss_background.squeeze()
    
    # true parameters
    which_bright = (sdss_hubble_data.fluxes > 1300.)
    true_locs = sdss_hubble_data.locs[which_bright]
    true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
full_image = torch.Tensor(full_image)
print(full_image.shape)

In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

# Load results

In [None]:
results_dir = '../../multiband_pcat/pcat-lion-results/20191007-115851/'

chain_results = np.load(results_dir + 'chain.npz')

In [None]:
include_classical_catalogue = True

if include_classical_catalogue: 
    pcat_catalog = np.loadtxt(results_dir + 'classical_catalog.txt')
    
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
        
    fluxes = pcat_catalog[:, 4]
    
    # remove na
    is_na = np.isnan(x1_loc) | np.isnan(x1_loc) | np.isnan(fluxes)
    
    x1_loc = x1_loc[~is_na]
    x0_loc = x0_loc[~is_na]
    fluxes = fluxes[~is_na]
    
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes)
    
x1_loc_samples = chain_results['x'][-300:, ].flatten()
x0_loc_samples = chain_results['y'][-300:, ].flatten()

portillos_est_fluxes_sampled = torch.Tensor(chain_results['f'][0, -300:, ].flatten()) 
portillos_est_locs_sampled = torch.Tensor([x0_loc_samples, x1_loc_samples]).transpose(0,1) \
                                / (full_image.shape[-1] - 1)

In [None]:
# TODO: this was chosen empirically looking at image residuals! 
# Need to figure out the correct conversion ... 
fudge_factor = 1 / (1 - 0.83)
if not use_simulated_data: 
    portillos_est_fluxes_sampled = portillos_est_fluxes_sampled * fudge_factor
    portillos_est_fluxes = portillos_est_fluxes * fudge_factor

# get reconstruction mean 

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = 0.)

# only works if we have the classical catalogue
if include_classical_catalogue: 
    _locs = portillos_est_locs.unsqueeze(0) 
    _fluxes = torch.Tensor(portillos_est_fluxes).unsqueeze(0)
    _n_stars = torch.Tensor([len(x0_loc)]).type(torch.LongTensor)
    
    portillos_recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes,
                                                 n_stars = _n_stars,  
                                                 add_noise = False).squeeze()
    
    plt.matshow(portillos_recon_mean); 
    plt.colorbar()

In [None]:
plt.hist(torch.log10(portillos_est_fluxes));

In [None]:
if include_classical_catalogue: 
    portillos_residuals = portillos_recon_mean - (full_image - full_background)
    
    plt.matshow(portillos_residuals / full_image)
    plt.colorbar()

In [None]:
plt.hist(torch.log10(true_fluxes), bins = 100);

In [None]:
hist = plt.hist((portillos_residuals / (full_image - full_background.squeeze())).flatten(), 
         bins = 100)

# Plot subimages

In [None]:
subimage_slen = 10

# possible coordinates
x0_vec = np.arange(10, 90, subimage_slen)
x1_vec = np.arange(10, 90, subimage_slen)

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 4, figsize=(15, 4))

# posterior samples
plotting_utils.plot_subimage(axarr[0], full_image, 
                             portillos_est_locs_sampled, 
                             true_locs, 
                             x0, x1, subimage_slen)
axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# condensed catalog
plotting_utils.plot_subimage(axarr[1], full_image,
                             portillos_est_locs, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[2], portillos_recon_mean,
                             portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[2].set_title('reconstructed\n');

# residuals
vmax = torch.abs((portillos_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[3], portillos_residuals / full_image, 
                            portillos_est_locs, 
                             None, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[3].set_title('reconstructed\n');



# Compare with my NN 

In [None]:
import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 9,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = 1,
                                            max_detections = 4)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/wake_sleep-portm2-101420129-encoder-iter3', 
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
# get parameters on the full image 
map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                           full_background.unsqueeze(0).unsqueeze(0))

In [None]:
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = map_n_stars_full, 
                                                 add_noise = False).squeeze()

vae_residuals = vae_recon_mean - (full_image - full_background)

In [None]:
my_est_locs = map_locs_full_image.squeeze() 
my_est_fluxes = map_fluxes_full_image.squeeze()

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

# my catalog
plotting_utils.plot_subimage(axarr[0], full_image, my_est_locs, true_locs, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[1], vae_recon_mean, my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('reconstructed\n');

# residuals
vmax = torch.abs((vae_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[2], vae_residuals / full_image, 
                            my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[2].set_title('residuals\n');



# Compare

In [None]:
fig, axarr = plt.subplots(2, 3, figsize=(15, 12))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

###################
# Plot catalogs
##################
# my catalog
for j in range(2):
    plotting_utils.plot_subimage(axarr[j, 0], full_image, my_est_locs, true_locs, x0, x1, subimage_slen, 
                                add_colorbar = True, global_fig = fig)
    axarr[j, 0].set_title('observed; coords: {}\n'.format([x0, x1]));

    # portillos catalogue
    _portillos_est_locs = portillos_est_locs * (full_image.shape[-1] - 1)
    which_locs = (_portillos_est_locs[:, 0] > x0) & \
                    (_portillos_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                    (_portillos_est_locs[:, 1] > x1) & \
                    (_portillos_est_locs[:, 1] < (x1 + subimage_slen - 1))
    portillos_locs = (_portillos_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
    axarr[j, 0].scatter(portillos_locs[:, 1], portillos_locs[:, 0], color = 'c', marker = 'x')

#######################
# Reconstructions 
#######################
# my reconstruction
plotting_utils.plot_subimage(axarr[0, 1], vae_recon_mean, my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0, 1].set_title('vae reconstructed\n');

# Portillos reconstruction
plotting_utils.plot_subimage(axarr[1, 1], portillos_recon_mean, portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            color = 'c', marker = 'x')
axarr[1, 1].set_title('portillos reconstructed\n');

######################
# residuals
######################
vmax1 = torch.abs((vae_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
vmax2 = torch.abs((portillos_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()

vmax = torch.max(torch.Tensor([vmax1, vmax2]))

# my residuals
plotting_utils.plot_subimage(axarr[0, 2], vae_residuals / full_image, 
                            my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[0, 2].set_title('vae residuals\n');



# portillos residuals
plotting_utils.plot_subimage(axarr[1, 2], portillos_residuals / full_image, 
                            portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax, 
                            color = 'c', marker = 'x')

axarr[1, 2].set_title('portillos residuals\n');

# Checkout some summary statistics

In [None]:
# we only look at locations within 10-90; 
# Portillos doesn't detect on the edge

def filter_params(locs, fluxes, slen): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > 10) & (_locs[:, 0] < 90) & \
                        (_locs[:, 1] > 10) & (_locs[:, 1] < 90) 
        
    
    return locs[which_params], fluxes[which_params]


In [None]:
my_est_locs, my_est_fluxes = filter_params(my_est_locs, my_est_fluxes, full_image.shape[-1])

portillos_est_locs, portillos_est_fluxes = filter_params(portillos_est_locs, portillos_est_fluxes, 
                                                         full_image.shape[-1])

true_locs, true_fluxes = filter_params(true_locs, true_fluxes, 
                                       full_image.shape[-1])

In [None]:
print('my n_stars: ', len(my_est_locs))
print('portillos n_stars: ', len(portillos_est_locs))
print('true n_stars: ', len(true_locs))

In [None]:
_recon_mine = simulator.draw_image_from_params(locs = my_est_locs.unsqueeze(0), 
                                    fluxes = my_est_fluxes.unsqueeze(0),
                                     n_stars = torch.Tensor([len(my_est_locs)]).type(torch.LongTensor), 
                                     add_noise = False).squeeze()

_recon_portillos = \
    simulator.draw_image_from_params(locs = portillos_est_locs.unsqueeze(0), 
                                    fluxes = portillos_est_fluxes.unsqueeze(0),
                                     n_stars = torch.Tensor([len(portillos_est_locs)]).type(torch.LongTensor), 
                                     add_noise = False).squeeze()

_recon_truth = \
    simulator.draw_image_from_params(locs = true_locs.unsqueeze(0), 
                                    fluxes = true_fluxes.unsqueeze(0),
                                     n_stars = torch.Tensor([len(true_locs)]).type(torch.LongTensor), 
                                     add_noise = False).squeeze()


fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(_recon_mine)
axarr[1].matshow(_recon_portillos)
axarr[2].matshow(_recon_truth)

In [None]:
# check out MSEs

_image = full_image[10:90, 10:90] - full_background[10:90, 10:90]

_my_residual = _recon_mine[10:90, 10:90] - _image
_portillos_residual = _recon_portillos[10:90, 10:90] - _image
_true_residual = _recon_truth[10:90, 10:90] - _image

print('my_mse: ', torch.mean(_my_residual**2))
print('portillos_mse: ', torch.mean(_portillos_residual**2))
print('truth_mse: ', torch.mean(_true_residual**2))

fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

im1 = axarr[0].matshow(_my_residual)
im2 = axarr[1].matshow(_portillos_residual)
im3 = axarr[2].matshow(_true_residual)

fig.colorbar(im1, ax = axarr[0])
fig.colorbar(im2, ax = axarr[1])
fig.colorbar(im3, ax = axarr[2])

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(20, 5))

axarr[0].hist((_my_residual / _image).flatten(), bins = 100);

axarr[1].hist((_portillos_residual / _image).flatten(), bins = 100);

axarr[2].hist((_true_residual / _image).flatten(), bins = 100);

# Get summary statistics

These are rather coarse measures. My completeness does not take into account the fact that several true stars might be matched with just one estimated star (so not all the true stars were detected); conversely my true positive rate does not take into account that several estimated stars might be matched with just one true star (so only one estimated star is a true positive). 

I tried the Hungarian algorithm to find a minimal matching, but this gave weird results because we're searching for a permutation that minimizes the **global** cost of the matching. 

In [None]:
import image_statistics_lib

In [None]:
# completeness and tpr using locations only
my_completeness, my_tpr, _, _ = \
    image_statistics_lib.get_summary_stats(my_est_locs, true_locs, 
                                           full_image.shape[-1], None, None)
portillos_completeness, portillos_tpr, _, _ = \
    image_statistics_lib.get_summary_stats(portillos_est_locs, true_locs, 
                                           full_image.shape[-1], None, None)

    
print('my completeness: {:0.3f}'.format(my_completeness))
print('portillos completeness: {:0.3f}\n'.format(portillos_completeness))

print('my true positive rate: {:0.3f}'.format(my_tpr))
print('portillos true positive rate: {:0.3f}'.format(portillos_tpr))

In [None]:
# take into account fluxes
my_completeness, my_tpr, _, _ = \
    image_statistics_lib.get_summary_stats(my_est_locs, true_locs, 
                                           full_image.shape[-1], 
                                           my_est_fluxes, true_fluxes)
    
portillos_completeness, portillos_tpr, _, _ = \
    image_statistics_lib.get_summary_stats(portillos_est_locs, true_locs, 
                                           full_image.shape[-1], 
                                           portillos_est_fluxes, true_fluxes)

    
print('my completeness: {:0.3f}'.format(my_completeness))
print('portillos completeness: {:0.3f}\n'.format(portillos_completeness))

print('my true positive rate: {:0.3f}'.format(my_tpr))
print('portillos true positive rate: {:0.3f}'.format(portillos_tpr))

In [None]:
my_completeness_vec, my_mag_vec = \
    image_statistics_lib.get_completeness_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)[0:2]

portillos_completeness_vec, portillos_mag_vec = \
    image_statistics_lib.get_completeness_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)[0:2]

plt.plot(my_mag_vec[0:-1], my_completeness_vec, '--x', label = 'Starnet')
plt.plot(portillos_mag_vec[0:-1], portillos_completeness_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
my_tpr_vec, my_mag_vec = \
    image_statistics_lib.get_tpr_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)[0:2]

portillos_tpr_vec, portillos_mag_vec = \
    image_statistics_lib.get_tpr_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)[0:2]

plt.plot(my_mag_vec[0:-1], my_tpr_vec, '--x', label = 'Starnet')
plt.plot(portillos_mag_vec[0:-1], portillos_tpr_vec, '--x', label = 'Portillos')

plt.legend()
plt.xlabel('estimated log flux')
plt.ylabel('true positive rate')

In [None]:
which_tile_coords = (star_encoder.tile_coords[:, 0] > 9) & (star_encoder.tile_coords[:, 0] < 91) & \
                        (star_encoder.tile_coords[:, 1] > 9) & (star_encoder.tile_coords[:, 1] < 91)
_tile_coords = star_encoder.tile_coords[which_tile_coords, :]

In [None]:
my_n_stars = image_utils.get_params_in_patches(_tile_coords, 
                                                  my_est_locs.unsqueeze(0), 
                                                  my_est_fluxes.unsqueeze(0), 
                                                slen = star_encoder.full_slen, 
                                                subimage_slen = star_encoder.stamp_slen, 
                                                 edge_padding = star_encoder.edge_padding)[2]

portillos_n_stars = image_utils.get_params_in_patches(_tile_coords, 
                                                  portillos_est_locs.unsqueeze(0), 
                                                  portillos_est_fluxes.unsqueeze(0), 
                                                slen = star_encoder.full_slen, 
                                                subimage_slen = star_encoder.stamp_slen, 
                                                 edge_padding = star_encoder.edge_padding)[2]


true_n_stars = image_utils.get_params_in_patches(_tile_coords, 
                                                  true_locs.unsqueeze(0), 
                                                  true_fluxes.unsqueeze(0), 
                                                slen = star_encoder.full_slen, 
                                                subimage_slen = star_encoder.stamp_slen, 
                                                 edge_padding = star_encoder.edge_padding)[2]


In [None]:
(portillos_n_stars[true_n_stars > 0] == true_n_stars[true_n_stars > 0]).float().mean()

In [None]:
(my_n_stars[true_n_stars > 0] == true_n_stars[true_n_stars > 0]).float().mean()

In [None]:
plt.hist(portillos_n_stars)

In [None]:
plt.hist(my_n_stars)