In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import kl_objective_lib as kl_lib
import plotting_utils

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

# psf file 
psf_fit_file = str(sdss_hubble_data.psf_file)

# image 
full_image = sdss_hubble_data.sdss_image.squeeze()
full_background = sdss_hubble_data.sdss_background.squeeze()

# true parameters
which_bright = (sdss_hubble_data.fluxes > 1300.)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# Our simulator

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = 0.)


# define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 9,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = 1,
                                           max_detections = 4)



# Define transform

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(simulator.psf_og),
                                    simulator.slen, kernel_size = 3)

# Examine encoder losses

In [None]:
losses = []
for i in range(1, 4): 
    losses_iter = np.loadtxt('../fits/wake_sleep-portm2-101420129-encoder-test_losses-iter' + str(i))[0]
    losses = np.concatenate((losses, losses_iter))
    
plt.plot(losses, '-x')

for i in range(4): 
    plt.vlines(x = i * 3, ymin = losses.min(), ymax = losses.max(), 
              color = 'r', linestyle = ':')

# Check psfs -- with true parameters

In [None]:
_, subimage_locs, subimage_fluxes, _, _ = \
    star_encoder.get_image_stamps(full_image.unsqueeze(0).unsqueeze(0), 
                                  true_locs.unsqueeze(0),
                                  true_fluxes.unsqueeze(0),
                                    trim_images = False)

In [None]:
# initial loss
psf_loss_vec = np.zeros(4)

for i in range(4): 
    if i > 0: 
        psf_transform.load_state_dict(torch.load('../fits/wake_sleep-portm2-101420129-psf_transform-iter' + \
                                                         str(i - 1), 
                                             map_location=lambda storage, loc: storage))
    
    recon_mean, psf_loss_vec[i] = \
        psf_transform_lib.get_psf_transform_loss(full_image.unsqueeze(0).unsqueeze(0), 
                                                full_background.unsqueeze(0).unsqueeze(0),
                                                subimage_locs,
                                                subimage_fluxes,
                                                star_encoder.tile_coords,
                                                star_encoder.stamp_slen,
                                                star_encoder.edge_padding,
                                                simulator,
                                                psf_transform)
    
    
    residual = ((recon_mean.squeeze().detach() - full_image) / full_image )[10:90, 10:90]
    vmax = residual.abs().max()
    plt.matshow(residual, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
    plt.colorbar()

In [None]:
# reset simulator
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                                slen = full_image.shape[-1], 
                                                sky_intensity = 0.)

In [None]:
init_psf = simulator.psf
trained_psf = psf_transform.forward()

plt.matshow(init_psf[45:56, 45:56])
plt.matshow(trained_psf[45:56, 45:56].detach())

In [None]:
plt.matshow(init_psf[45:56, 45:56] - trained_psf[45:56, 45:56].detach())
plt.colorbar()

# Check out summary statistics

In [None]:
def filter_params(locs, fluxes, slen): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > 10) & (_locs[:, 0] < 90) & \
                        (_locs[:, 1] > 10) & (_locs[:, 1] < 90) 
        
    
    return locs[which_params], fluxes[which_params]


In [None]:
true_locs, true_fluxes = filter_params(true_locs, true_fluxes, 
                                       full_image.shape[-1])

In [None]:
completeness_all = np.zeros(4)
tpr_all = np.zeros(4)

fig, axarr = plt.subplots(1, 2, figsize=(15, 4))


for i in range(0, 4): 
    if i == 0: 
        star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder-10092019-reweighted_samples', 
                                       map_location=lambda storage, loc: storage))
        
    else: 
        star_encoder.load_state_dict(torch.load('../fits/wake_sleep-portm2-101420129-encoder-iter' + str(i), 
                                       map_location=lambda storage, loc: storage))
    star_encoder.eval(); 
    
    # get parameters
    map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
        star_encoder.get_results_on_full_image(full_image.unsqueeze(0).unsqueeze(0), 
                                               full_background.unsqueeze(0).unsqueeze(0))
    
    est_locs, est_fluxes = filter_params(map_locs_full_image.squeeze(), 
                                           map_fluxes_full_image.squeeze(), 
                                           full_image.shape[-1])

    # take into account fluxes
    completeness, tpr, completeness1_bool, tpr1_bool = \
        image_statistics_lib.get_summary_stats(est_locs, true_locs, 
                                               full_image.shape[-1], 
                                               est_fluxes, true_fluxes)
    completeness_all[i] = completeness
    tpr_all[i] = tpr
    
    
    # get completeness as a function of magnitude  
    completeness1_vec, mag_vec1, _ = \
        image_statistics_lib.get_completeness_vec(est_locs, true_locs, full_image.shape[-1],
                                                  est_fluxes, true_fluxes)

    axarr[0].plot(mag_vec1[:-1], completeness1_vec, '--x', label = 'starnet_iter' + str(i))
        
    tpr_vec, mag_vec, _ = \
        image_statistics_lib.get_tpr_vec(est_locs, true_locs, full_image.shape[-1],
                                        est_fluxes, true_fluxes)

    axarr[1].plot(mag_vec[0:-1], tpr_vec, '--x', label = 'starnet_iter' + str(i))
    
axarr[0].legend()
axarr[0].set_xlabel('true log10 flux')
axarr[0].set_ylabel('completeness')

axarr[1].legend()
axarr[1].set_xlabel('estimated log10 flux')
axarr[1].set_ylabel('tpr')

In [None]:
plt.plot(completeness_all)

In [None]:
plt.plot(tpr_all)