In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
f_min = 1000.

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

# psf file 
psf_fit_file = str(sdss_hubble_data.psf_file)

# image 
full_image = sdss_hubble_data.sdss_image.squeeze()
full_background = sdss_hubble_data.sdss_background.squeeze() 

# true parameters
which_bright = (sdss_hubble_data.fluxes > f_min)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# Our simulator

In [None]:
from copy import deepcopy
psf_og = sdss_psf.psf_at_points(0, 0, psf_fit_file = str(sdss_hubble_data.psf_file))

psf_init = torch.Tensor(simulated_datasets_lib._expand_psf(psf_og, full_image.shape[-1]))

In [None]:
psf_og.shape

In [None]:
psf_init.shape

In [None]:
(psf_init**2).mean()

# define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = 1,
                                           max_detections = 2)

# Define transform

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1], 
                                    kernel_size = 3)

# Examine encoder losses

In [None]:
init_encoder = '../fits/results_11052019/starnet'

In [None]:
filename = '../fits/results_11052019/wake_sleep-loc630x310'

In [None]:
'../fits/results_11052019/'

In [None]:
n_iter = 6

In [None]:
losses = []
for i in range(1, n_iter + 1): 
    losses_iter = np.loadtxt('../fits/' + filename + '-encoder-test_losses-iter' + str(i))[0]
    losses = np.concatenate((losses, losses_iter))
    
plt.plot(losses, '-x')

for i in range(4): 
    plt.vlines(x = i * 11, ymin = losses.min(), ymax = losses.max(), 
              color = 'r', linestyle = ':')

In [None]:
losses = []
for i in range(0, n_iter): 
    losses_iter = np.loadtxt('../fits/' + filename + '-psf_transform-test_losses-iter' + str(i))
    losses = np.concatenate((losses, losses_iter))
    
    print(losses_iter[-1])
    
plt.plot(losses, '-x')

for i in range(n_iter): 
    plt.vlines(x = i * 4, ymin = losses.min(), ymax = losses.max(), 
              color = 'r', linestyle = ':')

In [None]:
from torch import optim

In [None]:
psf_lr = 0.1

In [None]:
wake_optimizer = optim.Adam([
                    {'params': psf_transform.parameters(),
                    'lr': psf_lr}],
                    weight_decay = 1e-5)



# Check psfs -- with true parameters

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1], 
                                    kernel_size = 3)

psf_loss_vec = np.zeros(n_iter + 1)

psf_old = psf_transform.forward()

for i in range(n_iter + 1): 
    if i > 0: 
        psf_transform.load_state_dict(torch.load('../fits/' + filename + '-psf_transform-iter' + \
                                                         str(i - 1), 
                                             map_location=lambda storage, loc: storage))
    
    recon_mean, psf_loss_vec[i] = \
        psf_transform_lib.get_psf_loss(full_image.unsqueeze(0).unsqueeze(0), 
                                        full_background.unsqueeze(0).unsqueeze(0),
                                        true_locs.unsqueeze(0), 
                                        true_fluxes.unsqueeze(0), 
                                        n_stars = (true_fluxes.unsqueeze(0) > 0).sum(1),
                                        psf = psf_transform.forward(),
                                        pad = 5)
    
    fig, axarr = plt.subplots(1, 2, figsize=(15, 4))

    residual = ((recon_mean.squeeze().detach() - full_image) / full_image)
    vmax = 0.7 # residual.abs().max()
    im0 = axarr[0].matshow(residual, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
    fig.colorbar(im0, ax = axarr[0])
    
    foo = (psf_transform.forward().detach()- psf_init)[40:60, 40:60]
    im1 = axarr[1].matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), 
                           cmap = plt.get_cmap('bwr'))
    fig.colorbar(im1, ax = axarr[1])
    axarr[1].set_title('iter = {}'.format(i - 1))
    
    diff = (psf_transform.forward() - psf_old).abs().max()
    print(diff)
    psf_old = psf_transform.forward()

In [None]:
plt.plot(psf_loss_vec, '-x')

In [None]:
x0 = 60
x1 = 10
subimage_slen = 20

psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1], 
                                    kernel_size = 3)

psf_loss_vec = np.zeros(n_iter + 1)

for i in range(n_iter + 1): 
    if i > 0: 
        psf_transform.load_state_dict(torch.load('../fits/' + filename + '-psf_transform-iter' + \
                                                         str(i - 1), 
                                             map_location=lambda storage, loc: storage))
    
    recon_mean, psf_loss_vec[i] = \
        psf_transform_lib.get_psf_loss(full_image.unsqueeze(0).unsqueeze(0), 
                                        full_background.unsqueeze(0).unsqueeze(0),
                                        true_locs.unsqueeze(0), 
                                        true_fluxes.unsqueeze(0), 
                                        n_stars = (true_fluxes.unsqueeze(0) > 0).sum(1),
                                        psf = psf_transform.forward(),
                                        pad = 5)


    resid = (recon_mean.squeeze().detach() - full_image) / full_image
    
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    plotting_utils.plot_subimage(axarr[0], full_image, 
                            None, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
    
    plotting_utils.plot_subimage(axarr[1], recon_mean.squeeze().detach(), 
                            None, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
    
    plotting_utils.plot_subimage(axarr[2], resid, 
                            None, 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            diverging_cmap = True, 
                            vmax = 0.5)

# Check out summary statistics

In [None]:
def filter_params(locs, fluxes, slen): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > 10) & (_locs[:, 0] < 90) & \
                        (_locs[:, 1] > 10) & (_locs[:, 1] < 90) 
        
    
    return locs[which_params], fluxes[which_params]


In [None]:
true_locs, true_fluxes = filter_params(true_locs, true_fluxes, 
                                       full_image.shape[-1])

In [None]:
completeness_all = np.zeros(n_iter + 1)
tpr_all = np.zeros(n_iter + 1)

fig, axarr = plt.subplots(1, 2, figsize=(15, 4))


for i in range(0, n_iter + 1): 
    if i == 0: 
        star_encoder.load_state_dict(torch.load(init_encoder, 
                                       map_location=lambda storage, loc: storage))
        
    else: 
        star_encoder.load_state_dict(torch.load('../fits/' + filename + '-encoder-iter' + str(i), 
                                       map_location=lambda storage, loc: storage))
    star_encoder.eval(); 
    
    # get parameters
    map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image.unsqueeze(0).unsqueeze(0), 
                                               full_background.unsqueeze(0).unsqueeze(0), 
                                              return_map = True)[0:3]
    
    est_locs, est_fluxes = filter_params(map_locs_full_image.squeeze(), 
                                           map_fluxes_full_image.squeeze(), 
                                           full_image.shape[-1])
    
    # take into account fluxes
    completeness, tpr, completeness1_bool, tpr1_bool = \
        image_statistics_lib.get_summary_stats(est_locs, true_locs, 
                                               full_image.shape[-1], 
                                               est_fluxes, true_fluxes)
    completeness_all[i] = completeness
    tpr_all[i] = tpr
    
    
    # get completeness as a function of magnitude  
    completeness1_vec, mag_vec1, _ = \
        image_statistics_lib.get_completeness_vec(est_locs, true_locs, full_image.shape[-1],
                                                  est_fluxes, true_fluxes)

    axarr[0].plot(mag_vec1[:-1], completeness1_vec, '--x', label = 'starnet_iter' + str(i))
        
    tpr_vec, mag_vec, _ = \
        image_statistics_lib.get_tpr_vec(est_locs, true_locs, full_image.shape[-1],
                                        est_fluxes, true_fluxes)

    axarr[1].plot(mag_vec[0:-1], tpr_vec, '--x', label = 'starnet_iter' + str(i))
    
axarr[0].legend()
axarr[0].set_xlabel('true log10 flux')
axarr[0].set_ylabel('completeness')

axarr[1].legend()
axarr[1].set_xlabel('estimated log10 flux')
axarr[1].set_ylabel('tpr')

In [None]:
10**2.5

# PSF loss on inferred parameters

In [None]:
import wake_sleep_lib

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet-10172019-no_reweighting', 
                                       map_location=lambda storage, loc: storage))

star_encoder.eval(); 
    
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
    wake_sleep_lib.sample_star_encoder(star_encoder, 
                                       full_image.unsqueeze(0).unsqueeze(0),
                                       full_background.unsqueeze(0).unsqueeze(0),
                                        return_map = True)[0:3]

recon_mean, init_loss = \
    psf_transform_lib.get_psf_loss(full_image,
                                   full_background,
                                    sampled_locs_full_image, sampled_fluxes_full_image,
                                    n_stars = sampled_n_stars_full,
                                    psf = psf_init, 
                                    pad = 5)

print(init_loss)
        
residual = ((recon_mean.squeeze().detach() - full_image) / full_image)[10:90, 10:90]
vmax = residual.abs().max()
plt.matshow(residual, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
plt.colorbar()

print((((recon_mean.squeeze().detach() - full_image)**2) / full_image)[10:90, 10:90].mean())

In [None]:
psf_loss_vec = np.zeros(n_iter + 1)
for i in range(0, n_iter + 1): 
    if i == 0: 
        star_encoder.load_state_dict(torch.load('../fits/starnet-10172019-no_reweighting', 
                                       map_location=lambda storage, loc: storage))
            
    else: 
        star_encoder.load_state_dict(torch.load('../fits/' + filename + '-encoder-iter' + str(i), 
                                       map_location=lambda storage, loc: storage))
    
    psf_transform.load_state_dict(torch.load('../fits/' + filename + '-psf_transform-iter' + \
                                                 str(i), 
                                     map_location=lambda storage, loc: storage))

        
        
        
    star_encoder.eval(); 
    
    sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
        wake_sleep_lib.sample_star_encoder(star_encoder, 
                                           full_image.unsqueeze(0).unsqueeze(0),
                                           full_background.unsqueeze(0).unsqueeze(0),
                                            return_map = True)[0:3]
    
    recon_mean, psf_loss_vec[i] = \
        psf_transform_lib.get_psf_loss(full_image,
                                       full_background,
                                        sampled_locs_full_image, sampled_fluxes_full_image,
                                        n_stars = sampled_n_stars_full,
                                        psf = psf_transform.forward(),
                                        pad = 5)
    
    residual = ((recon_mean.squeeze().detach() - full_image))[10:90, 10:90]
    vmax = residual.abs().max()
    plt.matshow(residual, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
    plt.colorbar()
    
    print((((recon_mean.squeeze().detach() - full_image)**2) / full_image)[10:90, 10:90].mean())

In [None]:
plt.plot(psf_loss_vec)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet-10172019-no_reweighting', 
                                       map_location=lambda storage, loc: storage))

In [None]:
psf_transform.load_state_dict(torch.load('../fits/wake_sleep-altm2-10212019-psf_transform-iter3', 
                                         map_location=lambda storage, loc: storage))


psf_trained = psf_init

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
        wake_sleep_lib.sample_star_encoder(star_encoder, 
                                           full_image.unsqueeze(0).unsqueeze(0),
                                           full_background.unsqueeze(0).unsqueeze(0),
                                            return_map = True)



recon_mean, _ = \
        psf_transform_lib.get_psf_loss(full_image,
                                       full_background,
                                        sampled_locs_full_image, sampled_fluxes_full_image, 
                                       sampled_n_stars_full,
                                        psf = psf_trained,
                                        pad = 5)


resid = ((recon_mean.squeeze().detach() - full_image) / full_image)[10:90, 10:90]
vmax = resid.abs().max()
plt.matshow(resid, 
           cmap = plt.get_cmap('bwr'), 
           vmin = -vmax, vmax = vmax)
plt.colorbar()

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
        wake_sleep_lib.sample_star_encoder(star_encoder, 
                                           full_image.unsqueeze(0).unsqueeze(0),
                                           full_background.unsqueeze(0).unsqueeze(0),
                                            return_map = False, 
                                          n_samples = 10)

In [None]:
recon_mean, recon_loss = \
            psf_transform_lib.get_psf_loss(full_image,
                                           full_background,
                                            sampled_locs_full_image, 
                                           sampled_fluxes_full_image, 
                                           sampled_n_stars_full,
                                            psf = psf_transform.forward().detach(),
                                            pad = 5)

In [None]:
plt.hist(recon_loss)

In [None]:
torch.topk(recon_loss, 5)[0].min().item()

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full = \
        wake_sleep_lib.sample_star_encoder(star_encoder, 
                                           full_image.unsqueeze(0).unsqueeze(0),
                                           full_background.unsqueeze(0).unsqueeze(0),
                                            return_map = False)
    
recon_mean, map_loss = \
            psf_transform_lib.get_psf_loss(full_image,
                                           full_background,
                                            sampled_locs_full_image, 
                                           sampled_fluxes_full_image, 
                                           sampled_n_stars_full,
                                            psf = psf_transform.forward().detach(),
                                            pad = 5)

In [None]:
x0 = 10
x1 = 10
subimage_slen = 10

for i in range(5): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))


    resid = (recon_mean[i].squeeze().detach() - full_image) / full_image

    plotting_utils.plot_subimage(axarr[0], full_image, 
                            sampled_locs_full_image[i], 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
    
    plotting_utils.plot_subimage(axarr[1], recon_mean[i].squeeze(), 
                            sampled_locs_full_image[i], 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig)
    
    plotting_utils.plot_subimage(axarr[2], resid, 
                            sampled_locs_full_image[i], 
                             true_locs, 
                             x0, x1, subimage_slen, 
                            add_colorbar = True, global_fig = fig, 
                            diverging_cmap = True)