In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import wake_lib

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib2
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
f_min = 1000.

In [None]:
bands = [2, 3]
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(sdssdir='../../celeste_net/sdss_stage_dir/',
                                       hubble_cat_file = '../hubble_data/NCG7089/' + \
                                        'hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt.txt',
                                        bands = bands)

full_image = sdss_hubble_data.sdss_image.unsqueeze(0)

n_elect_per_nmgy = sdss_hubble_data.nelec_per_nmgy.mean()

In [None]:
plt.matshow(full_image[0, 0])
plt.colorbar()

In [None]:
# true parameters
which_bright = sdss_hubble_data.fluxes[:, 0] > f_min

true_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
true_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)
true_n_stars = torch.Tensor([len(true_locs[0])]).type(torch.LongTensor)

# the encoder

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1], 
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2,
                                           estimate_flux = False)

# Check residuals

In [None]:
init_encoder = '../fits/results_2020-02-06/starnet_ri'
filename = '../fits/results_2020-02-06/wake-sleep_630x310_ri'

In [None]:
band = 0

### Initial residuals -- before having trained PSF

In [None]:
# load encoder
star_encoder.load_state_dict(torch.load(init_encoder,
                                   map_location=lambda storage, loc: storage))
star_encoder.eval();

In [None]:
# get initial parameters
map_locs_full_image, _, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image,
                                            torch.ones(full_image.shape),
                                            return_map_n_stars = True,
                                            return_map_star_params = True)[0:3]

In [None]:
# load initial psf
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
powerlaw_psf_params = torch.zeros(len(bands), 6)
for i in range(len(bands)):
    powerlaw_psf_params[i] = psf_transform_lib2.get_psf_params(
                                    psfield_file,
                                    band = bands[i])


# initial backgrounds
planar_background_params = torch.zeros(len(bands), 3)
planar_background_params[:, 0] = torch.Tensor([686., 1123.])


estimator0 = wake_lib.EstimateModelParams(full_image,
                                            map_locs_full_image,
                                            map_n_stars_full,
                                            init_psf_params = powerlaw_psf_params,
                                            init_background_params = None,
                                            init_fluxes = None,
                                            fmin = f_min)

# estimate fluxes
estimator0.optimize_fluxes_background(max_iter=20)

In [None]:
recon_mean0, init_loss = estimator0.get_loss()

fig, axarr = plt.subplots(1, 2, figsize=(15, 4))

resid = (torch.log10(recon_mean0.detach()) - torch.log10(full_image))[0, band, 5:95, 5:95] * 2.5
im = axarr[0].matshow(resid, vmax = resid.abs().max(), vmin = -resid.abs().max(), 
                      cmap = plt.get_cmap('bwr'))
fig.colorbar(im, ax = axarr[0])

init_psf = estimator0.get_psf().detach()
im1 = axarr[1].matshow(simulated_datasets_lib._trim_psf(init_psf, 15)[band])
fig.colorbar(im1, ax = axarr[1])

In [None]:
init_loss

In [None]:
completeness, tpr, _, _ = \
        image_statistics_lib.get_summary_stats(map_locs_full_image.squeeze(), 
                                               true_locs.squeeze(0), 
                                               star_encoder.full_slen, 
                                               estimator0.get_fluxes().detach()[0, :, 0], 
                                               true_fluxes[0, :, 0], 
                                            n_elect_per_nmgy)

completeness, tpr

### residuals after training

In [None]:
n_iter = 6
loss_vec = torch.zeros(n_iter)

In [None]:
band = 0

In [None]:
for iteration in range(6): 
    if iteration == 0:
        encoder_file = init_encoder
    else:
        encoder_file = filename + '-encoder-iter' + str(iteration)
        
    star_encoder.load_state_dict(torch.load(encoder_file,
                                   map_location=lambda storage, loc: storage))
    star_encoder.eval();
    
    map_locs_full_image, _, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image,
                                            torch.ones(full_image.shape),
                                            return_map_n_stars = True,
                                            return_map_star_params = True)[0:3]
    
    # load psf
    powerlaw_psf_params = \
            torch.Tensor(np.load('../fits/results_2020-02-06/powerlaw_psf_params-iter' + \
                                    str(iteration) + '.npy'))
    
    # load fluxes
    planar_background_params = \
        torch.Tensor(np.load('../fits/results_2020-02-06/planarback_params-iter' + \
                                str(iteration) + '.npy'))
    
    # the fluxes
    map_fluxes_full_image = \
        torch.Tensor(np.load('../fits/results_2020-02-06/fluxes-iter' + \
                                str(iteration) + '.npy'))
        
    estimator = wake_lib.EstimateModelParams(full_image,
                                            map_locs_full_image,
                                            map_n_stars_full,
                                            init_psf_params = powerlaw_psf_params,
                                            init_background_params = planar_background_params,
                                            init_fluxes = map_fluxes_full_image,
                                            fmin = f_min)
    
    fig, axarr = plt.subplots(1, 2, figsize=(15, 4))
    
    recon_mean, loss_vec[iteration] = estimator.get_loss()
    
    resid = (torch.log10(recon_mean.detach()) - torch.log10(full_image))[0, band, 5:95, 5:95] * 2.5
    im = axarr[0].matshow(resid, vmax = resid.abs().max(), vmin = -resid.abs().max(), 
                          cmap = plt.get_cmap('bwr'))
    fig.colorbar(im, ax = axarr[0])
    
    psf_diff = simulated_datasets_lib._trim_psf(
                        estimator.get_psf().detach() - init_psf, 15)
    im1 = axarr[1].matshow(psf_diff[band], vmax = psf_diff.abs().max(), vmin = -psf_diff.abs().max(), 
                    cmap = plt.get_cmap('bwr'))
    fig.colorbar(im1, ax = axarr[1])
    
#     print('\n')
#     print(map_n_stars_full)
#     print((map_locs_full_image**2).mean())
#     print('**final loss**', estimator.get_loss()[1])

In [None]:
loss_vec

# Check psfs -- with true parameters

In [None]:
band = 0

In [None]:
true_n_stars

In [None]:
for iteration in range(6):     
    # load psf
    powerlaw_psf_params = \
            torch.Tensor(np.load('../fits/results_2020-02-06/powerlaw_psf_params-iter' + \
                                    str(iteration) + '.npy'))
    
    # load background
    planar_background_params = \
        torch.Tensor(np.load('../fits/results_2020-02-06/planarback_params-iter' + \
                                str(iteration) + '.npy'))
        
    estimator = wake_lib.EstimateModelParams(full_image,
                                            true_locs,
                                            true_n_stars,
                                            init_psf_params = powerlaw_psf_params,
                                            init_background_params = planar_background_params,
                                            init_fluxes = true_fluxes,
                                            fmin = f_min)
    
    fig, axarr = plt.subplots(1, 2, figsize=(15, 4))
    
    recon_mean, loss_vec[iteration] = estimator.get_loss()
    
    resid = (torch.log10(recon_mean.detach()) - torch.log10(full_image))[0, band, 5:95, 5:95] * 2.5
    im = axarr[0].matshow(resid, vmax = resid.abs().max(), vmin = -resid.abs().max(), 
                          cmap = plt.get_cmap('bwr'))
    fig.colorbar(im, ax = axarr[0])
    
    psf_diff = simulated_datasets_lib._trim_psf(
                        estimator.get_psf().detach() - init_psf, 15)
    im1 = axarr[1].matshow(psf_diff[band], vmax = psf_diff.abs().max(), vmin = -psf_diff.abs().max(), 
                    cmap = plt.get_cmap('bwr'))
    fig.colorbar(im1, ax = axarr[1])

In [None]:
plt.plot(loss_vec.detach())

# Check out summary statistics

In [None]:
n_iter = 6

In [None]:
completeness_all = np.zeros(n_iter + 1)
tpr_all = np.zeros(n_iter + 1)

fig, axarr = plt.subplots(1, 2, figsize=(15, 4))


for iteration in range(0, n_iter): 
        
    if iteration == 0:
        encoder_file = init_encoder
    else:
        encoder_file = filename + '-encoder-iter' + str(iteration)

    
    star_encoder.load_state_dict(torch.load(encoder_file,
                                   map_location=lambda storage, loc: storage))
    star_encoder.eval();
    
    # get parameters
    map_locs_full_image, _, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image, 
                                        torch.zeros(full_image.shape),
                                        return_map_n_stars = True, 
                                        return_map_star_params = True)[0:3]
    
    
    map_fluxes_full_image = \
        torch.Tensor(np.load('../fits/results_2020-02-06/fluxes-iter' + \
                                str(iteration) + '.npy'))
    
    
    # get summary statistics 
    completeness, tpr, _, _ = \
        image_statistics_lib.get_summary_stats(map_locs_full_image.squeeze(), 
                                               true_locs.squeeze(0), 
                                               star_encoder.full_slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               true_fluxes[0, :, 0], 
                                              n_elect_per_nmgy)
    
    completeness_all[iteration] = completeness
    tpr_all[iteration] = tpr
    
    
    # get completeness as a function of magnitude  
    completeness_vec, mag_vec = \
        image_statistics_lib.get_completeness_vec(map_locs_full_image.squeeze(), 
                                               true_locs.squeeze(0), 
                                               star_encoder.full_slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               true_fluxes[0, :, 0], 
                                              n_elect_per_nmgy)[0:2]


    axarr[0].plot(mag_vec[:-1], completeness_vec, '--x', label = 'starnet_iter' + str(iteration))
        
    tpr_vec, mag_vec = \
        image_statistics_lib.get_tpr_vec(map_locs_full_image.squeeze(), 
                                               true_locs.squeeze(0), 
                                               star_encoder.full_slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               true_fluxes[0, :, 0], 
                                              n_elect_per_nmgy)[0:2]

    axarr[1].plot(mag_vec[0:-1], tpr_vec, '--x', label = 'starnet_iter' + str(iteration))
    
    
    print((map_fluxes_full_image**2).mean())
    print((map_locs_full_image**2).mean())
    
axarr[0].legend()
axarr[0].set_xlabel('true mag')
axarr[0].set_ylabel('completeness')

axarr[1].legend()
axarr[1].set_xlabel('estimated mag')
axarr[1].set_ylabel('tpr')

In [None]:
map_fluxes_full_image

In [None]:
(true_fluxes**2).mean()

In [None]:
completeness_all

In [None]:
tpr_all