In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import image_statistics_lib

import plotting_utils

np.random.seed(34534)


# Load the data

In [None]:
fmin = 1000

In [None]:
bands = [2]

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = bands)

# image 
full_image = sdss_hubble_data.sdss_image
full_background = sdss_hubble_data.sdss_background 

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > fmin)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
locs_sorted = torch.sort(true_locs[:, 1])[0]

x = locs_sorted[1:len(locs_sorted)] - locs_sorted[0:-1]

bins = plt.hist(x, density=True)[1]

# lambd = 1000 # 1 / x.numpy().mean() 
# pdf = lambd * np.exp(-lambd * bins)

# plt.plot(bins, pdf, 'x')

In [None]:
true_locs.shape

In [None]:
full_image = torch.Tensor(full_image)
print(full_image.shape)

In [None]:
plt.matshow(full_image[0])
plt.colorbar()

# Get simulator 

In [None]:
import fitsio

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

if len(bands) == 2: 
    psf_og = np.array([psf_r, psf_i])
elif len(bands) == 1: 
    psf_og = np.array([psf_r])
else: 
    assert 1 == 2, 'not implemented error'
    
sky_intensity = full_background.reshape(full_background.shape[0], -1).mean(1)


In [None]:
sky_intensity

In [None]:
simulator1 = simulated_datasets_lib.StarSimulator(psf=psf_og[0:1], 
                                                slen = full_image.shape[-1], 
                                                  transpose_psf = False,
                                                sky_intensity = sky_intensity[0:1])


simulator = simulated_datasets_lib.StarSimulator(psf=psf_og, 
                                                slen = full_image.shape[-1], 
                                                transpose_psf = False,
                                                sky_intensity = sky_intensity)



In [None]:
# simulator.psf = torch.Tensor(psf_trained)
# simulator1.psf = torch.Tensor(psf_trained)

# Simulation with ground truth

In [None]:
truth_recon = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                            fluxes = sdss_hubble_data.fluxes.unsqueeze(0),
                            n_stars = torch.Tensor([len(sdss_hubble_data.fluxes)]).type(torch.LongTensor), 
                            add_noise = False).squeeze(0)

In [None]:
for i in range(len(sdss_hubble_data.bands)): 
    foo = (truth_recon[i] - full_image[i]) / full_image[i]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
plt.matshow(foo[5:15, 55:65], vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
plt.colorbar()

# Load Portillos results

In [None]:
results_dir = '../../multiband_pcat/pcat-lion-results/20191107-115253/'

chain_results = np.load(results_dir + 'chain.npz')

In [None]:
# n bands 
chain_results['f'].shape

In [None]:
# fudge_factor = 1 / (1 - 0.83)
fudge_factor = sdss_hubble_data.sdss_data[0]['gain'][0] 

In [None]:
include_classical_catalogue = True

if include_classical_catalogue: 
    pcat_catalog = np.loadtxt(results_dir + 'classical_catalog.txt')
    
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
        
    fluxes = pcat_catalog[:, 4] * fudge_factor
    
    # remove na
    is_na = (fluxes < fmin) | np.isnan(fluxes)
    
    x1_loc = x1_loc[~is_na]
    x0_loc = x0_loc[~is_na]
    fluxes = fluxes[~is_na]
    
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes).unsqueeze(-1)
    
else: 
    # just take one sample 
    fluxes = chain_results['f'][:, -1, ].transpose() * fudge_factor
    
    x1_loc = chain_results['x'][-1, ].flatten()[fluxes[:, 0] > fmin]
    x0_loc = chain_results['y'][-1, ].flatten()[fluxes[:, 0] > fmin]
    
    fluxes = fluxes[fluxes[:, 0] > fmin]
        
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1) / (full_image.shape[-1] - 1)
    portillos_est_fluxes = torch.Tensor(fluxes) 
    

# x1_loc_samples = chain_results['x'][-300:, ].flatten()
# x0_loc_samples = chain_results['y'][-300:, ].flatten()

# portillos_est_fluxes_sampled = torch.Tensor(chain_results['f'][0, -300:, ].flatten()) * fudge_factor
# portillos_est_locs_sampled = torch.Tensor([x0_loc_samples, x1_loc_samples]).transpose(0,1) \
#                                 / (full_image.shape[-1] - 1)
    
# # filter by fmin
# port_which_bright = portillos_est_fluxes_sampled > fmin
# portillos_est_fluxes_sampled = portillos_est_fluxes_sampled[port_which_bright]
# portillos_est_locs_sampled = portillos_est_locs_sampled[port_which_bright]

In [None]:
chain_results['n'][-300:].mean()

In [None]:
chain_results['n'][-300:].std()

In [None]:
portillos_est_fluxes.shape

### get reconstruction mean 

In [None]:
_locs = portillos_est_locs.unsqueeze(0) 
_fluxes = portillos_est_fluxes.unsqueeze(0)
_n_stars = torch.Tensor([len(x0_loc)]).type(torch.LongTensor)

if _fluxes.shape[-1] == 1:
    portillos_recon_mean = simulator1.draw_image_from_params(locs = _locs, 
                                            fluxes = _fluxes,
                                             n_stars = _n_stars,  
                                             add_noise = False).squeeze(0)
else: 
    portillos_recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes,
                                                 n_stars = _n_stars,  
                                                 add_noise = False).squeeze()

plt.matshow(portillos_recon_mean[0]); 
plt.colorbar()

In [None]:
portillos_recon_mean.shape

In [None]:
full_image.shape

In [None]:
portillos_residuals = portillos_recon_mean - full_image

for i in range(portillos_recon_mean.shape[0]): 
    foo = (portillos_residuals[i] / full_image[i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
    plt.colorbar()

In [None]:
plt.matshow(foo[15:35, 50:70], vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
plt.colorbar()

# My starnet result

In [None]:
star_encoder1 = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = len(bands),
                                            max_detections = 2)

In [None]:
star_encoder1.load_state_dict(torch.load('../fits/results_11202019/starnet_r', 
                               map_location=lambda storage, loc: storage))


star_encoder1.eval(); 


In [None]:
import time

In [None]:
# get parameters on the full image 
# map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
#     star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
#                                            full_background.unsqueeze(0).unsqueeze(0))

t0 = time.time()
map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    star_encoder1.sample_star_encoder(full_image.unsqueeze(0), 
                                    full_background.unsqueeze(0), 
                                    return_map = True)[0:3]
    
print(time.time() - t0)

In [None]:
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = map_n_stars_full, 
                                                 add_noise = False).squeeze(0)

vae_residuals = vae_recon_mean - full_image

In [None]:
band = 0

In [None]:
for i in range(vae_residuals.shape[0]): 
    foo = (vae_residuals[i] / full_image[i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
    plt.colorbar()

# Results after wake sleep

In [None]:
star_encoder2 = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = len(bands),
                                            max_detections = 2)

In [None]:
star_encoder2.load_state_dict(torch.load('../fits/results_11202019/wake-sleep_630x310_r-encoder-iter6', 
                               map_location=lambda storage, loc: storage))


star_encoder2.eval(); 


In [None]:
# get parameters on the full image 
# map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
#     star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
#                                            full_background.unsqueeze(0).unsqueeze(0))

map_locs_full_image2, map_fluxes_full_image2, map_n_stars_full2 = \
    star_encoder2.sample_star_encoder(full_image.unsqueeze(0), 
                                    full_background.unsqueeze(0), 
                                    return_map = True)[0:3]

In [None]:
vae_recon_mean2 = simulator.draw_image_from_params(locs = map_locs_full_image2, 
                                                fluxes = map_fluxes_full_image2,
                                                 n_stars = map_n_stars_full2, 
                                                 add_noise = False).squeeze(0)

vae_residuals2 = vae_recon_mean2 - full_image

In [None]:
band = 0

In [None]:
for i in range(vae_residuals2.shape[0]): 
    foo = (vae_residuals2[i] / full_image[i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
    plt.colorbar()

In [None]:
map_n_stars_full2

# Checkout some summary statistics

In [None]:
map_n_stars_full

In [None]:
map_n_stars_full2

In [None]:
len(portillos_est_fluxes)

In [None]:
len(true_fluxes)

In [None]:
my_completeness1, my_tpr1, my_complete_bool1, my_tpr_bool = \
    image_statistics_lib.get_summary_stats(map_locs_full_image.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0])

my_completeness2, my_tpr2, my_complete_bool2, my_tpr_bool = \
    image_statistics_lib.get_summary_stats(map_locs_full_image2.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image2.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0])
    

portillos_completeness, portillos_tpr, portillos_complete_bool, portillos_tpr_bool = \
    image_statistics_lib.get_summary_stats(portillos_est_locs, true_locs, 
                                           full_image.shape[-1], 
                                           portillos_est_fluxes[:, 0], 
                                           true_fluxes[:, 0])

    
print('my completeness 1: {:0.3f}'.format(my_completeness1))
print('my completeness 2: {:0.3f}'.format(my_completeness2))
print('portillos completeness: {:0.3f}\n'.format(portillos_completeness))

print('my true positive rate 1: {:0.3f}'.format(my_tpr1))
print('my true positive rate 2: {:0.3f}'.format(my_tpr2))
print('portillos true positive rate: {:0.3f}'.format(portillos_tpr))

In [None]:
portillos_locs_error, portillos_fluxes_error = image_statistics_lib.get_l1_error(portillos_est_locs, true_locs, 
                                       full_image.shape[-1], 
                                       portillos_est_fluxes[:, 0], 
                                       true_fluxes[:, 0])

my_locs_error1, my_fluxes_error1 = image_statistics_lib.get_l1_error(map_locs_full_image.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0])

my_locs_error2, my_fluxes_error2 = image_statistics_lib.get_l1_error(map_locs_full_image2.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image2.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0])

In [None]:
#### loc errors 
print((my_locs_error1.mean(), my_locs_error1.std() / np.sqrt(len(my_locs_error1))))
print((my_locs_error2.mean(), my_locs_error2.std() / np.sqrt(len(my_locs_error1))))
print((portillos_locs_error.mean(), portillos_locs_error.std() / np.sqrt(len(portillos_locs_error))))


In [None]:
# flux errors 
print((my_fluxes_error1.mean(), my_fluxes_error1.std() / np.sqrt(len(my_fluxes_error1))))
print((my_fluxes_error2.mean(), my_fluxes_error2.std() / np.sqrt(len(my_fluxes_error1))))
print((portillos_fluxes_error.mean(), portillos_fluxes_error.std() / np.sqrt(len(portillos_fluxes_error))))


# Compare sleep vs portillos

In [None]:
save_figs = False

In [None]:
my_completeness_vec1, my_comp_mag_vec1 = \
    image_statistics_lib.get_completeness_vec(map_locs_full_image.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0], 
                                             mag_vec = None)[0:2]

portillos_completeness, portillos_comp_mag_vec = \
    image_statistics_lib.get_completeness_vec(portillos_est_locs.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           portillos_est_fluxes.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0], 
                                             mag_vec = None)[0:2]


In [None]:
my_tpr_vec1, my_tpr_mag_vec1 = \
    image_statistics_lib.get_tpr_vec(map_locs_full_image.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0], 
                                             mag_vec = None)[0:2]

portillos_tpr, portillos_tpr_mag_vec = \
    image_statistics_lib.get_tpr_vec(portillos_est_locs.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           portillos_est_fluxes.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0], 
                                             mag_vec = None)[0:2]

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(12, 4))

axarr[0].plot(my_comp_mag_vec1[0:-1], my_completeness_vec1, 'r--x', label = 'VI sleep-phase')
axarr[0].plot(portillos_comp_mag_vec[0:-1], portillos_completeness, 'b--x', label = 'Portillos')

# axarr[0].legend()
axarr[0].set_xlabel('true log flux', size = 16)
axarr[0].set_ylabel('TPR', size = 16)


axarr[1].plot(my_tpr_mag_vec1[0:-1], my_tpr_vec1, 'r--x', label = 'VI sleep-phase')
axarr[1].plot(portillos_tpr_mag_vec[0:-1], portillos_tpr, 'b--x', label = 'Portillos')

axarr[1].legend(fontsize = 16)
axarr[1].set_xlabel('estimated log flux', size = 16)
axarr[1].set_ylabel('PPV', size = 16)

fig.tight_layout()

if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/sleep_vs_portillos.png')

# Compare sleep vs wake-sleep

In [None]:
true_mags = torch.log10(true_fluxes[:, 0]).numpy()

In [None]:
mag_vec = np.percentile(true_mags, np.arange(0, 110, 10))

In [None]:
my_completeness_vec1, my_comp_mag_vec1 = \
    image_statistics_lib.get_completeness_vec(map_locs_full_image.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0], 
                                             mag_vec = None)[0:2]

my_completeness_vec2, my_comp_mag_vec2 = \
    image_statistics_lib.get_completeness_vec(map_locs_full_image2.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image2.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0], 
                                             mag_vec = None)[0:2]

In [None]:
mags = torch.log10(map_fluxes_full_image.squeeze(0)[:, 0]).numpy()
mag_vec = np.percentile(mags, np.arange(0, 110, 10))

In [None]:
my_tpr_vec1, my_tpr_mag_vec1, counts_vec  = \
    image_statistics_lib.get_tpr_vec(map_locs_full_image.squeeze(0), 
                                           true_locs, 
                                           full_image.shape[-1], 
                                           map_fluxes_full_image.squeeze(0)[:, 0], 
                                           true_fluxes[:, 0], mag_vec = None)

my_tpr_vec2, my_tpr_mag_vec2, counts_vec2 = \
    image_statistics_lib.get_tpr_vec(map_locs_full_image2.squeeze(0), 
                                       true_locs, 
                                       full_image.shape[-1], 
                                       map_fluxes_full_image2.squeeze(0)[:, 0], 
                                       true_fluxes[:, 0], mag_vec = None)

# plt.plot(my_mag_vec1[0:-1], my_tpr_vec1, '--x', label = 'starnet-iter0')
# plt.plot(my_mag_vec2[0:-1], my_tpr_vec2, '--x', label = 'starnet-iter6')

# plt.legend()
# plt.xlabel('true log flux')
# plt.ylabel('tpr')

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(12, 4))

# TPR
axarr[0].plot(my_comp_mag_vec1[0:-1], my_completeness_vec1, 'r--x', label = 'sleep-phase only')
axarr[0].plot(my_comp_mag_vec2[0:-1], my_completeness_vec2, '--x', color = 'orange', label = 'wake-sleep')

axarr[0].legend(fontsize = 14)
axarr[0].set_xlabel('true log flux', size = 16)
axarr[0].set_ylabel('TPR', size = 16)

# PPV
axarr[1].plot(my_tpr_mag_vec1[0:-1], my_tpr_vec1, 'r--x', label = 'VI sleep-phase only ')
axarr[1].plot(my_tpr_mag_vec2[0:-1], my_tpr_vec2, '--x', color = 'orange', label = 'VI wake-sleep')

# axarr[0].legend()
axarr[1].set_xlabel('true log flux', size = 16)
axarr[1].set_ylabel('PPV', size = 16)

fig.tight_layout()

if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/wake_sleep_curves.png')

In [None]:
my_complete_bool1.shape

In [None]:
my_complete_bool2.shape

In [None]:
_true_locs, _true_fluxes = image_statistics_lib.filter_params(true_locs, true_fluxes[:, 0], 101)

In [None]:
# which_bool = (torch.log10(_true_fluxes).flatten() > 4.0) & \
#                 (my_complete_bool1 == 0) & \
#                 (my_complete_bool2 == 1)
        

which_bool = (torch.log10(_true_fluxes).flatten() > 5.0)

In [None]:
which_indx = torch.nonzero(which_bool)

In [None]:
which_indx

In [None]:
_true_locs[44]

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(12, 3))


x0 = 85
x1 = 15
plotting_utils.plot_subimage(axarr[0], full_image[0], 
                             map_locs_full_image.squeeze(), 
                             true_locs, 
                             x0, 
                             x1, subimage_slen = 10, 
                            add_colorbar = True, 
                             global_fig = fig)

  
plotting_utils.plot_subimage(axarr[1], full_image[0], 
                             map_locs_full_image2.squeeze(), 
                             true_locs, 
                             x0, 
                             x1, subimage_slen = 10, 
                            add_colorbar = True, 
                             global_fig = fig)


# Lets take a closer look at sleep vs wake-sleep

### Training of PSF

In [None]:
import psf_transform_lib

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1], 
                                    kernel_size = 3)

psf_transform.load_state_dict(torch.load('../fits/results_11202019/wake-sleep_630x310_r-psf_transform-iter5', 
                                             map_location=lambda storage, loc: storage))


In [None]:
simulator_psf_trained = simulated_datasets_lib.StarSimulator(psf=psf_og, 
                                                slen = full_image.shape[-1], 
                                                transpose_psf = False,
                                                sky_intensity = sky_intensity)

simulator_psf_trained.psf = psf_transform.forward().detach()

In [None]:
foo = ((truth_recon - full_image) / full_image)[band, 5:95, 5:95]
plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
(foo).abs().mean()

In [None]:
truth_recon_trained = simulator_psf_trained.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                            fluxes = sdss_hubble_data.fluxes.unsqueeze(0),
                            n_stars = torch.Tensor([len(sdss_hubble_data.fluxes)]).type(torch.LongTensor), 
                            add_noise = False).squeeze(0)

In [None]:
foo = ((truth_recon_trained - full_image) / full_image)[band, 5:95, 5:95]
plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
(foo).abs().mean()

In [None]:
true_fluxes.shape

In [None]:
fig, axarr = plt.subplots(1,1, figsize=(5, 4))

foo = ((truth_recon - full_image) / truth_recon)
plotting_utils.plot_subimage(axarr, foo[band] - foo[band].mean(), 
                             true_locs[true_fluxes[:, 0] > 10**(4.3)],
                             None, 
                             x0 = 55, 
                            x1 = 5, 
                            subimage_slen = 25, 
                            diverging_cmap = True, 
                            color = 'navy', marker = 'o', 
                            add_colorbar = True, 
                             vmax = 0.21, vmin = -0.21, 
                            global_fig = fig)

axarr.set_title('residuals \n', size = 16)

fig.tight_layout()

if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/psf_misfit_ex_patch.png')

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(12, 4))

foo = ((truth_recon - full_image) / full_image)
plotting_utils.plot_subimage(axarr[0], foo[band] - foo[band].mean(), 
                             true_locs[true_fluxes[:, 0] > 10**(4.3)],
                             None, 
                             x0 = 55, 
                            x1 = 5, 
                            subimage_slen = 25, 
                            diverging_cmap = True, 
                            color = 'navy', marker = 'o', 
                            add_colorbar = True, 
                             vmax = 0.21, vmin = -0.21, 
                            global_fig = fig)

foo1 = ((truth_recon_trained - full_image) / full_image)
plotting_utils.plot_subimage(axarr[1], foo1[band] - foo1[band].mean(), 
                             true_locs[true_fluxes[:, 0] > 10**(4.3)],
                             None, 
                             x0 = 55, 
                            x1 = 5, 
                            subimage_slen = 25, 
                            diverging_cmap = True, 
                            color = 'navy', marker = 'o', 
                            add_colorbar = True, 
                             vmax = 0.21, vmin = -0.21, 
                            global_fig = fig)


axarr[0].set_title('residuals: SDSS psf \n', size = 16)
axarr[1].set_title('residuals: wake-sleep psf \n', size = 16)

fig.tight_layout()

if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/residuals_psf_training.png')

In [None]:
save_figs

# Show reverse KL does not work

In [None]:
star_encoder_kl = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = psf_og.shape[0],
                                            max_detections = 2, 
                                            fmin = 1000.)

star_encoder_kl.load_state_dict(torch.load('../fits/results_11202019/kl_starnet2', 
                               map_location=lambda storage, loc: storage))
star_encoder_kl.eval(); 


In [None]:
# get parameters on the full image 
map_locs_full_imagekl, map_fluxes_full_imagekl, map_n_stars_fullkl = \
    star_encoder_kl.sample_star_encoder(full_image.unsqueeze(0), 
                                    full_background.unsqueeze(0), 
                                    return_map = True)[0:3]

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(12, 3))

x0_vec = [55, 56, 23]
x1_vec = [19, 71, 32]
for i in range(3): 
    x0 = x0_vec[i]
    x1 = x1_vec[i]
    
    plotting_utils.plot_subimage(axarr[i], full_image[0], 
                                     map_locs_full_imagekl.squeeze(), 
                                     true_locs, 
                                     x0, x1, subimage_slen = 10, 
                                    add_colorbar = True, 
                                     global_fig = fig)

fig.tight_layout()
if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/reverse_kl_fails.png')

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(12, 3))

x0_vec = [55, 56, 23]
x1_vec = [19, 71, 32]
for i in range(3): 
    x0 = x0_vec[i]
    x1 = x1_vec[i]
    
    plotting_utils.plot_subimage(axarr[i], full_image[0], 
                                     map_locs_full_image.squeeze(), 
                                     true_locs, 
                                     x0, x1, subimage_slen = 10, 
                                    add_colorbar = True, 
                                     global_fig = fig)
fig.tight_layout()
if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/forward_kl_better.png')

In [None]:
save_figs = True

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(10, 4))

plotting_utils.plot_subimage(axarr[0], full_image[0], 
                                     map_locs_full_imagekl.squeeze(), 
                                     true_locs, 
                                     x0, x1, subimage_slen = 10, 
                                    add_colorbar = True, 
                                     global_fig = fig)

plotting_utils.plot_subimage(axarr[1], full_image[0], 
                                     map_locs_full_image.squeeze(), 
                                     true_locs, 
                                     x0, x1, subimage_slen = 10, 
                                    add_colorbar = True, 
                                     global_fig = fig)

axarr[0].set_title('E-step inferred locations\n', size = 16)
axarr[1].set_title('Sleep phase inferred locations\n', size = 16)
fig.tight_layout()

if save_figs: 
    fig.savefig('../../qualifying_exam_slides/figures/kl_vs_invkl.png')

# Sample images

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 6.5))

x0_vec = [53, 31, 41, 32]
x1_vec = [70, 83, 23, 64]


for i in range(4): 
    x0 = x0_vec[i]
    x1 = x1_vec[i]
    subimage_slen = 10
    
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], full_image[0], 
                                         None, 
                                         true_locs, 
                                         x0_vec[i], 
                                         x1_vec[i], subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)
    
    
#     axarr[i // 2, i % 2].set_title('observed; coords: {}\n'.format([x0, x1]));

    # portillos catalogue
#     _portillos_est_locs = portillos_est_locs * (full_image.shape[-1] - 1)
#     which_locs = (_portillos_est_locs[:, 0] > x0) & \
#                     (_portillos_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
#                     (_portillos_est_locs[:, 1] > x1) & \
#                     (_portillos_est_locs[:, 1] < (x1 + subimage_slen - 1))
#     portillos_locs = (_portillos_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
#     axarr[i // 2, i % 2].scatter(portillos_locs[:, 1], portillos_locs[:, 0], color = 'c', marker = 'x')

plt.tight_layout()

if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/portillos_m2_image_subpatches.png')

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 6.5))

x0_vec = [53, 31, 41, 32]
x1_vec = [70, 83, 23, 64]


for i in range(4): 
    x0 = x0_vec[i]
    x1 = x1_vec[i]
    subimage_slen = 10
    
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], full_image[0], 
                                         map_locs_full_image2.squeeze(), 
                                         true_locs, 
                                         x0_vec[i], 
                                         x1_vec[i], subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)
    
    
#     axarr[i // 2, i % 2].set_title('observed; coords: {}\n'.format([x0, x1]));

    # portillos catalogue
    _portillos_est_locs = portillos_est_locs * (full_image.shape[-1] - 1)
    which_locs = (_portillos_est_locs[:, 0] > x0) & \
                    (_portillos_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                    (_portillos_est_locs[:, 1] > x1) & \
                    (_portillos_est_locs[:, 1] < (x1 + subimage_slen - 1))
    portillos_locs = (_portillos_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
    axarr[i // 2, i % 2].scatter(portillos_locs[:, 1], portillos_locs[:, 0], color = 'c', marker = 'x')

plt.tight_layout()
if save_figs: 
    plt.savefig('../../qualifying_exam_slides/sample_figures.png')

In [None]:
image_stamps, _, _, \
    subimage_n_stars, _ = \
        star_encoder1.get_image_stamps(full_image.unsqueeze(0), 
                                       true_locs.unsqueeze(0), 
                                       true_fluxes.unsqueeze(0))
        
background_stamps = star_encoder1.get_image_stamps(full_background.unsqueeze(0),
                            locs = None, fluxes = None, trim_images = False)[0]

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

indx = int(np.random.choice(image_stamps.shape[0], 1))
# indx = int(np.random.choice(torch.where(true_subimage_n_stars == 2)[0].numpy(), 1))

plotting_utils.plot_subimage(axarr[0], full_image[band],
                            map_locs_full_image2.squeeze(), 
                            true_locs, 
                            int(star_encoder2.tile_coords[indx, 0]), 
                            int(star_encoder2.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder2.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

# plotting_utils.plot_subimage(axarr[1], [0, band],
#                             map_locs_full_image.squeeze(), 
#                             None, 
#                             int(star_encoder.tile_coords[indx, 0]), 
#                             int(star_encoder.tile_coords[indx, 1]), 
#                             subimage_slen = star_encoder.stamp_slen, 
#                             add_colorbar = True, 
#                             global_fig = f)

# foo = vae_recon_mean[0, band] - images_full[0, band]
# plotting_utils.plot_subimage(axarr[2], foo, 
#                             map_locs_full_image.squeeze(), 
#                             None, 
#                             int(star_encoder.tile_coords[indx, 0]), 
#                             int(star_encoder.tile_coords[indx, 1]), 
#                             subimage_slen = star_encoder.stamp_slen, 
#                             add_colorbar = True, 
#                             global_fig = f, 
#                             diverging_cmap = True)

axarr[0].axvline(x=2, color = 'r')
axarr[0].axvline(x=4, color = 'r')
axarr[0].axhline(y=2, color = 'r')
axarr[0].axhline(y=4, color = 'r')

axarr[1].axvline(x=2, color = 'r')
axarr[1].axvline(x=4, color = 'r')
axarr[1].axhline(y=2, color = 'r')
axarr[1].axhline(y=4, color = 'r')

# Uncertainties?

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
    star_encoder2.sample_star_encoder(full_image.unsqueeze(0), 
                                    full_background.unsqueeze(0), 
                                    return_map = False, 
                                    n_samples = 100)[0:3]

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 6.5))

x0_vec = [53, 31, 41, 32]
x1_vec = [70, 83, 23, 64]


for i in range(4): 
    x0 = x0_vec[i]
    x1 = x1_vec[i]
    subimage_slen = 10
    
    # posterior samples
    _sampled_locs = sampled_locs_full_image * (full_image.shape[-1] - 1)
    which_locs = (_sampled_locs[:, :, 0] > x0) & \
                    (_sampled_locs[:, :, 0] < (x0 + subimage_slen - 1)) & \
                    (_sampled_locs[:, :, 1] > x1) & \
                    (_sampled_locs[:, :, 1] < (x1 + subimage_slen - 1))
    sampled_locs = (_sampled_locs[which_locs, :] - torch.Tensor([[[x0, x1]]])) 
    axarr[i // 2, i % 2].scatter(sampled_locs[:, :, 1].flatten(), 
                                 sampled_locs[:, :, 0].flatten(), 
                                 color = 'r', marker = 'x', alpha = 0.1)

    # map estimates
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], full_image[0], 
                                         map_locs_full_image2.squeeze(), 
                                         true_locs, 
                                         x0_vec[i], 
                                         x1_vec[i], subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig, color = 'green')
    
    
#     axarr[i // 2, i % 2].set_title('observed; coords: {}\n'.format([x0, x1]));
        
    # axarr[i // 2, i % 2].set_title();

plt.tight_layout()
if save_figs: 
    plt.savefig('../../qualifying_exam_slides/figures/sample_figures_my_posterior_samples.png')

In [None]:
# now CONDITION on the true number of stars. 
# quantify uncertainties
# get image stamps

image_stamps, true_subimage_locs, true_subimage_fluxes, \
    true_subimage_n_stars, true_is_on_array = \
        star_encoder1.get_image_stamps(full_image.unsqueeze(0), 
                                       true_locs.unsqueeze(0), 
                                       true_fluxes.unsqueeze(0), 
                                      trim_images = False, clip_max_stars = True)
    
background_stamps = star_encoder1.get_image_stamps(full_background.unsqueeze(0), None, None, 
                                      trim_images = False)[0]

In [None]:
# Note that these variational parameters are estimated using the true number of stars!
stamp_logit_loc_mean, stamp_logit_loc_log_var, \
    stamp_log_flux_mean, stamp_log_flux_log_var, stamp_log_probs = \
        star_encoder1(image_stamps, background_stamps, true_subimage_n_stars)

In [None]:
# we really just want the permutation
loss, counter_loss, locs_loss, fluxes_loss, perm_indx = \
    inv_kl_lib.get_encoder_loss(star_encoder2, full_image.unsqueeze(0), 
                                full_background.unsqueeze(0),
                                true_locs.unsqueeze(0), 
                                true_fluxes.unsqueeze(0))[0:5]

In [None]:
loss

In [None]:
from itertools import permutations

perm_list = []
for perm in permutations(range(star_encoder1.max_detections)):
    perm_list.append(perm)
    
perm = np.zeros((image_stamps.shape[0], star_encoder1.max_detections))
for i in range(image_stamps.shape[0]): 
    perm[i, :] = perm_list[perm_indx[i]]

In [None]:
# permute true parameters 
def permute_params(locs, fluxes, perm): 
    batchsize = perm.shape[0]
    max_stars = perm.shape[1]
    
    n_bands = fluxes.shape[-1]

    locs_perm = torch.zeros((batchsize, max_stars, 2))
    fluxes_perm = torch.zeros((batchsize, max_stars, n_bands))
    seq_tensor = torch.LongTensor([i for i in range(batchsize)])

    for i in range(max_stars):
        locs_perm[:, i, :] = locs[seq_tensor, perm[:, i], :]
        fluxes_perm[:, i, :] = fluxes[seq_tensor, perm[:, i], :]
        
    return locs_perm, fluxes_perm

In [None]:
stamp_logit_loc_mean, stamp_log_flux_mean = permute_params(stamp_logit_loc_mean, stamp_log_flux_mean, perm)
stamp_logit_loc_log_var, stamp_log_flux_log_var = \
    permute_params(stamp_logit_loc_log_var, stamp_log_flux_log_var, perm)

In [None]:
import utils

In [None]:
# check same patterning of nonzero entries: true because we conditioned on the true number of stars
assert ((stamp_logit_loc_mean != 0).float() == (true_subimage_locs != 0).float()).all()
assert ((stamp_log_flux_mean != 0).float() == (true_subimage_fluxes != 0).float()).all()

true_subimage_logit_locs = utils._logit(true_subimage_locs) * (true_subimage_locs != 0).float()
true_subimage_log_fluxes = torch.log(true_subimage_fluxes + 1e-16) * (true_subimage_fluxes != 0).float()

In [None]:
plt.plot(stamp_logit_loc_mean.flatten()[stamp_logit_loc_mean.flatten() != 0].detach().numpy(), 
         true_subimage_logit_locs.flatten()[true_subimage_logit_locs.flatten() != 0].numpy(), '+')

plt.plot(true_subimage_logit_locs.flatten()[stamp_logit_loc_mean.flatten() != 0].detach().numpy(), 
         true_subimage_logit_locs.flatten()[true_subimage_logit_locs.flatten() != 0].numpy(), '-')

plt.xlabel('estimated')
plt.ylabel('truth')

In [None]:
plt.plot(stamp_log_flux_mean.flatten()[stamp_log_flux_mean.flatten() != 0].detach().numpy(), 
         true_subimage_log_fluxes.flatten()[stamp_log_flux_mean.flatten() != 0].numpy(), '+')

plt.plot(true_subimage_log_fluxes.flatten()[stamp_log_flux_mean.flatten() != 0].detach().numpy(), 
         true_subimage_log_fluxes.flatten()[stamp_log_flux_mean.flatten() != 0].numpy(), '-')

plt.xlabel('estimated')
plt.ylabel('truth')

In [None]:
zscore_locs = (stamp_logit_loc_mean.flatten()[stamp_logit_loc_mean.flatten() != 0] - \
            true_subimage_logit_locs.flatten()[true_subimage_logit_locs.flatten() != 0]) / \
            torch.exp(0.5 * stamp_logit_loc_log_var.flatten()[true_subimage_logit_locs.flatten() != 0])
    
zscore_fluxes = (stamp_log_flux_mean.flatten()[stamp_log_flux_mean.flatten() != 0] - \
            true_subimage_log_fluxes.flatten()[true_subimage_log_fluxes.flatten() != 0]) / \
            torch.exp(0.5 * stamp_log_flux_log_var.flatten()[stamp_log_flux_log_var.flatten() != 0])

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(12, 4))

# zscores for logit locations
n, bins, patches = axarr[0].hist(zscore_locs.detach().clamp(max = 10), bins = 100, density = True); 

normal = torch.distributions.normal.Normal(loc=0, scale = 1)
normal_pdf = torch.exp(normal.log_prob(torch.Tensor(bins)))

axarr[0].plot(bins, normal_pdf.numpy(), color = 'red', linewidth = 2)

axarr[0].set_xlabel('z-score', size = 16)
axarr[0].set_ylabel('density', size = 16)
axarr[0].set_title('z-score for logit-locations', size = 16)

# zscore for log fluxes
n, bins, patches = axarr[1].hist(zscore_fluxes.detach(), bins = 100, density = True); 

normal = torch.distributions.normal.Normal(loc=0, scale = 1)
normal_pdf = torch.exp(normal.log_prob(torch.Tensor(bins)))

axarr[1].plot(bins, normal_pdf.numpy(), color = 'red', linewidth = 2)

axarr[1].set_xlabel('z-score', size = 16)
axarr[1].set_ylabel('density', size = 16)
axarr[1].set_title('z-score for log-fluxes', size = 16)

fig.tight_layout()

if save_figs: 
    fig.savefig('../../qualifying_exam_slides/figures/zscores.png')

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
    star_encoder2.sample_star_encoder(full_image.unsqueeze(0), 
                                    full_background.unsqueeze(0), 
                                    return_map = Fa, 
                                    n_samples = 300)[0:3]

In [None]:
map_n_stars_full2

In [None]:
n_samples = sampled_locs_full_image.shape[0]
completeness_sampled = torch.zeros(n_samples)
tpr_sampled = torch.zeros(n_samples)

for i in range(300): 
    my_completeness1, my_tpr1, my_complete_bool1, my_tpr_bool = \
        image_statistics_lib.get_summary_stats(sampled_locs_full_image[i], 
                                               true_locs, 
                                               full_image.shape[-1], 
                                               sampled_fluxes_full_image[i][:, 0], 
                                               true_fluxes[:, 0])

    completeness_sampled[i] = my_completeness1
    tpr_sampled[i] = my_tpr1

In [None]:
completeness_sampled.mean()

In [None]:
tpr_sampled.mean()

In [None]:
# portillos posterior samples 

fluxes = chain_results['f'][0, -300:, ] * fudge_factor

x1_loc = chain_results['x'][-300:, ]
x0_loc = chain_results['y'][-300:, ]


fluxes = fluxes * (fluxes > fmin)

x0_loc = x0_loc * (fluxes[:, :] > fmin)
x1_loc = x1_loc * (fluxes[:, :] > fmin)

portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0, 2).transpose(0, 1) / (full_image.shape[-1] - 1)
portillos_est_fluxes = torch.Tensor(fluxes)

In [None]:
n_stars = (portillos_est_fluxes > 0).sum(1)

portillos_recon_mean = simulator1.draw_image_from_params(locs = portillos_est_locs[0:2], 
                                                            fluxes = portillos_est_fluxes[0:2].unsqueeze(2),
                                                             n_stars = n_stars[0:2],  
                                                             add_noise = False).squeeze(0)

plt.matshow(portillos_recon_mean[1, 0]); 
plt.colorbar()

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 6.5))

x0_vec = [53, 31, 41, 32]
x1_vec = [70, 83, 23, 64]


for i in range(4): 
    x0 = x0_vec[i]
    x1 = x1_vec[i]
    subimage_slen = 10
    
    # posterior samples
    _sampled_locs = portillos_est_locs * (full_image.shape[-1] - 1)
    which_locs = (_sampled_locs[:, :, 0] > x0) & \
                    (_sampled_locs[:, :, 0] < (x0 + subimage_slen - 1)) & \
                    (_sampled_locs[:, :, 1] > x1) & \
                    (_sampled_locs[:, :, 1] < (x1 + subimage_slen - 1))
    sampled_locs = (_sampled_locs[which_locs, :] - torch.Tensor([[[x0, x1]]])) 
    axarr[i // 2, i % 2].scatter(sampled_locs[:, :, 1].flatten(), 
                                 sampled_locs[:, :, 0].flatten(), 
                                 color = 'r', marker = 'x', alpha = 0.1)

    # map estimates
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], full_image[0], 
                                         None, 
                                         true_locs, 
                                         x0_vec[i], 
                                         x1_vec[i], subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig, color = 'green')
    
    
#     axarr[i // 2, i % 2].set_title('observed; coords: {}\n'.format([x0, x1]));
        
    # axarr[i // 2, i % 2].set_title();

plt.tight_layout()

plt.savefig('../../qualifying_exam_slides/figures/sample_figures_port_posterior_samples.png')

In [None]:
n_samples = portillos_est_fluxes.shape[0]
completeness_sampled = torch.zeros(n_samples)
tpr_sampled = torch.zeros(n_samples)

for i in range(n_samples): 
    my_completeness1, my_tpr1, my_complete_bool1, my_tpr_bool = \
        image_statistics_lib.get_summary_stats(portillos_est_locs[i], 
                                               true_locs, 
                                               full_image.shape[-1], 
                                               portillos_est_fluxes[i], 
                                               true_fluxes[:, 0])

    completeness_sampled[i] = my_completeness1
    tpr_sampled[i] = my_tpr1

In [None]:
completeness_sampled.mean()

In [None]:
tpr_sampled.mean()

In [None]:
(portillos_est_fluxes > 0).sum(1).float().mean()

# results at center of cluster?

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[0][870:970, 160:260])
plt.savefig('../../qualifying_exam_slides/figures/sdss_image_center.png')

In [None]:
sdss_hubble_data_center = sdss_dataset_lib.SDSSHubbleData(bands = bands, x0 = 870, x1 = 160)


In [None]:
(sdss_hubble_data_center.fluxes > 1000.).sum()

In [None]:
plt.matshow(sdss_hubble_data_center.sdss_image[0])

In [None]:
map_locs, map_fluxes, map_n_stars = \
    star_encoder2.sample_star_encoder(full_image=sdss_hubble_data_center.sdss_image.unsqueeze(0), 
                                 full_background=sdss_hubble_data_center.sdss_background.unsqueeze(0), 
                                     return_map = True)[0:3]

In [None]:
recon_mean = simulator.draw_image_from_params(locs = map_locs, 
                                                fluxes = map_fluxes,
                                                 n_stars = map_n_stars, 
                                                 add_noise = False).squeeze(0)

In [None]:
foo = recon_mean - sdss_hubble_data_center.sdss_image
plt.matshow(foo[0], vmax = foo.abs().max(), vmin = foo.abs().max() * -1, cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 6.5))

for i in range(4): 
    x0 = int(np.random.choice(90, 1))
    x1 = int(np.random.choice(90, 1))
    subimage_slen = 10
    
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], full_image[0], 
                             None, # map_locs.squeeze(), 
                             sdss_hubble_data_center.locs[sdss_hubble_data_center.fluxes[:, 0] > 10000.], 
                             x0, 
                             x1, subimage_slen = 10, 
                            add_colorbar = True, 
                             global_fig = fig)




# ROC curves on image stamps?

In [None]:
image_stamps, _, _, \
    subimage_n_stars, _ = \
        star_encoder1.get_image_stamps(full_image.unsqueeze(0), 
                                       true_locs.unsqueeze(0), 
                                       true_fluxes.unsqueeze(0))
        
background_stamps = star_encoder1.get_image_stamps(full_background.unsqueeze(0),
                            locs = None, fluxes = None, trim_images = False)[0]

In [None]:
plt.hist(subimage_n_stars)

In [None]:
log_probs1 = star_encoder1(image_stamps, background_stamps)[4]
probs1 = torch.exp(log_probs1)

is_on_probs1 = 1 - probs1[:, 0]

In [None]:
def get_roc_curve(is_on_probs, true_n_stars, seq): 
    tpr_vec = torch.zeros(seq.shape)
    fpr_vec = torch.zeros(seq.shape)
    for i in range(len(seq)): 
        true_positives = (is_on_probs >= seq[i]) & (true_n_stars > 0)
        tpr_vec[i] = true_positives.float().sum() / (true_n_stars > 0).float().sum()
        
        false_positives = (is_on_probs >= seq[i]) & (true_n_stars == 0)
        fpr_vec[i] = false_positives.float().sum() / (true_n_stars == 0).float().sum()
        
    return tpr_vec, fpr_vec
        

In [None]:
tpr_vec1, fpr_vec1 = get_roc_curve(is_on_probs1, subimage_n_stars, torch.arange(0, 1.05, step = 0.05))

log_probs2 = star_encoder2(image_stamps, background_stamps)[4]
probs2 = torch.exp(log_probs2)

is_on_probs2 = 1 - probs2[:, 0]

tpr_vec2, fpr_vec2 = get_roc_curve(is_on_probs2, subimage_n_stars, torch.arange(0, 1.05, step = 0.05))
plt.plot(fpr_vec1.numpy(), tpr_vec1.numpy(), '-x', color = 'orange')
plt.plot(fpr_vec2.numpy(), tpr_vec2.numpy(), '-x', color = 'red')

In [None]:
# get probabilities for portillos ... 
# we need to full chain here

port_flux_samples = chain_results['f'][0, -500:, :] * fudge_factor

x1_loc = chain_results['x'][-500:, ] * (port_flux_samples > fmin)
x0_loc = chain_results['y'][-500:, ] * (port_flux_samples > fmin)
    
        
port_locs_samples = torch.Tensor(np.stack([x0_loc, x1_loc], 2)) / (full_image.shape[-1] - 1)
port_flux_samples = torch.Tensor(port_flux_samples)

In [None]:
# check one sample, make sure I loaded this correctly 
portillos_recon_mean = simulator.draw_image_from_params(locs = port_locs_samples[0:1], 
                                                fluxes = port_flux_samples[0:1].unsqueeze(2),
                                                 n_stars = (port_flux_samples > 0).sum(1)[0:1],  
                                                 add_noise = False).squeeze()

In [None]:
foo = ((portillos_recon_mean - full_image[0]) / full_image[0])[5:95, 5:95]
plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
port_n_stars_sampled = torch.zeros(port_locs_samples.shape[0], star_encoder1.tile_coords.shape[0])

# doing it all at once freezes my laptop ... 
for i in range(port_sampled_locs.shape[0]): 
    port_n_stars_sampled[i] = image_utils.get_params_in_patches(star_encoder1.tile_coords,
                                          port_locs_samples[i:(i+1)],
                                          port_flux_samples[i:(i+1)].unsqueeze(2),
                                          star_encoder1.full_slen,
                                          star_encoder1.stamp_slen,
                                          star_encoder1.edge_padding)[2]
    
    if(i % 50 == 0): 
        print(i)

In [None]:
plt.hist(port_n_stars_sampled.flatten())

In [None]:
is_on_probs_port = (port_n_stars_sampled > 0).float().mean(0)

In [None]:
is_on_probs_port[is_on_probs_port > 0].min()

In [None]:
tpr_vec_port, fpr_vec_port = get_roc_curve(is_on_probs_port, n_stars, 
                                           # torch.arange(0, 1 + 2 / 300, step = 1/300))
                                           torch.arange(0, 1.06, step = 1/500))

plt.plot(fpr_vec_port.numpy(), tpr_vec_port.numpy(), '-x')
plt.plot(fpr_vec1.numpy(), tpr_vec1.numpy(), '-x')

In [None]:
plt.plot(fpr_vec1.numpy(), tpr_vec1.numpy(), '-x', color = 'orange', label = 'sleep only')
plt.plot(fpr_vec2.numpy(), tpr_vec2.numpy(), '-x', color = 'red', label = 'wake-sleep')
plt.plot(fpr_vec_port.numpy(), tpr_vec_port.numpy(), '-x', color = 'blue', label = 'Portillos')

plt.xlabel('False positive rate', size = 16)
plt.ylabel('True positive rate', size = 16)
plt.legend()

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/roc_curve.png')

# Band misalignment

In [None]:
sdss_hubble_data_bands = sdss_dataset_lib.SDSSHubbleData(bands = [2, 3], align_bands = False)


In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

im0 = axarr[0].matshow(sdss_hubble_data_bands.sdss_image[0])
fig.colorbar(im0, ax = axarr[0])
axarr[0].set_title('r band image \n', size = 16)

im1 = axarr[1].matshow(sdss_hubble_data_bands.sdss_image[1])
fig.colorbar(im1, ax = axarr[1])
axarr[1].set_title('i band image \n', size = 16)


diff = sdss_hubble_data_bands.sdss_image[0] - sdss_hubble_data_bands.sdss_image[1]
im2 = axarr[2].matshow(diff, vmax = diff.abs().max(), vmin = -diff.abs().max(), cmap = plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])
axarr[2].set_title('r - i \n', size = 16)

fig.tight_layout()

plt.savefig('../../qualifying_exam_slides/figures/misaligned_bands.png')

In [None]:
sdss_hubble_data_center = sdss_dataset_lib.SDSSHubbleData(x0 = 870, x1 = 160)


In [None]:
plt.matshow(sdss_hubble_data_center.sdss_image[0])
plt.colorbar()
plt.savefig('../../qualifying_exam_slides/figures/sdss_image_center.png')

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[0, 900:950, 180:250])

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(12, 3))

for i in range(3): 
    x0 = int(np.random.choice(90))
    x1 = int(np.random.choice(90))
    
    plotting_utils.plot_subimage(axarr[i], sdss_hubble_data_center.sdss_image[0], 
                                     None, 
                                     None, 
                                     x0, x1, 
                                     subimage_slen = 10, 
                                    add_colorbar = True, 
                                     global_fig = fig)
fig.tight_layout()
# if save_figs: 
#     plt.savefig('../../qualifying_exam_slides/figures/forward_kl_better.png')

In [None]:
sdss_hubble_data_bands