In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import sdss_dataset_lib

import inv_kl_objective_lib as inv_kl_lib

import image_utils

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(22)
_ = torch.manual_seed(22)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 101
data_params['min_stars'] = 2000
data_params['max_stars'] = 2000
data_params['alpha'] = 0.5


In [None]:
use_simulated_data = True
use_hubble_params = False

if use_simulated_data: 
    print('simulating data')
    if not use_hubble_params: 
        # Draw from the sam distribution I simulated data 
        n_images = 1

        simulated_dataset = \
            simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                                    data_params,
                                    n_images = n_images,
                                    add_noise = True)

        batchsize = 1

        loader = torch.utils.data.DataLoader(
                         dataset=simulated_dataset,
                         batch_size=batchsize,
                         shuffle=False)
    
        for _, data in enumerate(loader):
            true_full_fluxes = data['fluxes']
            true_full_locs = data['locs']
            images_full = data['image']
            backgrounds_full = torch.ones(images_full.shape) * data_params['sky_intensity']

            break
        
        simulator = simulated_dataset.simulator
        
    else: 
        # simulate my own data, but use hubble paramters
        
        # load hubble data
        sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()
        
        # my simulator 
        simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=str(sdss_hubble_data.psf_file), 
                                                            slen = sdss_hubble_data.slen, 
                                                            sky_intensity = data_params['sky_intensity'])
        
        # simulate data using hubble parameters
        images_full = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                                fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                                n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                                add_noise = True)
        
        backgrounds_full = torch.ones(images_full.shape) * data_params['sky_intensity']
        
        # save true parameters
        which_bright = sdss_hubble_data.fluxes > data_params['f_min']
        true_full_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
        true_full_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)

    
else: 
    print('loading data')
    # use hubble images and hubble parameters
    sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()
    images_full = sdss_hubble_data.sdss_image.unsqueeze(0)
    backgrounds_full = sdss_hubble_data.sdss_background.unsqueeze(0)

    which_bright = sdss_hubble_data.fluxes > data_params['f_min']
    true_full_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
    true_full_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)
    
    simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=str(sdss_hubble_data.psf_file), 
                                                        slen = sdss_hubble_data.slen, 
                                                        sky_intensity = data_params['sky_intensity'])




In [None]:
# # true parameters
# batchsize = 1

# loader = torch.utils.data.DataLoader(
#                  dataset=simulated_dataset,
#                  batch_size=batchsize,
#                  shuffle=False)

# for _, data in enumerate(loader):
#     true_full_fluxes = data['fluxes']
#     true_full_locs = data['locs']
#     images_full = data['image']
    
#     break
# backgrounds_full = torch.ones((image_stamps.shape[0], 1, 1, 1)) * data_params['sky_intensity']

# data = np.load('../fits/testing_data.npz')
# images_full = torch.Tensor(data['images'][0:1])
# backgrounds_full = torch.ones((image_stamps.shape[0], 1, 1, 1)) * data_params['sky_intensity']
# true_full_locs = torch.Tensor(data['true_locs'][0:1])
# true_full_fluxes = torch.Tensor(data['true_fluxes'][0:1])



In [None]:
# histogram of fluxes
plt.hist(np.log10(true_full_fluxes.numpy().flatten()))

In [None]:
plt.matshow(images_full.squeeze());

# Load VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                            stamp_slen = 9,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = 1,
                                            max_detections = 4)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder-10072019',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
inv_kl_lib.eval_star_encoder_loss(star_encoder, loader, train = False)

In [None]:
losses = np.loadtxt('../fits/test_losses_2000stars_smallpatch5')

n_losses = losses.shape[1]
plt.plot(np.arange(n_losses - 1), losses[0, 1:n_losses])
plt.scatter(np.arange(n_losses - 1), losses[0, 1:n_losses], marker = 'x')


# Get image stamps

In [None]:
backgrounds_full.shape

In [None]:
# get image stamps
image_stamps, true_subimage_locs, true_subimage_fluxes, true_n_stars, is_on_array = \
        star_encoder.get_image_stamps(images_full, true_full_locs, true_full_fluxes, 
                                      trim_images = False)
    
background_stamps = star_encoder.get_image_stamps(backgrounds_full, None, None, 
                                      trim_images = False)[0]

In [None]:
plt.hist(true_n_stars, bins=np.arange(max(true_n_stars) + 2))

In [None]:
# check loss again
loss, counter_loss, locs_loss, fluxes_loss, perm = \
    inv_kl_lib.get_encoder_loss(star_encoder, images_full, backgrounds_full, true_full_locs, true_full_fluxes)

In [None]:
print('loss: {:06f}'.format(loss))

In [None]:
plt.hist(locs_loss.detach()[locs_loss != 0], bins = 100);

# get inferred parameters

In [None]:
# if more than max detections ...
_true_n_stars = true_n_stars.clone()
_true_n_stars[_true_n_stars > star_encoder.max_detections] = star_encoder.max_detections
    
logit_loc_mean, logit_loc_log_var, \
    log_flux_mean, log_flux_log_var, log_probs = \
        star_encoder(image_stamps, data_params['sky_intensity'], _true_n_stars)

# check parameters

In [None]:
# permute true parameters 
def permute_params(locs, fluxes, perm): 
    batchsize = perm.shape[0]
    max_stars = perm.shape[1]

    locs_perm = torch.zeros((batchsize, max_stars, 2))
    fluxes_perm = torch.zeros((batchsize, max_stars))
    seq_tensor = torch.LongTensor([i for i in range(batchsize)])

    for i in range(max_stars):
        locs_perm[:, i, :] = locs[seq_tensor, perm[:, i], :]
        fluxes_perm[:, i] = fluxes[seq_tensor, perm[:, i]]
        
    return locs_perm, fluxes_perm

In [None]:
locs_perm, fluxes_perm = permute_params(logit_loc_mean, log_flux_mean, perm)

In [None]:
map_locs = torch.sigmoid(locs_perm) * is_on_array.unsqueeze(2).float()
map_fluxes = fluxes_perm * is_on_array.float()

In [None]:
plt.plot(map_locs.flatten()[map_locs.flatten() > 0].detach(), 
         true_subimage_locs.flatten()[true_subimage_locs.flatten() > 0], '+')

plt.plot(map_locs.flatten()[map_locs.flatten() > 0].detach(), 
         map_locs.flatten()[map_locs.flatten() > 0].detach(), '-')

plt.xlabel('estimated')
plt.ylabel('truth')

In [None]:
plt.plot(map_fluxes.flatten()[map_fluxes.flatten() > 0].detach(), 
         torch.log(true_subimage_fluxes.flatten()[true_subimage_fluxes.flatten() > 0]), '+')

plt.plot(map_fluxes.flatten()[map_fluxes.flatten() > 0].detach(), 
         map_fluxes.flatten()[map_fluxes.flatten() > 0].detach(), '-')

plt.xlabel('estimated')
plt.ylabel('truth')

# Check reconstructions 

In [None]:
# _true_n_stars = true_n_stars.clone()
# _true_n_stars[true_n_stars > star_encoder.max_detections] = star_encoder.max_detections
# probs = objectives_lib.get_one_hot_encoding_from_int(_true_n_stars, star_encoder.max_detections + 1) + 0.1

# probs = probs / probs.sum(dim = 1).unsqueeze(1)

# log_probs = torch.log(probs)

# is_on_array = objectives_lib.get_is_on_from_n_stars(log_probs.argmax(1), star_encoder.max_detections)

# logit_loc_mean = objectives_lib._logit(true_subimage_locs[:, 0:star_encoder.max_detections, :]) * \
#                     is_on_array.unsqueeze(2).float()
    
# logit_loc_log_var = -5 * torch.ones(logit_loc_mean.shape) * is_on_array.unsqueeze(2).float()

# log_flux_mean = torch.log(true_subimage_fluxes[:, 0:star_encoder.max_detections].clamp(min = 0.1)) * \
#                     is_on_array.float()

# log_flux_log_var = -5 * torch.ones(log_flux_mean.shape) * is_on_array.float()

# loss, counter_loss, locs_loss, fluxes_loss, perm = \
#     objectives_lib.get_params_loss(logit_loc_mean, logit_loc_log_var, \
#                         log_flux_mean, log_flux_log_var, log_probs,
#                         true_subimage_locs, true_subimage_fluxes, _true_n_stars)

In [None]:
# get map estimates for image patches
use_true_n_stars = False
if use_true_n_stars: 
    map_n_stars = _true_n_stars # torch.argmax(log_probs, dim = 1)
else: 
    map_n_stars = torch.argmax(log_probs, dim = 1)
    
is_on_array = inv_kl_lib.get_is_on_from_n_stars(map_n_stars, star_encoder.max_detections)

map_locs = torch.sigmoid(logit_loc_mean).detach() * is_on_array.unsqueeze(2).float()
map_fluxes = torch.exp(log_flux_mean).detach() * is_on_array.float()

In [None]:
# convert patch parameters to parameters on the full image 
map_locs_full_image, map_fluxes_full_image, n_stars = \
    image_utils.get_full_params_from_patch_params(map_locs, 
                                                  map_fluxes,
                                                    star_encoder.tile_coords,
                                                    star_encoder.full_slen,
                                                    star_encoder.stamp_slen,
                                                    star_encoder.edge_padding,
                                                    star_encoder.batchsize)

In [None]:
# get reconsstructed mean
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = n_stars, 
                                                 add_noise = False).squeeze()

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(images_full.squeeze())
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean.squeeze())
fig.colorbar(im1, ax = axarr[1])

residual = vae_recon_mean.squeeze() - images_full.squeeze()
im2 = axarr[2].matshow(residual)
fig.colorbar(im2, ax = axarr[2])

# check image patches

In [None]:
import plotting_utils

In [None]:
for i in range(1): 
    fig, axarr = plt.subplots(1, 3, figsize=(16, 6))
    indx = int(np.random.choice(image_stamps.shape[0], 1))
    
    x0 = int(star_encoder.tile_coords[indx, 0])
    x1 = int(star_encoder.tile_coords[indx, 1]) 
    
    # plot image stamp
    im0 = axarr[0].matshow(image_stamps[indx].squeeze())
    im0 = fig.colorbar(im0, ax=axarr[0])
    
    # plot true locations      
    patch_slen = (star_encoder.stamp_slen - 2 * star_encoder.edge_padding)
    axarr[0].scatter(true_subimage_locs[indx, 0:true_n_stars[indx], 1] * (patch_slen - 1) + \
                         star_encoder.edge_padding, 
                    true_subimage_locs[indx, 0:true_n_stars[indx], 0] * (patch_slen - 1) + \
                         star_encoder.edge_padding, 
                    color = 'b')
    
    axarr[0].scatter(map_locs[indx, 0:map_n_stars[indx], 1] * (patch_slen - 1) + star_encoder.edge_padding, 
                    map_locs[indx, 0:map_n_stars[indx], 0] * (patch_slen - 1) + star_encoder.edge_padding, 
                    color = 'r', marker = 'x')
    
    
    axarr[0].axvline(x=2, color = 'r')
    axarr[0].axvline(x=6, color = 'r')
    axarr[0].axhline(y=2, color = 'r')
    axarr[0].axhline(y=6, color = 'r')
    
    axarr[0].set_title('observed; coords {}\n'.format([x0, x1]))
    
    # plot reconstruction
    recon_patch = vae_recon_mean[x0:(x0+star_encoder.stamp_slen), 
                                   x1:(x1+star_encoder.stamp_slen)]
    im1 = axarr[1].matshow(recon_patch)
    
    axarr[1].axvline(x=2, color = 'r')
    axarr[1].axvline(x=6, color = 'r')
    axarr[1].axhline(y=2, color = 'r')
    axarr[1].axhline(y=6, color = 'r')
    
    axarr[1].scatter(map_locs[indx, 0:map_n_stars[indx], 1] * (patch_slen - 1) + star_encoder.edge_padding, 
                    map_locs[indx, 0:map_n_stars[indx], 0] * (patch_slen - 1) + star_encoder.edge_padding, 
                    color = 'r', marker = 'x')
    fig.colorbar(im1, ax=axarr[1])
    
    
    # plot residual
    im2 = axarr[2].matshow((recon_patch - image_stamps[indx].squeeze()))
    fig.colorbar(im2, ax=axarr[2])

In [None]:
# check that using the full image params get the same image
# f, axarr = plt.subplots(1, 3, figsize=(16, 6))

# plotting_utils.plot_subimage(axarr[0], images_full.squeeze(),
#                             map_locs_full_image.squeeze() * (images_full.shape[-1] - 1), 
#                             true_full_locs.squeeze() * (images_full.shape[-1] - 1), 
#                             int(star_encoder.tile_coords[indx, 0]), 
#                             int(star_encoder.tile_coords[indx, 1]), 
#                             subimage_slen = star_encoder.stamp_slen)

# axarr[0].axvline(x=2, color = 'r')
# axarr[0].axvline(x=6, color = 'r')
# axarr[0].axhline(y=2, color = 'r')
# axarr[0].axhline(y=6, color = 'r')

# On any arbitrary patch of the image 

In [None]:
w = 10
x0_vec = np.arange(star_encoder.edge_padding, 
                   star_encoder.full_slen - star_encoder.edge_padding - w, 
                  w)

x1_vec = x0_vec

In [None]:
x0_vec

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

plotting_utils.plot_subimage(axarr[0], images_full.squeeze(),
                            map_locs_full_image.squeeze(), 
                            true_full_locs.squeeze(), 
                            x0, x1, 
                            subimage_slen = w, 
                            add_colorbar = True, 
                            global_fig = fig)

axarr[0].set_title('observed; coords = {}'.format([x0, x1]));

plotting_utils.plot_subimage(axarr[1], vae_recon_mean.squeeze(),
                            map_locs_full_image.squeeze(), 
                            None,  
                            x0, x1, 
                            subimage_slen = w, 
                            add_colorbar = True, 
                            global_fig = fig)

axarr[1].set_title('reconstructed; coords = {}'.format([x0, x1]));


residual = (vae_recon_mean.squeeze() - images_full.squeeze())
plotting_utils.plot_subimage(axarr[2], residual, 
                            map_locs_full_image.squeeze(), 
                            None,  
                            x0, x1, 
                            subimage_slen = w, 
                            add_colorbar = True, 
                            global_fig = fig)

axarr[1].set_title('residual; coords = {}'.format([x0, x1]));


# Check out some summary statistics 

In [None]:
def filter_params(locs, fluxes, slen, pad): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > pad) & (_locs[:, 0] < (slen - pad - 1)) & \
                        (_locs[:, 1] > pad) & (_locs[:, 1] < (slen - pad - 1))
        
    
    return locs[which_params], fluxes[which_params]

In [None]:
true_locs, true_fluxes = filter_params(true_full_locs.squeeze(), 
                          true_full_fluxes.squeeze(), 
                          slen = star_encoder.full_slen,
                          pad = star_encoder.edge_padding)

est_locs, est_fluxes = filter_params(map_locs_full_image.squeeze(), 
                          map_fluxes_full_image.squeeze(), 
                          slen = star_encoder.full_slen,
                          pad = star_encoder.edge_padding)

In [None]:
_recon_mean = simulator.draw_image_from_params(
                                locs = est_locs.unsqueeze(0), 
                                fluxes = est_fluxes.unsqueeze(0),
                                n_stars = torch.Tensor([est_locs.shape[0]]).type(torch.LongTensor), 
                                add_noise = False).squeeze()

_recon_truth = \
    simulator.draw_image_from_params(locs = true_locs.unsqueeze(0), 
                                    fluxes = true_fluxes.unsqueeze(0),
                                     n_stars = torch.Tensor([len(true_locs)]).type(torch.LongTensor), 
                                     add_noise = False).squeeze()


fig, axarr = plt.subplots(1, 4, figsize=(15, 6))

axarr[0].matshow(_recon_mean)
axarr[2].matshow(_recon_truth)

axarr[1].matshow((vae_recon_mean - images_full.squeeze()))

axarr[3].matshow((_recon_truth - images_full.squeeze()))


In [None]:
# get matrix of error in locations 
def get_locs_error(locs, true_locs): 
    # truth x estimated
    return torch.abs(locs.unsqueeze(0) - true_locs.unsqueeze(1)).max(2)[0]

In [None]:
locs_error = get_locs_error(est_locs * (star_encoder.full_slen - 1), 
                            true_locs * (star_encoder.full_slen - 1))
# completeness: for each true star, is there at least one estimated star that is close 
print(torch.any(locs_error < 0.5, dim = 1).float().mean())
# true positive rate: for each estimated star, is the at least one true star that is close?
print(torch.any(locs_error < 0.5, dim = 0).float().mean())

In [None]:
# Take into account fluxes?
def get_fluxes_error(fluxes, true_fluxes): 
    # truth x estimated
    return torch.abs(torch.log10(fluxes).unsqueeze(0) - \
                     torch.log10(true_fluxes).unsqueeze(1))


In [None]:
fluxes_error = get_fluxes_error(est_fluxes, true_fluxes)

In [None]:
# completeness
print(torch.any((locs_error < 0.5) * (fluxes_error < 0.5), dim = 1).float().mean())
# true positive rate: for each estimated star, is the at least one true star that is close?
print(torch.any((locs_error < 0.5) * (fluxes_error < 0.5), dim = 0).float().mean())

In [None]:
plt.hist(torch.log10(true_fluxes))

In [None]:
# get completeness as a function of magnitude

true_mag = torch.log10(true_fluxes)

max_mag = torch.ceil(true_mag.max())
min_mag = torch.floor(true_mag.min())

mag_vec = np.arange(min_mag, max_mag, 0.5)

completeness_vec = np.zeros(len(mag_vec) - 1)

for i in range(len(mag_vec) - 1): 
    which_true = (true_mag > mag_vec[i]) & (true_mag < mag_vec[i + 1])
    
    fluxes_error = get_fluxes_error(est_fluxes, true_fluxes[which_true])
    
    locs_error = get_locs_error(est_locs * (star_encoder.full_slen - 1), 
                                true_locs[which_true] * (star_encoder.full_slen - 1))
    
    completeness_vec[i] = \
        torch.any((locs_error < 0.5), dim = 1).float().mean()
    
#     completeness_vec[i] = \
#         torch.any((locs_error < 0.5) * (fluxes_error < 0.5), dim = 1).float().mean()

In [None]:
plt.plot(mag_vec[0:-1], completeness_vec, '--x')
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
# true positive rate

est_mag = torch.log10(est_fluxes)

max_mag = torch.ceil(true_mag.max())
min_mag = torch.floor(true_mag.min())

mag_vec = np.arange(min_mag, max_mag, 0.5)

tpr_vec = np.zeros(len(mag_vec) - 1)

for i in range(len(mag_vec) - 1): 
    which_est = (est_mag > mag_vec[i]) & (est_mag < mag_vec[i + 1])
    
    fluxes_error = get_fluxes_error(est_fluxes[which_est], true_fluxes)
    
    locs_error = get_locs_error(est_locs[which_est] * (star_encoder.full_slen - 1), 
                                true_locs * (star_encoder.full_slen - 1))
    
    tpr_vec[i] = \
        torch.any(locs_error < 0.5, dim = 0).float().mean()
    
#     completeness_vec[i] = \
#         torch.any((locs_error < 0.5) * (fluxes_error < 0.5), dim = 1).float().mean()

plt.plot(mag_vec[0:-1], tpr_vec, '--x')
plt.xlabel('estimated log flux')
plt.ylabel('tpr')