In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_psf
import image_utils 

import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'

In [None]:
full_image = np.loadtxt('../../multiband_pcat/Data/sdss_simulated/cts/sdss_simulated-ctsr.txt')

In [None]:
full_image = torch.Tensor(full_image)

In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

In [None]:
# true parameters 

true_params = np.load('../../multiband_pcat/Data/sdss_simulated/true_params.npz')
true_params.keys()

In [None]:
true_locs = torch.Tensor(true_params['true_locs'].squeeze() * (full_image.shape[-1] - 1))

# Load results

In [None]:
results_dir = '../../multiband_pcat/pcat-lion-results/20190930-194553/'

chain_results = np.load(results_dir + 'chain.npz')

In [None]:
include_classical_catalogue = True

if include_classical_catalogue: 
    pcat_catalog = np.loadtxt(results_dir + 'classical_catalog.txt')
    
    x1_loc = pcat_catalog[:, 0]
    x0_loc = pcat_catalog[:, 2]
        
    fluxes = pcat_catalog[:, 4]
    
    # remove na
    is_na = np.isnan(x1_loc) | np.isnan(x1_loc) | np.isnan(fluxes)
    
    x1_loc = x1_loc[~is_na]
    x0_loc = x0_loc[~is_na]
    fluxes = fluxes[~is_na]
    
    portillos_est_locs = torch.Tensor([x0_loc, x1_loc]).transpose(0,1)
    portillos_est_fluxes = torch.Tensor(fluxes)
    
x1_loc_samples = chain_results['x'][-300:, ].flatten()
x0_loc_samples = chain_results['y'][-300:, ].flatten()

portillos_est_fluxes_sampled = torch.Tensor(chain_results['f'][0, -300:, ].flatten())
portillos_est_locs_sampled = torch.Tensor([x0_loc_samples, x1_loc_samples]).transpose(0,1)

# get reconstruction mean 

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = 686.0)

# only works if we have the classical catalogue
if include_classical_catalogue: 
    _locs = portillos_est_locs.unsqueeze(0) / (full_image.shape[-1] - 1)
    _fluxes = torch.Tensor(fluxes).unsqueeze(0)
    _n_stars = torch.Tensor([len(x0_loc)]).type(torch.LongTensor)
    
    portillos_recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes,
                                                 n_stars = _n_stars,  
                                                 add_noise = False).squeeze()
    
    plt.matshow(portillos_recon_mean); 
    plt.colorbar()

In [None]:
chain_results.keys()

In [None]:
plt.hist(fluxes)

In [None]:
if include_classical_catalogue: 
    portillos_residuals = portillos_recon_mean - full_image
    
    plt.matshow(portillos_residuals)
    plt.colorbar()

# Plot subimages

In [None]:
subimage_slen = 11

# possible coordinates
x0_vec = np.arange(0, full_image.shape[-1] - subimage_slen, subimage_slen)
x1_vec = np.arange(0, full_image.shape[-1] - subimage_slen, subimage_slen)

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 4, figsize=(15, 4))

# posterior samples
plotting_utils.plot_subimage(axarr[0], full_image, portillos_est_locs_sampled, 
                             true_locs, x0, x1, subimage_slen)
axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# condensed catalog
plotting_utils.plot_subimage(axarr[1], full_image, portillos_est_locs, true_locs, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[2], portillos_recon_mean, portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[2].set_title('reconstructed\n');

# residuals
vmax = torch.abs((portillos_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[3], portillos_residuals / full_image, 
                            portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[3].set_title('reconstructed\n');



# Compare with my NN 

In [None]:
import starnet_vae_lib
import inv_KL_objective_lib as objectives_lib

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = 101,
                                            stamp_slen = 9,
                                            step = 4,
                                            edge_padding = 2, 
                                            n_bands = 1,
                                            max_detections = 6)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder_batched_images_400stars_smallpatch2', 
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
full_image.shape

In [None]:
# tile images
_true_locs = (true_locs / (full_image.shape[-1] - 1)).unsqueeze(0)
_true_fluxes = torch.Tensor(true_params['true_fluxes'])

# because I messed up the step ... 
_full_image = full_image[0:101, 0:101].unsqueeze(0).unsqueeze(0)
image_stamps, true_subimage_locs, true_subimage_fluxes, true_n_stars, is_on_array = \
        star_encoder.get_image_stamps(_full_image, _true_locs, _true_fluxes, 
                                      trim_images = False)

In [None]:
# get variational parameters
logit_loc_mean, logit_loc_log_var, \
    log_flux_mean, log_flux_log_var, log_probs = \
        star_encoder(image_stamps, torch.Tensor(true_params['sky_intensity']))

In [None]:
# get map estimates for image patches
map_n_stars = torch.argmax(log_probs, dim = 1)
is_on_array = objectives_lib.get_is_on_from_n_stars(map_n_stars, star_encoder.max_detections)

map_locs = torch.sigmoid(logit_loc_mean).detach()
map_fluxes = torch.exp(log_flux_mean).detach() * is_on_array

In [None]:
# get map estimates for full image 
map_locs_full_image, map_fluxes_full_image, n_stars = \
    image_utils.get_full_params_from_patch_params(map_locs, map_fluxes,
                                                    is_on_array,
                                                    star_encoder.tile_coords,
                                                    full_image.shape[-1],
                                                    star_encoder.stamp_slen,
                                                    star_encoder.edge_padding,
                                                    star_encoder.batchsize)

In [None]:
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = n_stars, 
                                                 add_noise = False).squeeze()

residuals = vae_recon_mean - full_image

In [None]:
x0_vec = np.arange(star_encoder.edge_padding,
                   star_encoder.full_slen - star_encoder.edge_padding - subimage_slen,
                   subimage_slen)

x1_vec = x0_vec

x0_vec

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

# catalog catalog
est_locs = map_locs_full_image.squeeze() * (full_image.shape[-1] - 1)
plotting_utils.plot_subimage(axarr[0], full_image, est_locs, true_locs, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));

# reconstruction
plotting_utils.plot_subimage(axarr[1], vae_recon_mean, est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[1].set_title('reconstructed\n');

# residuals
vmax = torch.abs((residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[2], residuals / full_image, 
                            est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[2].set_title('residuals\n');



# Compare

In [None]:
x0_vec

In [None]:
x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

# my catalog
fig, axarr = plt.subplots(1, 2, figsize=(15, 6))

est_locs = map_locs_full_image.squeeze() * (full_image.shape[-1] - 1)
plotting_utils.plot_subimage(axarr[0], full_image, est_locs, true_locs, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0].set_title('observed; coords: {}\n'.format([x0, x1]));


# portillos catalogue
which_locs = (portillos_est_locs[:, 0] > x0) & \
                (portillos_est_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (portillos_est_locs[:, 1] > x1) & \
                (portillos_est_locs[:, 1] < (x1 + subimage_slen - 1))
portillos_locs = (portillos_est_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0].scatter(portillos_locs[:, 1], portillos_locs[:, 0], color = 'c', marker = 'x')


In [None]:
fig, axarr = plt.subplots(1, 4, figsize=(30, 6))


# reconstruction
plotting_utils.plot_subimage(axarr[0], vae_recon_mean, est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[0].set_title('vae reconstructed\n');

# residuals
vmax = torch.abs((residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[1], residuals / full_image, 
                            est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[1].set_title('vae residuals\n');


# reconstruction
plotting_utils.plot_subimage(axarr[2], portillos_recon_mean, portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
axarr[2].set_title('portillos reconstructed\n');

# residuals
vmax = torch.abs((portillos_residuals / full_image)[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]).max()
plotting_utils.plot_subimage(axarr[3], portillos_residuals / full_image, 
                            portillos_est_locs, None, x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig,
                            diverging_cmap = True, 
                            vmax = vmax, vmin = -vmax)

axarr[3].set_title('residuals\n');

