In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_lib
import sdss_dataset_lib
import plotting_utils
import image_statistics_lib
import psf_transform_lib
import utils

import sleep_lib

import image_utils

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
import fitsio

In [None]:
np.random.seed(453)
_ = torch.manual_seed(786)

# Data parameters

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

In [None]:
data_params

# The PSF

In [None]:
bands = [2, 3]
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
init_psf_params = psf_transform_lib.get_psf_params(
                                    psfield_file,
                                    bands = bands)
# init_psf_params = torch.Tensor(np.load('../data/fitted_powerlaw_psf_params.npy'))
power_law_psf = psf_transform_lib.PowerLawPSF(init_psf_params.to(device))
psf_og = power_law_psf.forward().detach()


In [None]:
plt.matshow(simulated_datasets_lib._trim_psf(psf_og, 15)[0])

In [None]:
psf_og.shape

In [None]:
n_elect_per_nmgy = 856.

# Draw data

In [None]:
import wake_lib
# init_background_params = torch.zeros(len(bands), 3).to(device)
# init_background_params[:, 0] = torch.Tensor([686., 1123.])
init_background_params = torch.Tensor(np.load('../data/fitted_planar_backgrounds.npy'))
planar_background = wake_lib.PlanarBackground(image_slen = data_params['slen'],
                            init_background_params = init_background_params.to(device))
background = planar_background.forward().detach()


In [None]:
# Draw from the same distribution I used in the sleep phase
n_images = 1

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                    data_params,
                    background = background,
                    n_images = n_images,
                    transpose_psf = False, 
                    add_noise = True)

images = simulated_dataset.images.detach()
backgrounds = simulated_dataset.background.detach()
        
which_on = (simulated_dataset.fluxes > 0).any(2).squeeze()
        
true_locs = simulated_dataset.locs[:, which_on, :]
true_fluxes = simulated_dataset.fluxes[:, which_on, :]
        
simulator = simulated_dataset.simulator

In [None]:
b = 0
plt.matshow(images[0, b])

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

for i in range(3): 
    x0 = int(np.random.choice(images.shape[-1] - 10, 1))
    x1 = int(np.random.choice(images.shape[-1] - 10, 1))

    plotting_utils.plot_subimage(axarr[i], images[0, b],
                                None, 
                                true_locs.squeeze(), 
                                x0, x1, 
                                patch_slen = 10, 
                                add_colorbar = True, 
                                global_fig = f)


In [None]:
images.min()

In [None]:
true_fluxes.shape

In [None]:
true_locs.shape

In [None]:
# histogram of fluxes
for i in range(psf_og.shape[0]): 
    plt.hist(np.log10(true_fluxes[:, :, i].numpy().flatten()), bins = 50);

In [None]:
n_bands = len(bands)

In [None]:
if n_bands > 1: 
    # histogram of colors
    colors = -2.5 * np.log10(true_fluxes[:, :, 1].numpy().flatten() / \
                      true_fluxes[:, :, 0].numpy().flatten())
    plt.hist(colors, bins = 50);

In [None]:
# color x flux
if n_bands > 1: 
    plt.scatter(colors, np.log10(true_fluxes[:, :, 0].numpy().flatten()), marker = 'x', alpha = 0.5)

In [None]:
images.min()

In [None]:
true_fluxes.min()

# Load VAE

In [None]:
star_encoder = starnet_lib.StarEncoder(slen = data_params['slen'],
                                            patch_slen = 8,
                                            step = 2,
                                            edge_padding = 3, 
                                            n_bands = n_bands,
                                            max_detections = 2)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/results_2020-02-27/starnet_ri',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
# check loss 
loss, counter_loss, locs_loss, fluxes_loss, perm_indx = \
    sleep_lib.get_inv_kl_loss(star_encoder, images, backgrounds, 
                                true_locs, true_fluxes)[0:5]

In [None]:
print('loss: {:06f}'.format(loss))

In [None]:
print(counter_loss.mean())
print(locs_loss.mean())
print(fluxes_loss.mean())

# Get image patches

In [None]:
# get image patches
image_patches, true_patch_locs, true_patch_fluxes, \
    true_patch_n_stars, true_is_on_array = \
        star_encoder.get_image_patches(images, true_locs, true_fluxes, 
                                      clip_max_stars = True)

In [None]:
foo = plt.hist(true_patch_n_stars, bins=np.arange(max(true_patch_n_stars) + 2))[0]
# plt.plot(foo[0] / star_encoder.weights, 'x')

In [None]:
plt.hist(locs_loss.detach()[locs_loss != 0], bins = 100);


# get inferred parameters on patches

In [None]:
# Note that these variational parameters are estimated using the true number of stars!
patch_loc_mean, patch_loc_log_var, \
    patch_log_flux_mean, patch_log_flux_log_var, patch_log_probs = \
        star_encoder(image_patches, true_patch_n_stars)

In [None]:
map_n_stars_patches = torch.argmax(patch_log_probs, dim = 1).detach()

In [None]:
(map_n_stars_patches == true_patch_n_stars).float().mean()

In [None]:
plt.hist(map_n_stars_patches, bins = np.arange(star_encoder.max_detections + 2))
plt.hist(true_patch_n_stars, bins = np.arange(star_encoder.max_detections + 2), alpha = 0.2)

In [None]:
from itertools import permutations

In [None]:
perm_list = []
for perm in permutations(range(star_encoder.max_detections)):
    perm_list.append(perm)

In [None]:
perm = np.zeros((image_patches.shape[0], star_encoder.max_detections))
for i in range(image_patches.shape[0]): 
    perm[i, :] = perm_list[perm_indx[i]]

### check parameters

In [None]:
# permute true parameters 
def permute_params(locs, fluxes, perm): 
    batchsize = perm.shape[0]
    max_stars = perm.shape[1]
    
    n_bands = fluxes.shape[-1]

    locs_perm = torch.zeros((batchsize, max_stars, 2))
    fluxes_perm = torch.zeros((batchsize, max_stars, n_bands))
    seq_tensor = torch.LongTensor([i for i in range(batchsize)])

    for i in range(max_stars):
        locs_perm[:, i, :] = locs[seq_tensor, perm[:, i], :]
        fluxes_perm[:, i, :] = fluxes[seq_tensor, perm[:, i], :]
        
    return locs_perm, fluxes_perm

In [None]:
locs_perm, log_fluxes_perm = permute_params(patch_loc_mean, patch_log_flux_mean, perm)
log_var_locs_perm, log_var_log_fluxes_perm = \
    permute_params(patch_loc_log_var, patch_log_flux_log_var, perm)

In [None]:
map_patch_locs = (locs_perm * \
                     true_is_on_array.unsqueeze(2).float()).detach()
map_patch_fluxes = \
    (torch.exp(log_fluxes_perm) * \
                           true_is_on_array.unsqueeze(2).float()).detach()

In [None]:
plt.plot(map_patch_locs.flatten()[map_patch_locs.flatten() > 0].detach(), 
         true_patch_locs.flatten()[true_patch_locs.flatten() > 0], '+')

plt.plot(map_patch_locs.flatten()[map_patch_locs.flatten() > 0].detach(), 
         map_patch_locs.flatten()[map_patch_locs.flatten() > 0].detach(), '-')

plt.xlabel('estimated')
plt.ylabel('truth')

In [None]:
plt.plot(torch.log10(true_patch_fluxes.flatten()[true_patch_fluxes.flatten() > 0]), 
         torch.log10(map_patch_fluxes.flatten()[map_patch_fluxes.flatten() > 0].detach()), '+')

plt.plot(torch.log10(map_patch_fluxes.flatten()[map_patch_fluxes.flatten() > 0].detach()), 
         torch.log10(map_patch_fluxes.flatten()[map_patch_fluxes.flatten() > 0].detach()), '-')

plt.plot(torch.log10(map_patch_fluxes.flatten()[map_patch_fluxes.flatten() > 0].detach()), 
         torch.log10(map_patch_fluxes.flatten()[map_patch_fluxes.flatten() > 0].detach()) + 0.4, 'r:')

plt.plot(torch.log10(map_patch_fluxes.flatten()[map_patch_fluxes.flatten() > 0].detach()), 
     torch.log10(map_patch_fluxes.flatten()[map_patch_fluxes.flatten() > 0].detach()) - 0.4, 'r:')

plt.xlabel('truth')
plt.ylabel('estimated')

In [None]:
est = log_fluxes_perm[true_patch_fluxes > 0]
truth = torch.log(true_patch_fluxes[true_patch_fluxes > 0])

est_sd = torch.exp(0.5 * log_var_log_fluxes_perm[true_patch_fluxes > 0])

zscore = (est - truth) / est_sd

plt.hist(zscore.detach(), bins = 50); 

In [None]:
est = locs_perm[true_patch_locs > 0]
truth = true_patch_locs[true_patch_locs > 0]

est_sd = torch.exp(0.5 * log_var_locs_perm[true_patch_locs > 0])

zscore = (est - truth) / est_sd

plt.hist(zscore.detach(), bins = 50); 

In [None]:
zscore.mean()

In [None]:
zscore.var()

In [None]:
# Not just fluxes, but also color?
if n_bands > 1: 
    map_color = \
        torch.log10(map_patch_fluxes[:, :, 1].flatten()[map_patch_fluxes[:, :, 1].flatten() > 0] / \
                            map_patch_fluxes[:, :, 0].flatten()[map_patch_fluxes[:, :, 0].flatten() > 0])

    true_color = \
        torch.log10(true_patch_fluxes[:, :, 1].flatten()[true_patch_fluxes[:, :, 1].flatten() > 0] / \
                            true_patch_fluxes[:, :, 0].flatten()[true_patch_fluxes[:, :, 0].flatten() > 0])

    plt.plot(map_color, true_color, 'x')
    plt.plot(map_color, map_color, '-')

In [None]:
if n_bands > 1: 
    fig, axarr = plt.subplots(1, 2, figsize=(7, 4))

    axarr[0].scatter(true_color, 
                     true_patch_fluxes[:, :, 1].flatten()[true_patch_fluxes[:, :, 1].flatten() > 0])

    axarr[0].scatter(map_color, 
                     map_patch_fluxes[:, :, 1].flatten()[true_patch_fluxes[:, :, 1].flatten() > 0], )

# Check reconstructions 

In [None]:
use_true_n_stars = False
if use_true_n_stars: 
    _n_stars = true_patch_n_stars
else: 
    _n_stars = None

# get parameters on the full image 
map_locs, map_fluxes, map_n_stars = \
    star_encoder.sample_star_encoder(images, 
                                     return_map_n_stars = True,
                                     return_map_star_params = True, 
                                     patch_n_stars = _n_stars)[0:3]
    
if _n_stars is not None: 
    assert map_n_stars == torch.sum(_n_stars)

assert map_n_stars == map_locs.shape[1]
assert map_n_stars == map_fluxes.shape[1]

In [None]:
plt.hist(torch.log10(map_fluxes.flatten()), bins = 100)
foo = plt.hist(torch.log10(true_fluxes).flatten(), alpha = 0.5, bins = 100)[1]; 

In [None]:
# get reconstructed mean
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs, 
                                                fluxes = map_fluxes,
                                                 n_stars = map_n_stars, 
                                                 add_noise = False).detach()

In [None]:
band = 0

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(images[0, band][5:95, 5:95])
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean[0, band][5:95, 5:95])
fig.colorbar(im1, ax = axarr[1])

residual = torch.log10(vae_recon_mean[0, band]) - torch.log10(images[0, band])
_residual = (residual * 2.5)[5:95, 5:95]
# (torch.log(vae_recon_mean.squeeze()) - torch.log(images.squeeze()))[10:90, 10:90]
vmax = _residual.abs().max()
im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])

In [None]:
plt.hist(_residual.flatten())

In [None]:
torch.where(_residual < -2.5)

In [None]:
plt.matshow(_residual[40:50, 10:20]); 
plt.colorbar()

In [None]:
def get_which_tile(x0, x1, tile_coords, edge_padding, patch_slen): 
    coords = tile_coords + edge_padding
    
    view_slen = patch_slen - 2 * edge_padding
    
    indx = torch.where((x0 > coords[:, 0]) & \
                       (x0 < coords[:, 0] + view_slen) & \
                       (x1 > coords[:, 1]) & \
                       (x1 < coords[:, 1] + view_slen))
        
    return tile_coords[indx], indx

In [None]:
get_which_tile(21, 67, star_encoder.tile_coords, star_encoder.edge_padding, star_encoder.patch_slen)

# check image patches

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

indx = int(np.random.choice(image_patches.shape[0], 1))
# indx = int(np.random.choice(torch.where(true_patch_n_stars == 2)[0].numpy(), 1))

boo = plotting_utils.plot_subimage(axarr[0], images[0, band],
                            map_locs.squeeze(), 
                            true_locs.squeeze(), 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            patch_slen = star_encoder.patch_slen, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[1], vae_recon_mean[0, band],
                            map_locs.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            patch_slen = star_encoder.patch_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = torch.log10(vae_recon_mean[0, band]) - torch.log10(images[0, band])
plotting_utils.plot_subimage(axarr[2], foo * 2.5, 
                            map_locs.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            patch_slen = star_encoder.patch_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

axarr[0].axvline(x=2.5, color = 'r')
axarr[0].axvline(x=4.5, color = 'r')
axarr[0].axhline(y=2.5, color = 'r')
axarr[0].axhline(y=4.5, color = 'r')

axarr[1].axvline(x=2.5, color = 'r')
axarr[1].axvline(x=4.5, color = 'r')
axarr[1].axhline(y=2.5, color = 'r')
axarr[1].axhline(y=4.5, color = 'r')

In [None]:
sdss_dataset_lib.convert_nmgy_to_mag(true_fluxes[0, boo[0], :] / n_elect_per_nmgy)

In [None]:
sdss_dataset_lib.convert_nmgy_to_mag(map_fluxes[0, boo[1], :] / n_elect_per_nmgy)

In [None]:
true_locs[0, boo[0], :]

In [None]:
map_locs[0, boo[1], :]

# Check out some summary statistics 

In [None]:
pad = 5

In [None]:
# tpr and ppv 
tpr, ppv, tpr_bool, ppv_bool = \
    image_statistics_lib.get_summary_stats(map_locs.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder.slen, 
                                           map_fluxes.squeeze(0)[:, 0], 
                                           true_fluxes.squeeze(0)[:, 0], 
                                          n_elect_per_nmgy, pad = pad)
    
print('tpr: {:0.3f}'.format(tpr))
print('ppv: {:0.3f}'.format(ppv))

In [None]:
tpr_vec, mag_vec, counts = \
    image_statistics_lib.get_tpr_vec(map_locs.squeeze(), 
                                           true_locs.squeeze(), 
                                           star_encoder.slen, 
                                           map_fluxes.squeeze(0)[:, 0], 
                                           true_fluxes.squeeze(0)[:, 0], 
                                             n_elect_per_nmgy, pad = pad)

plt.plot(mag_vec[0:-1], tpr_vec, '--x')
plt.xlabel('true log flux')
plt.ylabel('tpr')

In [None]:
ppv_vec, mag_vec = \
    image_statistics_lib.get_ppv_vec(map_locs.squeeze(), 
                                        true_locs.squeeze(), 
                                        star_encoder.slen, 
                                        map_fluxes.squeeze(0)[:, 0], 
                                        true_fluxes.squeeze(0)[:, 0], 
                                        n_elect_per_nmgy, pad = pad)[0:2]

plt.plot(mag_vec[0:-1], ppv_vec, '--x')
plt.xlabel('estimated log flux')
plt.ylabel('ppv')

# Posterior samples

In [None]:
n_samples = 5

In [None]:
sampled_locs, sampled_fluxes, sampled_n_stars = \
    star_encoder.sample_star_encoder(images,  
                                     patch_n_stars = None, 
                                    return_map_n_stars = False,
                                    return_map_star_params = False, 
                                    n_samples = n_samples)[0:3]

In [None]:
sampled_n_stars

In [None]:
sampled_fluxes.shape

In [None]:
tpr_sampled = torch.zeros(n_samples)
ppv_sampled = torch.zeros(n_samples)
for i in range(n_samples): 
    # tpr and ppv
    n_stars_i = sampled_n_stars[i]
    tpr_sampled[i], ppv_sampled[i] = \
        image_statistics_lib.get_summary_stats(sampled_locs[i][0:n_stars_i], 
                                               true_locs.squeeze(), 
                                               star_encoder.slen, 
                                               sampled_fluxes[i][0:n_stars_i, 0], 
                                               true_fluxes.squeeze(0)[:, 0], 
                                              n_elect_per_nmgy)[0:2]

print('tpr: {:0.3f}'.format(tpr_sampled.mean()))
print('true positive rate: {:0.3f}'.format(ppv_sampled.mean()))

In [None]:
# get reconstructed mean
recon_sampled = simulator.draw_image_from_params(locs = sampled_locs, 
                                                fluxes = sampled_fluxes,
                                                 n_stars = sampled_n_stars, 
                                                 add_noise = False).detach()

In [None]:
recon_sampled.shape

In [None]:
for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    im0 = axarr[0].matshow(images[0, band][5:95, 5:95])
    fig.colorbar(im0, ax = axarr[0])

    im1 = axarr[1].matshow(recon_sampled[i, band][5:95, 5:95])
    fig.colorbar(im1, ax = axarr[1])

    residual = torch.log10(recon_sampled[i, band]) - torch.log10(images[0, band])
    _residual = (residual * 2.5)[5:95, 5:95]
    # (torch.log(vae_recon_mean.squeeze()) - torch.log(images.squeeze()))[10:90, 10:90]
    vmax = _residual.abs().max()
    im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax = axarr[2])

In [None]:
indx = 1483 # int(np.random.choice(image_patches.shape[0], 1))

f, axarr = plt.subplots(1, 3, figsize=(16, 4))
plotting_utils.plot_subimage(axarr[0], images[0, band],
                            map_locs.squeeze(), 
                            true_locs.squeeze(), 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            patch_slen = star_encoder.patch_slen, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[1], vae_recon_mean[0, band],
                            map_locs.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            patch_slen = star_encoder.patch_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = torch.log10(vae_recon_mean[0, band] / images[0, band])
plotting_utils.plot_subimage(axarr[2], foo, 
                            map_locs.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            patch_slen = star_encoder.patch_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

axarr[0].axvline(x=2, color = 'r')
axarr[0].axvline(x=4, color = 'r')
axarr[0].axhline(y=2, color = 'r')
axarr[0].axhline(y=4, color = 'r')

axarr[1].axvline(x=2, color = 'r')
axarr[1].axvline(x=4, color = 'r')
axarr[1].axhline(y=2, color = 'r')
axarr[1].axhline(y=4, color = 'r')

for i in range(n_samples): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    plotting_utils.plot_subimage(axarr[0], images[0, band],
                                sampled_locs[i], 
                                true_locs.squeeze(), 
                                int(star_encoder.tile_coords[indx, 0]), 
                                int(star_encoder.tile_coords[indx, 1]), 
                                patch_slen = star_encoder.patch_slen, 
                                add_colorbar = True, 
                                global_fig = f)

    plotting_utils.plot_subimage(axarr[1], recon_sampled[i, band],
                                sampled_locs[i], 
                                None, 
                                int(star_encoder.tile_coords[indx, 0]), 
                                int(star_encoder.tile_coords[indx, 1]), 
                                patch_slen = star_encoder.patch_slen, 
                                add_colorbar = True, 
                                global_fig = f)

    foo = torch.log10(recon_sampled[i, band] / images[0, band])
    plotting_utils.plot_subimage(axarr[2], foo, 
                                sampled_locs[i], 
                                None, 
                                int(star_encoder.tile_coords[indx, 0]), 
                                int(star_encoder.tile_coords[indx, 1]), 
                                patch_slen = star_encoder.patch_slen, 
                                add_colorbar = True, 
                                global_fig = f, 
                                diverging_cmap = True)

    axarr[0].axvline(x=2, color = 'r')
    axarr[0].axvline(x=4, color = 'r')
    axarr[0].axhline(y=2, color = 'r')
    axarr[0].axhline(y=4, color = 'r')

    axarr[1].axvline(x=2, color = 'r')
    axarr[1].axvline(x=4, color = 'r')
    axarr[1].axhline(y=2, color = 'r')
    axarr[1].axhline(y=4, color = 'r')

In [None]:
sampled_locs_full_image, _, sampled_n_stars_full = \
    star_encoder.sample_star_encoder(images, 
                                     patch_n_stars = None, 
                                    return_map_n_stars = False,
                                    return_map_star_params = False, 
                                    n_samples = 100)[0:3]

In [None]:
f, axarr = plt.subplots(1, 1, figsize=(4, 4))

foo = sampled_locs_full_image.view(-1, 2)

indx = int(np.random.choice(image_patches.shape[0], 1))

plotting_utils.plot_subimage(axarr, images[0, band],
                            foo[foo[:, 0] > 0], # map_locs.squeeze(), 
                            true_locs.squeeze(), 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            patch_slen = star_encoder.patch_slen, 
                            add_colorbar = True, 
                            global_fig = f, alpha = 0.2)
