In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import fitsio

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import psf_transform_lib
import wake_lib 

import starnet_lib
import sleep_lib
import plotting_utils

np.random.seed(34534)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Data parameters

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

In [None]:
data_params

# PSF

In [None]:
# psfield_file = './../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
# init_psf_params = psf_transform_lib.get_psf_params(
#                                     psfield_file,
#                                     bands = bands)
init_psf_params = torch.Tensor(np.load('../data/fitted_powerlaw_psf_params.npy'))
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params.to(device))
psf_og = power_law_psf.forward().detach()


# Background

In [None]:
# init_background_params = torch.zeros(len(bands), 3).to(device)
# init_background_params[:, 0] = torch.Tensor([686., 1123.])
init_background_params = torch.Tensor(np.load('../data/fitted_planar_backgrounds.npy'))
planar_background = wake_lib.PlanarBackground(image_slen = data_params['slen'],
                            init_background_params = init_background_params.to(device))
background = planar_background.forward().detach()


In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_og, data_params['slen'], background, transpose_psf = False)

# Get Data

In [None]:
use_real_data = True

In [None]:
bands = [2, 3]

In [None]:
if use_real_data: 
    x0 = 630
    x1 = 310
    sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = x0,
                                                        x1 = x1, 
                                                        bands = bands)

    # image 
    full_image = sdss_hubble_data.sdss_image.unsqueeze(0)
    full_background = sdss_hubble_data.sdss_background.unsqueeze(0)

    # true parameters
    which_bright = (sdss_hubble_data.fluxes[:, 0] > data_params['f_min'])
    true_locs = sdss_hubble_data.locs[which_bright]
    true_fluxes = sdss_hubble_data.fluxes[which_bright]

else:    
    # draw data
    n_images = 1
    star_dataset = \
        simulated_datasets_lib.load_dataset_from_params(psf_og,
                                data_params,
                                background = background,
                                n_images = n_images,
                                transpose_psf = False,
                                add_noise = True)

    full_image = star_dataset.images
    full_background = background.unsqueeze(0)

    true_locs = star_dataset.locs.squeeze(0)[0:int(star_dataset.n_stars)]
    true_fluxes = star_dataset.fluxes.squeeze(0)[0:int(star_dataset.n_stars)]


In [None]:
plt.matshow(full_image[0, 0])
plt.colorbar()

# First set of results 

In [None]:
star_encoder1 = starnet_lib.StarEncoder(full_slen = data_params['slen'],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands), 
                                           max_detections = 2,
                                           estimate_flux = True)

In [None]:
star_encoder1.load_state_dict(torch.load('../fits/results_2020-02-17/starnet_ri',
                               map_location=lambda storage, loc: storage))
star_encoder1.eval(); 

In [None]:
# get parameters on the full image 
map_locs_full_image1, map_fluxes_full_image1, map_n_stars_full1 = \
    star_encoder1.sample_star_encoder(full_image, 
                                        full_background, 
                                        n_stars = None,
                                        return_map_n_stars = True,
                                        return_map_star_params = True)[0:3]

In [None]:
bins = plt.hist(torch.log10(map_fluxes_full_image1.flatten()), bins = 100);
plt.hist(torch.log10(true_fluxes.flatten()), bins = bins[1], alpha = 0.5);

In [None]:
# get reconstructed mean
vae_recon_mean1 = simulator.draw_image_from_params(map_locs_full_image1, 
                                                    map_fluxes_full_image1,
                                                    map_n_stars_full1, 
                                                    add_noise = False)

# Second set of results

In [None]:
star_encoder2 = starnet_lib.StarEncoder(full_slen = data_params['slen'],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands), 
                                           max_detections = 2,
                                           estimate_flux = True)


In [None]:
star_encoder2.load_state_dict(torch.load('../fits/results_2020-02-17/starnet_ri_fitted-back-psf',
                               map_location=lambda storage, loc: storage))
star_encoder2.eval(); 

In [None]:
# get parameters on the full image 
map_locs_full_image2, map_fluxes_full_image2, map_n_stars_full2 = \
    star_encoder2.sample_star_encoder(full_image, 
                                        full_background, 
                                        n_stars = None,
                                        return_map_n_stars = True,
                                        return_map_star_params = True)[0:3]

In [None]:
bins = plt.hist(torch.log10(map_fluxes_full_image2.flatten()), bins = 100);
plt.hist(torch.log10(true_fluxes.flatten()), bins = bins[1], alpha = 0.5);

In [None]:
# get reconstructed mean
vae_recon_mean2 = simulator.draw_image_from_params(map_locs_full_image2, 
                                                    map_fluxes_full_image2,
                                                    map_n_stars_full2, 
                                                    add_noise = False)

# Check reconstructions

In [None]:
band = 0

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(full_image[0, band][5:95, 5:95])
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean1[0, band][5:95, 5:95])
fig.colorbar(im1, ax = axarr[1])

residual = torch.log10(vae_recon_mean1[0, band]) - torch.log10(full_image[0, band])
_residual = (residual * 2.5)[5:95, 5:95]
# (torch.log(vae_recon_mean.squeeze()) - torch.log(images_full.squeeze()))[10:90, 10:90]
vmax = _residual.abs().max()
im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(full_image[0, band][5:95, 5:95])
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean2[0, band][5:95, 5:95])
fig.colorbar(im1, ax = axarr[1])

residual = torch.log10(vae_recon_mean2[0, band]) - torch.log10(full_image[0, band])
_residual = (residual * 2.5)[5:95, 5:95]
# (torch.log(vae_recon_mean.squeeze()) - torch.log(images_full.squeeze()))[10:90, 10:90]
vmax = _residual.abs().max()
im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])

# Summary statistics

In [None]:
map_n_stars_full1

In [None]:
map_n_stars_full2

In [None]:
import image_statistics_lib

In [None]:
n_elect_per_nmgy = 856.

In [None]:
# completeness and tpr 
completeness1, tpr1, _, _ = \
    image_statistics_lib.get_summary_stats(map_locs_full_image1.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder1.full_slen, 
                                           map_fluxes_full_image1.squeeze(0)[:, 0], 
                                           true_fluxes.squeeze(0)[:, 0], 
                                          n_elect_per_nmgy)
print('completeness: {:0.3f}'.format(completeness1))
print('true positive rate: {:0.3f}'.format(tpr1))

In [None]:
# completeness and tpr 
completeness2, tpr2, _, _ = \
    image_statistics_lib.get_summary_stats(map_locs_full_image2.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder2.full_slen, 
                                           map_fluxes_full_image2.squeeze(0)[:, 0], 
                                          true_fluxes.squeeze(0)[:, 0], 
                                          n_elect_per_nmgy)
    
print('completeness: {:0.3f}'.format(completeness2))
print('true positive rate: {:0.3f}'.format(tpr2))

In [None]:
completeness_vec1, mag_vec1 = \
    image_statistics_lib.get_completeness_vec(map_locs_full_image1.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder1.full_slen, 
                                           map_fluxes_full_image1.squeeze(0)[:, 0], 
                                           true_fluxes.squeeze(0)[:, 0], 
                                             n_elect_per_nmgy)[0:2]

plt.plot(mag_vec1[0:-1], completeness_vec1, '--x', label = 'starnet1')

completeness_vec2, mag_vec2 = \
    image_statistics_lib.get_completeness_vec(map_locs_full_image2.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder2.full_slen, 
                                           map_fluxes_full_image2.squeeze(0)[:, 0], 
                                           true_fluxes.squeeze(0)[:, 0], 
                                             n_elect_per_nmgy)[0:2]

plt.plot(mag_vec2[0:-1], completeness_vec2, '--x', label = 'starnet2')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
tpr_vec1, mag_vec1 = \
    image_statistics_lib.get_tpr_vec(map_locs_full_image1.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder1.full_slen, 
                                           map_fluxes_full_image1.squeeze(0)[:, 0], 
                                           true_fluxes.squeeze(0)[:, 0], 
                                             n_elect_per_nmgy)[0:2]

plt.plot(mag_vec1[0:-1], tpr_vec1, '--x', label = 'starnet1')

tpr_vec2, mag_vec2 = \
    image_statistics_lib.get_tpr_vec(map_locs_full_image2.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder2.full_slen, 
                                           map_fluxes_full_image2.squeeze(0)[:, 0], 
                                           true_fluxes.squeeze(0)[:, 0], 
                                             n_elect_per_nmgy)[0:2]

plt.plot(mag_vec2[0:-1], tpr_vec2, '--x', label = 'starnet2')

plt.legend()
plt.xlabel('est magnitude')
plt.ylabel('tpr')

# Check image patches

In [None]:
f, axarr = plt.subplots(2, 3, figsize=(16, 8))

indx = int(np.random.choice(star_encoder1.tile_coords.shape[0], 1))


########################
# First set of results
#######################
_ = plotting_utils.plot_subimage(axarr[0, 0], full_image[0, band],
                            map_locs_full_image1.squeeze(), 
                            true_locs.squeeze(), 
                            int(star_encoder1.tile_coords[indx, 0]), 
                            int(star_encoder1.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder1.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

_ = plotting_utils.plot_subimage(axarr[0, 1], vae_recon_mean1[0, band],
                            map_locs_full_image1.squeeze(), 
                            None, 
                            int(star_encoder1.tile_coords[indx, 0]), 
                            int(star_encoder1.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder1.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = torch.log10(vae_recon_mean1[0, band]) - torch.log10(full_image[0, band])
_ = plotting_utils.plot_subimage(axarr[0, 2], foo * 2.5, 
                            map_locs_full_image1.squeeze(), 
                            None, 
                            int(star_encoder1.tile_coords[indx, 0]), 
                            int(star_encoder1.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder1.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

###################
# Second set of results 
###################
_ = plotting_utils.plot_subimage(axarr[1, 0], full_image[0, band],
                            map_locs_full_image2.squeeze(), 
                            true_locs.squeeze(), 
                            int(star_encoder2.tile_coords[indx, 0]), 
                            int(star_encoder2.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder2.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

_ = plotting_utils.plot_subimage(axarr[1, 1], vae_recon_mean2[0, band],
                            map_locs_full_image2.squeeze(), 
                            None, 
                            int(star_encoder2.tile_coords[indx, 0]), 
                            int(star_encoder2.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder2.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = torch.log10(vae_recon_mean2[0, band]) - torch.log10(full_image[0, band])
_ = plotting_utils.plot_subimage(axarr[1, 2], foo * 2.5, 
                            map_locs_full_image2.squeeze(), 
                            None, 
                            int(star_encoder2.tile_coords[indx, 0]), 
                            int(star_encoder2.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder2.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

# axarr[0].axvline(x=2, color = 'r')
# axarr[0].axvline(x=4, color = 'r')
# axarr[0].axhline(y=2, color = 'r')
# axarr[0].axhline(y=4, color = 'r')

# axarr[1].axvline(x=2, color = 'r')
# axarr[1].axvline(x=4, color = 'r')
# axarr[1].axhline(y=2, color = 'r')
# axarr[1].axhline(y=4, color = 'r')