In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_lib
import sleep_lib
import wake_lib
import plotting_utils

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Load the data

In [None]:
f_min = 1000.

In [None]:
bands = [2, 3]
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(sdssdir='../../celeste_net/sdss_stage_dir/',
                                       hubble_cat_file = '../hubble_data/NCG7089/' + \
                                        'hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt.txt',
                                        bands = bands)

full_image = sdss_hubble_data.sdss_image.unsqueeze(0).to(device)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0) 

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > f_min)
true_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
true_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)
true_nstars = torch.Tensor([which_bright.sum()]).type(torch.LongTensor)

In [None]:
plt.matshow(full_image[0, 0])
plt.colorbar()

# plt.scatter(true_locs[:, 1] * 100, 
#            true_locs[:, 0] * 100)

# Load SDSS PSF

In [None]:
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
init_psf_params = psf_transform_lib.get_psf_params(
                                    psfield_file,
                                    bands = bands).to(device)
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params.to(device))
psf_og = power_law_psf.forward().detach()


In [None]:
init_background_params = torch.zeros(len(bands), 3).to(device)
init_background_params[:, 0] = torch.Tensor([686., 1123.])
planar_background = wake_lib.PlanarBackground(image_slen = full_image.shape[-1],
                            init_background_params = init_background_params.to(device))
background = planar_background.forward().detach()


# define VAE

In [None]:
star_encoder = starnet_lib.StarEncoder(slen = full_image.shape[-1],
                                           patch_slen = 8,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = len(bands),
                                           max_detections = 2, 
                                          estimate_flux = True)

# Examine encoder losses

In [None]:
init_encoder = '../fits/results_2020-02-27/starnet_ri'

In [None]:
outfolder = '../fits/results_2020-02-27/'

In [None]:
n_iter = 6

In [None]:
losses = []
for i in range(0, n_iter + 1): 
    print(i)
    losses_iter = np.loadtxt(outfolder + 'wake-sleep-encoder-iter' + str(i) + '-test_losses')[0]
    losses = np.concatenate((losses, losses_iter))
    
plt.plot(losses, '-x')

for i in range(4): 
    plt.vlines(x = i * 11, ymin = losses.min(), ymax = losses.max(), 
              color = 'r', linestyle = ':')

In [None]:
losses = []
for i in range(0, n_iter): 
    print(i)
    losses_iter = np.loadtxt(outfolder + 'iter' + str(i) + '-wake_losses')
    losses = np.concatenate((losses, losses_iter))
        
plt.plot(losses, '-x')

for i in range(n_iter): 
    plt.vlines(x = i * 5, ymin = losses.min(), ymax = losses.max(), 
              color = 'r', linestyle = ':')

# Check psf/background -- with true star parameters

In [None]:
band = 1

In [None]:
psf_loss_vec = np.zeros(n_iter + 1)

for i in range(n_iter + 1): 
    if i > 0: 
        powerlaw_psf_params = \
            torch.Tensor(np.load(outfolder + 'iter' + str(i - 1) +\
                                    '-powerlaw_psf_params.npy')).to(device)
        
        planar_background_params = \
            torch.Tensor(np.load(outfolder + 'iter' + str(i - 1) +\
                                    '-planarback_params.npy')).to(device)
    else: 
        powerlaw_psf_params = init_psf_params
        planar_background_params = None
        
    
    model_params = wake_lib.ModelParams(full_image,
                                powerlaw_psf_params,
                                planar_background_params)
    
    recon_mean, psf_loss_vec[i] = model_params.get_loss(use_cached_stars = False, 
                                                       locs = true_locs,
                                                        fluxes = true_fluxes, 
                                                        n_stars = true_nstars)
    recon_mean = recon_mean.detach()
    psf_trained = model_params.get_psf().detach()
    
    fig, axarr = plt.subplots(1, 2, figsize=(15, 4))

    residual = (torch.log10(recon_mean[0, band]) - torch.log10(full_image[0, band]))[5:95, 5:95] * 2.5
    vmax = residual.abs().max()
    im0 = axarr[0].matshow(residual, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
    fig.colorbar(im0, ax = axarr[0])
    
    foo = (psf_trained- psf_og)[band, 40:60, 40:60]
    im1 = axarr[1].matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), 
                           cmap = plt.get_cmap('bwr'))
    fig.colorbar(im1, ax = axarr[1])
    axarr[1].set_title('iter = {}'.format(i - 1))

In [None]:
plt.plot(psf_loss_vec, '-x')

In [None]:
for i in range(n_iter + 1): 
    if i > 0: 
        powerlaw_psf_params = \
            torch.Tensor(np.load(outfolder + 'iter' + str(i - 1) +\
                                    '-powerlaw_psf_params.npy')).to(device)
        
        planar_background_params = \
            torch.Tensor(np.load(outfolder + 'iter' + str(i - 1) +\
                                    '-planarback_params.npy')).to(device)
    else: 
        powerlaw_psf_params = init_psf_params
        planar_background_params = None
        
    
    model_params = wake_lib.ModelParams(full_image,
                                powerlaw_psf_params,
                                planar_background_params)
    
    print(model_params.planar_background.params)
    
    old_back = model_params.get_background()


In [None]:
model_params = wake_lib.ModelParams(full_image,
                                init_psf_params,
                                None)

In [None]:
np.load('../data/fitted_planar_backgrounds.npy')

In [None]:
model_params.planar_background.params

# Check out summary statistics

In [None]:
n_elect_per_nmgy = sdss_hubble_data.nelec_per_nmgy.mean()

In [None]:
tpr_all = np.zeros(n_iter + 2)
ppv_all = np.zeros(n_iter + 2)

fig, axarr = plt.subplots(1, 2, figsize=(15, 4))


for i in range(-1, n_iter + 1):     
    if i == -1:
        encoder_file = init_encoder
    else: 
        encoder_file = outfolder + 'wake-sleep-encoder-iter' + str(i)

        
    star_encoder.load_state_dict(torch.load(encoder_file, 
                                   map_location=lambda storage, loc: storage))
    star_encoder.eval(); 
    
    # get parameters
    map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image, 
                                        return_map_n_stars = True, 
                                        return_map_star_params = True)[0:3]
    print(map_n_stars_full)
    
    
    # get summary statistics 
    tpr, ppv, _, _ = \
        image_statistics_lib.get_summary_stats(map_locs_full_image.squeeze(), 
                                               true_locs.squeeze(), 
                                               star_encoder.slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               true_fluxes.squeeze(0)[:, 0], 
                                              n_elect_per_nmgy)
    
    tpr_all[i+1] = tpr
    ppv_all[i+1] = ppv
    
    
    # get tpr as a function of magnitude  
    tpr_vec, mag_vec = \
        image_statistics_lib.get_tpr_vec(map_locs_full_image.squeeze(), 
                                               true_locs.squeeze(), 
                                               star_encoder.slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               true_fluxes.squeeze(0)[:, 0], 
                                              n_elect_per_nmgy)[0:2]


    axarr[0].plot(mag_vec[:-1], tpr_vec, '--x', label = 'starnet_iter' + str(i))
        
    ppv_vec, mag_vec = \
        image_statistics_lib.get_ppv_vec(map_locs_full_image.squeeze(), 
                                               true_locs.squeeze(), 
                                               star_encoder.slen, 
                                               map_fluxes_full_image.squeeze(0)[:, 0], 
                                               true_fluxes.squeeze(0)[:, 0], 
                                              n_elect_per_nmgy)[0:2]

    axarr[1].plot(mag_vec[0:-1], ppv_vec, '--x', label = 'starnet_iter' + str(i))
    
axarr[0].legend()
axarr[0].set_xlabel('true mag')
axarr[0].set_ylabel('tpr')

axarr[1].legend()
axarr[1].set_xlabel('estimated mag')
axarr[1].set_ylabel('ppv')

In [None]:
ppv_all

In [None]:
tpr_all

In [None]:
f1_score = 2 * ppv_all * tpr_all / (ppv_all + tpr_all)
f1_score

# PSF loss on inferred parameters

In [None]:
band = 0

In [None]:
init_encoder

In [None]:
wake_losses = torch.zeros(n_iter)
for i in range(0, n_iter):     
    encoder_file = outfolder + 'wake-sleep-encoder-iter' + str(i)
        
    powerlaw_psf_params = \
        torch.Tensor(np.load(outfolder + 'iter' + str(i) +\
                                '-powerlaw_psf_params.npy')).to(device)
        
    planar_background_params = \
        torch.Tensor(np.load(outfolder + 'iter' + str(i) +\
                                '-planarback_params.npy')).to(device)

        
    star_encoder.load_state_dict(torch.load(encoder_file, 
                                   map_location=lambda storage, loc: storage))
    star_encoder.eval(); 
    
    # get parameters
    map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image, 
                                        full_background,
                                        return_map_n_stars = True, 
                                        return_map_star_params = True)[0:3]
    
    # get reconstruction
    model_params = wake_lib.ModelParams(full_image,
                                powerlaw_psf_params,
                                planar_background_params)
    
    recon_mean, wake_losses[i] = model_params.get_loss(use_cached_stars = False, 
                                                       locs = map_locs_full_image,
                                                        fluxes = map_fluxes_full_image, 
                                                        n_stars = map_n_stars_full)
    
    residual = (torch.log10(recon_mean[0, band].detach()) - torch.log10(full_image[0, band]))[5:95, 5:95] * 2.5
    vmax = residual.abs().max()
    plt.matshow(residual, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
    plt.colorbar()
    

In [None]:
np.load('../fits/results_2020-02-18/map_losses.npy') 

In [None]:
wake_losses.detach().numpy()

In [None]:
map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
        star_encoder.sample_star_encoder(full_image, 
                                        full_background,
                                        return_map_n_stars = False, 
                                        return_map_star_params = False, 
                                        n_samples = 5)[0:3]

In [None]:
model_params = wake_lib.ModelParams(full_image,
                                powerlaw_psf_params,
                                planar_background_params)
    
recon_mean, loss = model_params.get_loss(use_cached_stars = False, 
                                                   locs = map_locs_full_image.detach(),
                                                    fluxes = map_fluxes_full_image.detach(), 
                                                    n_stars = map_n_stars_full.detach())

In [None]:
residual = (torch.log10(recon_mean[0, band].detach()) - torch.log10(full_image[0, band]))[5:95, 5:95] * 2.5
vmax = residual.abs().max()
plt.matshow(residual, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
plt.colorbar()


In [None]:
model_params.planar_background.params.grad

In [None]:
loss.mean().backward()

In [None]:
model_params.planar_background.params.grad

In [None]:
loss

In [None]:
model_params.power_law_psf.params.grad

In [None]:
np.load('../data/fitted_planar_backgrounds.npy')