In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import wake_lib

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
f_min = 1000.

In [None]:
bands = [2, 3]
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(sdssdir='../../celeste_net/sdss_stage_dir/',
                                       hubble_cat_file = '../hubble_data/NCG7089/' + \
                                        'hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt.txt',
                                        bands = bands)

full_image = sdss_hubble_data.sdss_image.unsqueeze(0)


In [None]:
plt.matshow(full_image[0, 0])
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# the encoder

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1], 
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2,
                                           estimate_flux = False)

# Storing my model parameters

In [None]:
init_encoder = '../fits/results_2020-02-06/starnet_ri'
filename = '../fits/results_2020-02-06/wake-sleep_630x310_ri'

In [None]:
for iteration in range(6): 
    if iteration == 0:
        encoder_file = init_encoder
    else:
        encoder_file = filename + '-encoder-iter' + str(iteration)
        
    powerlaw_psf_params = \
            torch.Tensor(np.load('../fits/results_2020-02-06/powerlaw_psf_params-iter' + \
                                    str(iteration) + '.npy'))
    planar_background_params = \
        torch.Tensor(np.load('../fits/results_2020-02-06/planarback_params-iter' + \
                                str(iteration) + '.npy'))
    fluxes = \
        torch.Tensor(np.load('../fits/results_2020-02-06/fluxes-iter' + \
                                str(iteration) + '.npy'))
    
    star_encoder.load_state_dict(torch.load(encoder_file,
                                   map_location=lambda storage, loc: storage))
    star_encoder.eval();
    
    map_locs_full_image, _, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image,
                                            torch.ones(full_image.shape),
                                            return_map_n_stars = True,
                                            return_map_star_params = True)[0:3]
        
    estimator = wake_lib.EstimateModelParams(full_image,
                                            map_locs_full_image,
                                            map_n_stars_full,
                                            init_psf_params = powerlaw_psf_params,
                                            init_background_params = planar_background_params,
                                            init_fluxes = fluxes,
                                            fmin = f_min)
    print('\n')
    print(map_n_stars_full)
    print((map_locs_full_image**2).mean())
    print('**final loss**', estimator.get_loss()[1])

# Check psfs -- with true parameters

In [None]:
band = 0

In [None]:
import psf_transform_lib2

In [None]:
which_bright = sdss_hubble_data.fluxes[:, 0] > f_min

_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)
_n_stars = torch.Tensor([len(_locs[0])]).type(torch.LongTensor)


In [None]:
for iteration in range(6): 
    powerlaw_psf_params = \
            torch.Tensor(np.load('../fits/results_2020-02-06/powerlaw_psf_params-iter' + \
                                    str(iteration) + '.npy'))
    
    planar_background_params = \
        torch.Tensor(np.load('../fits/results_2020-02-06/planarback_params-iter' + \
                                str(iteration) + '.npy'))
    
    true_n_stars = torch.Tensor([sdss_hubble_data.locs.shape[0]])
    estimator = wake_lib.EstimateModelParams(full_image,
                                            locs = _locs,
                                            n_stars = _n_stars,
                                            init_psf_params = powerlaw_psf_params,
                                            init_background_params = planar_background_params,
                                            init_fluxes = _fluxes,
                                            fmin = f_min)
    # print('**loss**', estimator.get_loss()[1])
    print(planar_background_params)

# Check out summary statistics

In [None]:
n_elect_per_nmgy = sdss_hubble_data.nelec_per_nmgy.mean()

In [None]:
n_iter = 6

In [None]:
completeness_all = np.zeros(n_iter + 1)
tpr_all = np.zeros(n_iter + 1)

fig, axarr = plt.subplots(1, 2, figsize=(15, 4))


for iteration in range(0, n_iter): 
        
    if iteration == 0:
        encoder_file = init_encoder
    else:
        encoder_file = filename + '-encoder-iter' + str(iteration)
        
    fluxes = \
        torch.Tensor(np.load('../fits/results_2020-02-06/fluxes-iter' + \
                                str(iteration) + '.npy'))
    
    star_encoder.load_state_dict(torch.load(encoder_file,
                                   map_location=lambda storage, loc: storage))
    star_encoder.eval();
    
    # get parameters
    map_locs_full_image, _, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image, 
                                        torch.zeros(full_image.shape),
                                        return_map_n_stars = True, 
                                        return_map_star_params = True)[0:3]
    map_fluxes_full_image = fluxes
    
    
    # get summary statistics 
    completeness, tpr, _, _ = \
        image_statistics_lib.get_summary_stats(map_locs_full_image.squeeze(), 
                                               _locs.squeeze(0), 
                                               star_encoder.full_slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               _fluxes[0, :, 0], 
                                              n_elect_per_nmgy)
    
    completeness_all[iteration] = completeness
    tpr_all[iteration] = tpr
    
    
    # get completeness as a function of magnitude  
    completeness_vec, mag_vec = \
        image_statistics_lib.get_completeness_vec(map_locs_full_image.squeeze(), 
                                               _locs.squeeze(0), 
                                               star_encoder.full_slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               _fluxes[0, :, 0], 
                                              n_elect_per_nmgy)[0:2]


    axarr[0].plot(mag_vec[:-1], completeness_vec, '--x', label = 'starnet_iter' + str(iteration))
        
    tpr_vec, mag_vec = \
        image_statistics_lib.get_tpr_vec(map_locs_full_image.squeeze(), 
                                               _locs.squeeze(0), 
                                               star_encoder.full_slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               _fluxes[0, :, 0], 
                                              n_elect_per_nmgy)[0:2]

    axarr[1].plot(mag_vec[0:-1], tpr_vec, '--x', label = 'starnet_iter' + str(iteration))
    
axarr[0].legend()
axarr[0].set_xlabel('true mag')
axarr[0].set_ylabel('completeness')

axarr[1].legend()
axarr[1].set_xlabel('estimated mag')
axarr[1].set_ylabel('tpr')

In [None]:
completeness_all

In [None]:
tpr_all