In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import inv_kl_objective_lib as inv_kl_lib

import starnet_vae_lib

import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

# psf file 
psf_fit_file = str(sdss_hubble_data.psf_file)

# image 
full_image = sdss_hubble_data.sdss_image.squeeze()
full_background = sdss_hubble_data.sdss_background.squeeze()

# true parameters
which_bright = (sdss_hubble_data.fluxes > 1300.)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]

In [None]:
full_image = torch.Tensor(full_image)
print(full_image.shape)

# the data

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze()[626:1175, 0:414])
plt.vlines(x=310, ymin=4, ymax=104, color = 'red')
plt.vlines(x=410, ymin=4, ymax=104, color = 'red')
plt.hlines(y=4, xmin=310, xmax = 410, color = 'red')
plt.hlines(y=104, xmin=310, xmax = 410, color = 'red')

In [None]:
plt.matshow(full_image.squeeze())
# plt.colorbar()

In [None]:
subimage_slen = 10

# possible coordinates
x0_vec = np.arange(10, 90, subimage_slen)
x1_vec = np.arange(10, 90, subimage_slen)

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 8))

for i in range(4):
    x0 = int(np.random.choice(x0_vec, 1))
    x1 = int(np.random.choice(x1_vec, 1))
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], full_image, 
                             None, 
                             true_locs, 
                             x0, x1, subimage_slen)

# Load neural network 

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 9,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = 1,
                                            max_detections = 4)

# This is what the neural network sees

In [None]:
image_stamps, subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    star_encoder.get_image_stamps(full_image.unsqueeze(0).unsqueeze(0), 
                                  true_locs.unsqueeze(0), true_fluxes.unsqueeze(0), trim_images=False)

In [None]:
plt.matshow(full_image[0:9, 0:14].abs())

In [None]:
vmax = full_image[0:9, 0:14].abs().max()
vmin = full_image[0:9, 0:14].abs().min()

In [None]:
f, axarr = plt.subplots(1, 4, figsize=(16, 8))

for i in range(4): 
    indx = i # int(np.random.choice(image_stamps.shape[0], 1))

    which_nonzero = is_on_array[indx].type(torch.bool)

    # Plot my image patch. 
    # plot subset of full image. They should match. 
    x0 = star_encoder.tile_coords[indx, 0]
    x1 = star_encoder.tile_coords[indx, 1]

    axarr[i].matshow(image_stamps[indx].squeeze(), vmin = vmin, vmax = vmax)


    _true_locs = true_locs * (star_encoder.full_slen - 1)

    which_locs = ((_true_locs[:, 0] > x0.float()) & (_true_locs[:, 1] > x1.float())) & \
                    (_true_locs[:, 0] < (x0 + star_encoder.stamp_slen).float() - 1) & \
                    (_true_locs[:, 1] < (x1 + star_encoder.stamp_slen).float() - 1)

    axarr[i].scatter(_true_locs[which_locs, 1] - x1, 
               _true_locs[which_locs, 0] - x0, 
               marker = 'o', color = 'b')
    
    
    axarr[i].axvline(x=star_encoder.edge_padding, color = 'r')
    axarr[i].axvline(x=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')
    axarr[i].axhline(y=star_encoder.edge_padding, color = 'r')
    axarr[i].axhline(y=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')
    

    plt.sca(axarr[i]); 
    plt.xticks(range(9), np.arange(x1, x1 + star_encoder.stamp_slen)); 


# Load Portillos results

In [None]:
results_dir = '../../multiband_pcat/pcat-lion-results/20191007-115851/'

chain_results = np.load(results_dir + 'chain.npz')

In [None]:
include_classical_catalogue = True

if include_classical_catalogue: 
    pcat_catalog = np.loadtxt(results_dir + 'classical_catalog.txt')
    
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
        
    fluxes = pcat_catalog[:, 4]
    
    # remove na
    is_na = np.isnan(x1_loc) | np.isnan(x1_loc) | np.isnan(fluxes)
    
    x1_loc = x1_loc[~is_na]
    x0_loc = x0_loc[~is_na]
    fluxes = fluxes[~is_na]
    
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes)
    
x1_loc_samples = chain_results['x'][-300:, ].flatten()
x0_loc_samples = chain_results['y'][-300:, ].flatten()

portillos_est_fluxes_sampled = torch.Tensor(chain_results['f'][0, -300:, ].flatten()) 
portillos_est_locs_sampled = torch.Tensor([x0_loc_samples, x1_loc_samples]).transpose(0,1) \
                                / (full_image.shape[-1] - 1)

In [None]:
# TODO: this was chosen empirically looking at image residuals! 
# Need to figure out the correct conversion ... 
fudge_factor = 1 / (1 - 0.83)
portillos_est_fluxes_sampled = portillos_est_fluxes_sampled * fudge_factor
portillos_est_fluxes = portillos_est_fluxes * fudge_factor

### get reconstruction mean 

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = 686.)

# only works if we have the classical catalogue
if include_classical_catalogue: 
    _locs = portillos_est_locs.unsqueeze(0) 
    _fluxes = torch.Tensor(portillos_est_fluxes).unsqueeze(0)
    _n_stars = torch.Tensor([len(x0_loc)]).type(torch.LongTensor)
    
    portillos_recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes,
                                                 n_stars = _n_stars,  
                                                 add_noise = False).squeeze()
    
    plt.matshow(portillos_recon_mean); 
    plt.colorbar()
    
portillos_residuals = portillos_recon_mean.squeeze() - full_image

# Load sleep phase 1 NN results

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder-10092019-reweighted_samples', 
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
# get parameters on the full image 
map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                           full_background.unsqueeze(0).unsqueeze(0))

In [None]:
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = map_n_stars_full, 
                                                 add_noise = False).squeeze()

vae_residuals = vae_recon_mean - full_image 

In [None]:
my_est_locs = map_locs_full_image.squeeze() 
my_est_fluxes = map_fluxes_full_image.squeeze()

# Compare

In [None]:
fig, axarr = plt.subplots(2, 3, figsize=(15, 12))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

###################
# Plot catalogs
##################
# my catalog
for j in range(2):
    plotting_utils.plot_subimage(axarr[j, 0], full_image, my_est_locs, true_locs, x0, x1, subimage_slen, 
                                add_colorbar = True, global_fig = fig)
    axarr[j, 0].set_title('observed; coords: {}\n'.format([x0, x1]));

    # portillos catalogue
    _portillos_est_locs = portillos_est_locs * (full_image.shape[-1] - 1)
    which_locs = (_portillos_est_locs[:, 0] > x0) & \
                    (_portillos_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                    (_portillos_est_locs[:, 1] > x1) & \
                    (_portillos_est_locs[:, 1] < (x1 + subimage_slen - 1))
    portillos_locs = (_portillos_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
    axarr[j, 0].scatter(portillos_locs[:, 1], portillos_locs[:, 0], color = 'c', marker = 'x')

#######################
# Reconstructions 
#######################
# my reconstruction
plotting_utils.plot_subimage(axarr[0, 1], vae_recon_mean, my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0, 1].set_title('starnet reconstructed\n');

# Portillos reconstruction
plotting_utils.plot_subimage(axarr[1, 1], portillos_recon_mean, portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            color = 'c', marker = 'x')
axarr[1, 1].set_title('portillos reconstructed\n');

######################
# residuals
######################
vmax1 = torch.abs((vae_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
vmax2 = torch.abs((portillos_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()

vmax = torch.max(torch.Tensor([vmax1, vmax2]))

# my residuals
plotting_utils.plot_subimage(axarr[0, 2], vae_residuals / full_image, 
                            my_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[0, 2].set_title('starnet residuals\n');



# portillos residuals
plotting_utils.plot_subimage(axarr[1, 2], portillos_residuals / full_image, 
                            portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax, 
                            color = 'c', marker = 'x')

axarr[1, 2].set_title('portillos residuals\n');

# Checkout some summary statistics

In [None]:
# we only look at locations within 10-90; 
# Portillos doesn't detect on the edge

def filter_params(locs, fluxes, slen): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > 10) & (_locs[:, 0] < 90) & \
                        (_locs[:, 1] > 10) & (_locs[:, 1] < 90) 
        
    
    return locs[which_params], fluxes[which_params]


In [None]:
my_est_locs, my_est_fluxes = filter_params(my_est_locs, my_est_fluxes, full_image.shape[-1])

portillos_est_locs, portillos_est_fluxes = filter_params(portillos_est_locs, portillos_est_fluxes, 
                                                         full_image.shape[-1])

true_locs, true_fluxes = filter_params(true_locs, true_fluxes, 
                                       full_image.shape[-1])

These are rather coarse measures. My completeness does not take into account the fact that several true stars might be matched with just one estimated star (so not all the true stars were detected); conversely my true positive rate does not take into account that several estimated stars might be matched with just one true star (so only one estimated star is a true positive). 

I tried the Hungarian algorithm to find a minimal matching, but this gave weird results because we're searching for a permutation that minimizes the **global** cost of the matching. 

In [None]:
import image_statistics_lib

In [None]:
my_completeness_vec, my_mag_vec, my_counts = \
    image_statistics_lib.get_completeness_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)

portillos_completeness_vec, portillos_mag_vec, portillos_counts = \
    image_statistics_lib.get_completeness_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)

plt.plot(my_mag_vec[0:-1], my_completeness_vec, '--x', color = 'b', label = 'Starnet')
sd =  np.sqrt(my_completeness_vec * (1 - my_completeness_vec) / my_counts)
plt.errorbar(my_mag_vec[0:-1], my_completeness_vec, 
             yerr = 3 * sd, 
             marker = 'x', linestyle = '--', color = 'b', alpha = 0.7)
plt.ylim(0.3, 1.01)



plt.plot(portillos_mag_vec[0:-1], portillos_completeness_vec, 
         '--x', color = 'r', label = 'Portillos')
sd = np.sqrt(portillos_completeness_vec * (1 - portillos_completeness_vec) / portillos_counts)
plt.errorbar(portillos_mag_vec[0:-1], portillos_completeness_vec, 
             yerr = 3 * sd, 
             marker = 'x', linestyle = '--', color = 'r', alpha = 0.7)

plt.legend(loc = 'lower right')
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
my_tpr_vec, my_mag_vec, my_counts = \
    image_statistics_lib.get_tpr_vec(my_est_locs, true_locs, full_image.shape[-1],
                                              my_est_fluxes, true_fluxes)

portillos_tpr_vec, portillos_mag_vec, portillos_counts = \
    image_statistics_lib.get_tpr_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)

plt.plot(my_mag_vec[0:-1], my_tpr_vec, '--x', color = 'blue', label = 'Starnet')
sd =  np.sqrt(my_tpr_vec * (1 - my_tpr_vec) / my_counts)
plt.errorbar(my_mag_vec[0:-1], my_tpr_vec, 
             yerr = 3 * sd, 
             marker = 'x', linestyle = '--', color = 'b', alpha = 0.7)



plt.plot(portillos_mag_vec[0:-1], portillos_tpr_vec, '--x', color = 'r', label = 'Portillos')
sd =  np.sqrt(portillos_tpr_vec * (1 - portillos_tpr_vec) / portillos_counts)
plt.errorbar(portillos_mag_vec[0:-1], portillos_tpr_vec, 
             yerr = 3 * sd, 
             marker = 'x', linestyle = '--', color = 'r', alpha = 0.7)

plt.legend(loc = 'lower right')
plt.xlabel('estimated log flux')
plt.ylabel('true positive rate')
plt.ylim(0.3, 1.01)


# Examining PSF transforms

In [None]:
# This is the original psf and its residuals
full_image_sim = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                        fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                        n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                        add_noise = True) - simulator.sky_intensity + \
                        sdss_hubble_data.sdss_background.unsqueeze(0)

In [None]:
resid = (full_image_sim.squeeze() - full_image) / full_image
vmax = resid.abs().max()
plt.matshow(resid, vmax = vmax, vmin = -vmax, 
           cmap=plt.get_cmap('bwr'))

plt.colorbar()

# Wake-sleep results

In [None]:
completeness_all = np.zeros(4)
tpr_all = np.zeros(4)

fig, axarr = plt.subplots(1, 2, figsize=(15, 4))


for i in range(0, 4): 
    if i == 0: 
        star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder-10092019-reweighted_samples', 
                                       map_location=lambda storage, loc: storage))
        
    else: 
        star_encoder.load_state_dict(torch.load('../fits/wake_sleep-portm2-101420129-encoder-iter' + str(i), 
                                       map_location=lambda storage, loc: storage))
    star_encoder.eval(); 
    
    if (i > 0) & (i < 3): 
        continue
        
    # get parameters
    map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
        star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                               full_background.unsqueeze(0).unsqueeze(0))
    
    est_locs, est_fluxes = filter_params(map_locs_full_image.squeeze(), 
                                           map_fluxes_full_image.squeeze(), 
                                           full_image.shape[-1])

    # take into account fluxes
    completeness, tpr, completeness1_bool, tpr1_bool = \
        image_statistics_lib.get_summary_stats(est_locs, true_locs, 
                                               full_image.shape[-1], 
                                               est_fluxes, true_fluxes)
    completeness_all[i] = completeness
    tpr_all[i] = tpr
    
    
    # get completeness as a function of magnitude  
    completeness1_vec, mag_vec1, _ = \
        image_statistics_lib.get_completeness_vec(est_locs, true_locs, full_image.shape[-1],
                                                  est_fluxes, true_fluxes)

    axarr[0].plot(mag_vec1[:-1], completeness1_vec, '--x', label = 'starnet_iter' + str(i))
        
    tpr_vec, mag_vec, _ = \
        image_statistics_lib.get_tpr_vec(est_locs, true_locs, full_image.shape[-1],
                                        est_fluxes, true_fluxes)

    axarr[1].plot(mag_vec[0:-1], tpr_vec, '--x', label = 'starnet_iter' + str(i))
    
    
    
# RECALL PORTILLOS RESULTS
portillos_completeness_vec, portillos_mag_vec, portillos_counts = \
    image_statistics_lib.get_completeness_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)
axarr[0].plot(portillos_mag_vec[0:-1], portillos_completeness_vec, '-x', label = 'Portillos')

portillos_tpr_vec, portillos_mag_vec, portillos_counts = \
    image_statistics_lib.get_tpr_vec(portillos_est_locs, true_locs, full_image.shape[-1],
                                              portillos_est_fluxes, true_fluxes)
axarr[1].plot(portillos_mag_vec[0:-1], portillos_tpr_vec, '-x', label = 'Portillos')

    
axarr[0].legend()
axarr[0].set_xlabel('true log10 flux')
axarr[0].set_ylabel('completeness')

axarr[1].legend()
axarr[1].set_xlabel('estimated log10 flux')
axarr[1].set_ylabel('tpr')




# Checkout trained psf

In [None]:
# reset simulator
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = 686.)

In [None]:
recon_mean_init = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                        fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                        n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                        add_noise = False) - simulator.sky_intensity + \
                        sdss_hubble_data.sdss_background.unsqueeze(0)

In [None]:
import psf_transform_lib

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(simulator.psf_og),
                                    simulator.slen,
                                    kernel_size = 3)

In [None]:
psf_transform.load_state_dict(torch.load('../fits/wake_sleep-portm2-101420129-psf_transform-iter2', 
                                             map_location=lambda storage, loc: storage))

In [None]:
init_psf = simulator.psf
trained_psf = psf_transform.forward()

fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

im0 = axarr[0].matshow(init_psf[45:56, 45:56])
axarr[0].set_title('initial sdss psf \n')
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(trained_psf[45:56, 45:56].detach())
axarr[1].set_title('wake-sleep trained psf \n')
fig.colorbar(im1, ax = axarr[1])

diff = trained_psf[45:56, 45:56].detach() - init_psf[45:56, 45:56]
im2 = axarr[2].matshow(diff, vmax = diff.abs().max(), vmin = -diff.abs().max(), cmap=plt.get_cmap('bwr'))
axarr[2].set_title('diff \n')
fig.colorbar(im2, ax = axarr[2])

In [None]:
simulator.psf = trained_psf

In [None]:
recon_mean_trained = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                        fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                        n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                        add_noise = False) - simulator.sky_intensity + \
                        sdss_hubble_data.sdss_background.unsqueeze(0)

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(10, 4))

resid = ((recon_mean_init.detach().squeeze() - full_image) / full_image)[10:90, 10:90]
vmax = resid.abs().max()
im0 = axarr[0].matshow(resid, vmax = vmax, vmin = -vmax, 
           cmap=plt.get_cmap('bwr'))
fig.colorbar(im0, ax=axarr[0])


resid = ((recon_mean_trained.detach().squeeze() - full_image) / full_image)[10:90, 10:90]
vmax = resid.abs().max()
im1 = axarr[1].matshow(resid, vmax = vmax, vmin = -vmax, 
           cmap=plt.get_cmap('bwr'))
fig.colorbar(im1, ax=axarr[1])
