In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import sdss_dataset_lib
import plotting_utils
import image_statistics_lib
import utils

import inv_kl_objective_lib as inv_kl_lib

import image_utils

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
np.random.seed(22)
_ = torch.manual_seed(22)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

In [None]:
data_params

In [None]:
n_bands = 1

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

# just for simplicity ... 
if n_bands == 1: 
    psf_og = np.array([psf_r])
    sky_intensity = torch.Tensor([686.])
elif n_bands == 2: 
    psf_og = np.array([psf_r, psf_i])
    sky_intensity = torch.Tensor([686., 1123.])
else: 
    assert 1 == 2


In [None]:
psf_og.shape

In [None]:
# Draw from the same distribution I used int the sleep phase
n_images = 1

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                    data_params,
                    sky_intensity = sky_intensity,
                    n_images = n_images,
                    transpose_psf = False, 
                    add_noise = True)
        
images_full = simulated_dataset.images.detach()
backgrounds_full = simulated_dataset.background.detach()
        
which_on = (simulated_dataset.fluxes > 0).any(2).squeeze()
        
true_full_locs = simulated_dataset.locs[:, which_on, :]
true_full_fluxes = simulated_dataset.fluxes[:, which_on, :]
        
        
simulator = simulated_dataset.simulator

In [None]:
true_full_locs.shape

In [None]:
# histogram of fluxes
for i in range(n_bands): 
    plt.hist(np.log10(true_full_fluxes[:, :, i].numpy().flatten()), bins = 100);

In [None]:
if n_bands > 1: 
    # histogram of colors
    colors = np.log10(true_full_fluxes[:, :, 0].numpy().flatten() / \
                      true_full_fluxes[:, :, 1].numpy().flatten())
    plt.hist(colors, bins = 50);

In [None]:
# color x flux
if n_bands > 1: 
    plt.scatter(colors, np.log10(true_full_fluxes[:, :, 0].numpy().flatten()), marker = 'x', alpha = 0.5)

In [None]:
images_full.min()

In [None]:
true_full_fluxes.min()

# Load VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = n_bands,
                                            max_detections = 2)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/results_2020-01-21/starnet_r',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
# check loss 
loss, counter_loss, locs_loss, fluxes_loss, perm_indx = \
    inv_kl_lib.get_encoder_loss(star_encoder, images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes)[0:5]

In [None]:
print('loss: {:06f}'.format(loss))

In [None]:
print(counter_loss.mean())
print(locs_loss.mean())
print(fluxes_loss.mean())

# Get image stamps

In [None]:
# get image stamps
image_stamps, true_subimage_locs, true_subimage_fluxes, \
    true_subimage_n_stars, true_is_on_array = \
        star_encoder.get_image_stamps(images_full, true_full_locs, true_full_fluxes, 
                                      trim_images = False, clip_max_stars = True)
    
background_stamps = star_encoder.get_image_stamps(backgrounds_full, None, None, 
                                      trim_images = False)[0]

In [None]:
foo = plt.hist(true_subimage_n_stars, bins=np.arange(max(true_subimage_n_stars) + 2))[0]
# plt.plot(foo[0] / star_encoder.weights, 'x')

In [None]:
plt.hist(locs_loss.detach()[locs_loss != 0], bins = 100);


# get inferred parameters on stamps

In [None]:
# Note that these variational parameters are estimated using the true number of stars!
stamp_logit_loc_mean, stamp_logit_loc_log_var, \
    stamp_log_flux_mean, stamp_log_flux_log_var, stamp_log_probs = \
        star_encoder(image_stamps, background_stamps, true_subimage_n_stars)

In [None]:
map_n_stars_stamps = torch.argmax(stamp_log_probs, dim = 1).detach()

In [None]:
(map_n_stars_stamps == true_subimage_n_stars).float().mean()

In [None]:
plt.hist(map_n_stars_stamps, bins = np.arange(star_encoder.max_detections + 2))

In [None]:
from itertools import permutations

In [None]:
perm_list = []
for perm in permutations(range(star_encoder.max_detections)):
    perm_list.append(perm)

In [None]:
perm = np.zeros((image_stamps.shape[0], star_encoder.max_detections))
for i in range(image_stamps.shape[0]): 
    perm[i, :] = perm_list[perm_indx[i]]

### check parameters

In [None]:
# permute true parameters 
def permute_params(locs, fluxes, perm): 
    batchsize = perm.shape[0]
    max_stars = perm.shape[1]
    
    n_bands = fluxes.shape[-1]

    locs_perm = torch.zeros((batchsize, max_stars, 2))
    fluxes_perm = torch.zeros((batchsize, max_stars, n_bands))
    seq_tensor = torch.LongTensor([i for i in range(batchsize)])

    for i in range(max_stars):
        locs_perm[:, i, :] = locs[seq_tensor, perm[:, i], :]
        fluxes_perm[:, i, :] = fluxes[seq_tensor, perm[:, i], :]
        
    return locs_perm, fluxes_perm

In [None]:
locs_perm, fluxes_perm = permute_params(stamp_logit_loc_mean, stamp_log_flux_mean, perm)

In [None]:
map_subimage_locs = (torch.sigmoid(locs_perm) * true_is_on_array.unsqueeze(2).float()).detach()
map_subimage_fluxes = (torch.exp(fluxes_perm) * true_is_on_array.unsqueeze(2).float()).detach()

In [None]:
plt.plot(map_subimage_locs.flatten()[map_subimage_locs.flatten() > 0].detach(), 
         true_subimage_locs.flatten()[true_subimage_locs.flatten() > 0], '+')

plt.plot(map_subimage_locs.flatten()[map_subimage_locs.flatten() > 0].detach(), 
         map_subimage_locs.flatten()[map_subimage_locs.flatten() > 0].detach(), '-')

plt.xlabel('estimated')
plt.ylabel('truth')

In [None]:
plt.plot(torch.log(true_subimage_fluxes.flatten()[true_subimage_fluxes.flatten() > 0]), 
         torch.log(map_subimage_fluxes.flatten()[map_subimage_fluxes.flatten() > 0].detach()), '+')

plt.plot(torch.log(map_subimage_fluxes.flatten()[map_subimage_fluxes.flatten() > 0].detach()), 
         torch.log(map_subimage_fluxes.flatten()[map_subimage_fluxes.flatten() > 0].detach()), '-')

plt.xlabel('truth')
plt.ylabel('estimated')

In [None]:
plt.hist(torch.log10(map_subimage_fluxes.flatten()[map_subimage_fluxes.flatten() > 0]).detach() - \
         torch.log10(true_subimage_fluxes.flatten()[true_subimage_fluxes.flatten() > 0]), 
        bins = 100);

In [None]:
# Not just fluxes, but also color?

if n_bands > 1: 
    map_color = torch.log10(map_subimage_fluxes[:, :, 1].flatten()[map_subimage_fluxes[:, :, 1].flatten() > 0] / \
                            map_subimage_fluxes[:, :, 0].flatten()[map_subimage_fluxes[:, :, 0].flatten() > 0])

    true_color = torch.log10(true_subimage_fluxes[:, :, 1].flatten()[true_subimage_fluxes[:, :, 1].flatten() > 0] / \
                            true_subimage_fluxes[:, :, 0].flatten()[true_subimage_fluxes[:, :, 0].flatten() > 0])

    plt.plot(map_color, true_color, 'x')
    plt.plot(map_color, map_color, '-')

In [None]:

if n_bands > 1: 
    fig, axarr = plt.subplots(1, 2, figsize=(7, 4))

    axarr[0].scatter(true_color, 
                     true_subimage_fluxes[:, :, 1].flatten()[true_subimage_fluxes[:, :, 1].flatten() > 0])

    axarr[0].scatter(map_color, 
                     map_subimage_fluxes[:, :, 1].flatten()[true_subimage_fluxes[:, :, 1].flatten() > 0], )

# Check reconstructions 

In [None]:
use_true_n_stars = False
if use_true_n_stars: 
    _n_stars = true_subimage_n_stars
else: 
    _n_stars = None

# get parameters on the full image 
map_locs_full_image, map_fluxes_full_image, map_n_stars_full = \
    star_encoder.sample_star_encoder(images_full, backgrounds_full, 
                                     return_map=True, n_stars = _n_stars)[0:3]
    
if _n_stars is not None: 
    assert map_n_stars_full == torch.sum(_n_stars)

assert map_n_stars_full == map_locs_full_image.shape[1]
assert map_n_stars_full == map_fluxes_full_image.shape[1]

In [None]:
foo = plt.hist(torch.log10(true_full_fluxes).flatten())[1]; 
plt.hist(torch.log10(map_fluxes_full_image.flatten()), bins = foo, alpha = 0.5)

In [None]:
# get reconstructed mean
vae_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image, 
                                                fluxes = map_fluxes_full_image,
                                                 n_stars = map_n_stars_full, 
                                                 add_noise = False).detach()

In [None]:
band = 0

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
im0 = axarr[0].matshow(images_full[0, band])
fig.colorbar(im0, ax = axarr[0])

im1 = axarr[1].matshow(vae_recon_mean[0, band])
fig.colorbar(im1, ax = axarr[1])

residual = vae_recon_mean[0, band] - images_full[0, band]
_residual = (residual / images_full[0, band])[5:95, 5:95]
# (torch.log(vae_recon_mean.squeeze()) - torch.log(images_full.squeeze()))[10:90, 10:90]
vmax = _residual.abs().max()
im2 = axarr[2].matshow(_residual, vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])

In [None]:
def get_which_tile(x0, x1, tile_coords, edge_padding, stamp_slen): 
    coords = tile_coords + edge_padding
    
    view_slen = stamp_slen - 2 * edge_padding
    
    indx = torch.where((x0 > coords[:, 0]) & \
                       (x0 < coords[:, 0] + view_slen) & \
                       (x1 > coords[:, 1]) & \
                       (x1 < coords[:, 1] + view_slen))
        
    return tile_coords[indx], indx

# check image patches

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

indx = int(np.random.choice(image_stamps.shape[0], 1))
# indx = int(np.random.choice(torch.where(true_subimage_n_stars == 2)[0].numpy(), 1))

plotting_utils.plot_subimage(axarr[0], images_full[0, band],
                            map_locs_full_image.squeeze(), 
                            true_full_locs.squeeze(), 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[1], vae_recon_mean[0, band],
                            map_locs_full_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f)

foo = vae_recon_mean[0, band] - images_full[0, band]
plotting_utils.plot_subimage(axarr[2], foo, 
                            map_locs_full_image.squeeze(), 
                            None, 
                            int(star_encoder.tile_coords[indx, 0]), 
                            int(star_encoder.tile_coords[indx, 1]), 
                            subimage_slen = star_encoder.stamp_slen, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

axarr[0].axvline(x=2, color = 'r')
axarr[0].axvline(x=4, color = 'r')
axarr[0].axhline(y=2, color = 'r')
axarr[0].axhline(y=4, color = 'r')

axarr[1].axvline(x=2, color = 'r')
axarr[1].axvline(x=4, color = 'r')
axarr[1].axhline(y=2, color = 'r')
axarr[1].axhline(y=4, color = 'r')

# On any arbitrary patch of the image 

In [None]:
w = 9
x0_vec = np.arange(star_encoder.edge_padding, 
                   star_encoder.full_slen - star_encoder.edge_padding - w, 
                  w)

x1_vec = x0_vec

In [None]:
x0_vec

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

plotting_utils.plot_subimage(axarr[0], images_full[0, band],
                            map_locs_full_image.squeeze(), 
                            true_full_locs.squeeze(), 
                            x0, x1, 
                            subimage_slen = w, 
                            add_colorbar = True, 
                            global_fig = fig)

axarr[0].set_title('observed; coords = {}'.format([x0, x1]));

plotting_utils.plot_subimage(axarr[1], vae_recon_mean[0, band],
                            map_locs_full_image.squeeze(), 
                            None,  
                            x0, x1, 
                            subimage_slen = w, 
                            add_colorbar = True, 
                            global_fig = fig)

axarr[1].set_title('reconstructed; coords = {}'.format([x0, x1]));


residual = (vae_recon_mean[0, band] - images_full[0, band])
plotting_utils.plot_subimage(axarr[2], residual, 
                            map_locs_full_image.squeeze(), 
                            None,  
                            x0, x1, 
                            subimage_slen = w, 
                            add_colorbar = True, 
                            global_fig = fig, diverging_cmap = True)

axarr[2].set_title('residual; coords = {}'.format([x0, x1]));


# Check out some summary statistics 

In [None]:
def filter_params(locs, fluxes, slen, pad): 
    assert len(locs.shape) == 2
    assert len(fluxes.shape) == 1
    
    _locs = locs * (slen - 1)
    which_params = (_locs[:, 0] > pad) & (_locs[:, 0] < (slen - pad - 1)) & \
                        (_locs[:, 1] > pad) & (_locs[:, 1] < (slen - pad - 1))
        
    
    return locs[which_params], fluxes[which_params]

In [None]:
true_locs, true_fluxes = filter_params(true_full_locs.squeeze(), 
                                          true_full_fluxes.squeeze(0)[:, 0], 
                                          slen = star_encoder.full_slen,
                                          pad = star_encoder.edge_padding)

est_locs, est_fluxes = filter_params(map_locs_full_image.squeeze(), 
                                          map_fluxes_full_image.squeeze(0)[:, 0], 
                                          slen = star_encoder.full_slen,
                                          pad = star_encoder.edge_padding)

In [None]:
# _recon_mean = simulator.draw_image_from_params(
#                                 locs = est_locs.unsqueeze(0), 
#                                 fluxes = est_fluxes.unsqueeze(0),
#                                 n_stars = torch.Tensor([est_locs.shape[0]]).type(torch.LongTensor), 
#                                 add_noise = False).squeeze()

# _recon_truth = \
#     simulator.draw_image_from_params(locs = true_locs.unsqueeze(0), 
#                                     fluxes = true_fluxes.unsqueeze(0),
#                                      n_stars = torch.Tensor([len(true_locs)]).type(torch.LongTensor), 
#                                      add_noise = False).squeeze()


# fig, axarr = plt.subplots(1, 4, figsize=(15, 6))

# axarr[0].matshow(_recon_mean)
# axarr[2].matshow(_recon_truth)

# axarr[1].matshow((vae_recon_mean - images_full.squeeze()))

# axarr[3].matshow((_recon_truth - images_full.squeeze()))


In [None]:
# completeness and tpr using locations only
completeness, tpr = \
    image_statistics_lib.get_summary_stats(est_locs, true_locs, star_encoder.full_slen, None, None)[0:2]
    
print('completeness: {:0.3f}'.format(completeness))
print('true positive rate: {:0.3f}'.format(tpr))

In [None]:
# completeness and tpr incorporating fluxes
completeness, tpr = \
    image_statistics_lib.get_summary_stats(est_locs, true_locs, star_encoder.full_slen, est_fluxes, true_fluxes)[0:2]
    
print('completeness: {:0.3f}'.format(completeness))
print('true positive rate: {:0.3f}'.format(tpr))

In [None]:
completeness_vec, mag_vec = \
    image_statistics_lib.get_completeness_vec(est_locs, true_locs, star_encoder.full_slen,
                                              est_fluxes, true_fluxes)[0:2]
plt.plot(mag_vec[0:-1], completeness_vec, '--x')
plt.xlabel('true log flux')
plt.ylabel('completeness')

In [None]:
tpr_vec, mag_vec = \
    image_statistics_lib.get_tpr_vec(est_locs, true_locs, star_encoder.full_slen,
                                              est_fluxes, true_fluxes)[0:2]

plt.plot(mag_vec[0:-1], tpr_vec, '--x')
plt.xlabel('estimated log flux')
plt.ylabel('tpr')

# Compare true SDSS image with simulated SDSS image

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = [2, 3])

In [None]:
# simulate data using hubble parameters
sim_images_full = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                        fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                        n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                        add_noise = True) 

# the oberved data 
sdss_images_full = sdss_hubble_data.sdss_image.unsqueeze(0)

# get true parameters
backgrounds_full = sdss_hubble_data.sdss_background.unsqueeze(0)

which_bright = sdss_hubble_data.fluxes[:, 0] > data_params['f_min']
true_full_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
true_full_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)

In [None]:
band = 1

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 6))

im0 = axarr[0].matshow(sdss_images_full[0, band]); 
f.colorbar(im0, ax = axarr[0])
axarr[0].set_title('true sdss image')

im1 = axarr[1].matshow(sim_images_full[0, band]); 
f.colorbar(im1, ax = axarr[1])
axarr[1].set_title('observed sdss image')


residual = torch.log10(sim_images_full[0, band]) - torch.log10(sdss_images_full[0, band])
vmax = residual[10:90, 10:90].abs().max()
im2 = axarr[2].matshow(residual[10:90, 10:90], vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr')); 
f.colorbar(im2, ax = axarr[2])
axarr[2].set_title('residual')

In [None]:
sdss_image_stamps, true_subimage_locs, true_subimage_fluxes, true_subimage_n_stars, true_is_on_array = \
        star_encoder.get_image_stamps(sdss_images_full, true_full_locs, true_full_fluxes, 
                                      trim_images = False)

background_stamps = \
        star_encoder.get_image_stamps(backgrounds_full, true_full_locs, true_full_fluxes, 
                                      trim_images = False)[0]

In [None]:
logit_loc_mean, logit_loc_log_var, \
        log_flux_mean, log_flux_log_var, log_probs = \
            star_encoder(sdss_image_stamps, background_stamps)
(log_probs.argmax(1) == true_subimage_n_stars).float().mean()

In [None]:
sim_image_stamps, true_subimage_locs, true_subimage_fluxes, true_subimage_n_stars, true_is_on_array = \
        star_encoder.get_image_stamps(sim_images_full, true_full_locs, true_full_fluxes, 
                                      trim_images = False)
    
logit_loc_mean, logit_loc_log_var, \
        log_flux_mean, log_flux_log_var, log_probs = \
            star_encoder(sim_image_stamps, background_stamps)
(log_probs.argmax(1) == true_subimage_n_stars).float().mean()

In [None]:
# get parameters on the simulated image 
map_locs_sim_image, map_fluxes_sim_image, map_n_stars_sim_image = \
        star_encoder.sample_star_encoder(sim_images_full, backgrounds_full, 
                                               return_map = True)[0:3]

In [None]:
# get parameters on the sdss image 
map_locs_sdss_image, map_fluxes_sdss_image, map_n_stars_sdss_image = \
        star_encoder.sample_star_encoder(sdss_images_full, backgrounds_full, 
                                               return_map = True)[0:3]

## Check out losses

In [None]:
loss, counter_loss, locs_loss, fluxes_loss, perm = \
    inv_kl_lib.get_encoder_loss(star_encoder, sim_images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes)
    
print(loss)

In [None]:
loss, counter_loss, locs_loss, fluxes_loss, perm = \
    inv_kl_lib.get_encoder_loss(star_encoder, sdss_images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes)
    
print(loss)

### OK more interpretable ... lets look at l2 loss

In [None]:
loss, counter_loss_sim, locs_loss_sim, fluxes_loss_sim, _ = \
    inv_kl_lib.get_encoder_loss(star_encoder, sim_images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes, use_l2_loss = True)
    
print(loss)

In [None]:
loss, counter_loss_sdss, locs_loss_sdss, fluxes_loss_sdss, perm = \
    inv_kl_lib.get_encoder_loss(star_encoder, sdss_images_full, backgrounds_full, 
                                true_full_locs, true_full_fluxes, use_l2_loss = True)

print(loss)

In [None]:
plt.plot(counter_loss_sim.detach(), 
         counter_loss_sdss.detach(), '+')
plt.plot(counter_loss_sim.detach(), 
         counter_loss_sim.detach(), '-')
plt.xlabel('sim')
plt.ylabel('sdss')

In [None]:
(counter_loss_sim < counter_loss_sdss).float().mean()

In [None]:
plt.plot(locs_loss_sim.detach(), 
         locs_loss_sdss.detach(), '+')
plt.plot(locs_loss_sim.detach(), 
         locs_loss_sim.detach(), '-')
plt.xlabel('sim')
plt.ylabel('sdss')

In [None]:
(locs_loss_sim[locs_loss_sim > 0] < locs_loss_sdss[locs_loss_sim > 0]).float().mean()

In [None]:
plt.plot(fluxes_loss_sim.detach(), 
         fluxes_loss_sdss.detach(), '+')
plt.plot(fluxes_loss_sim.detach(), 
         fluxes_loss_sim.detach(), '-')
plt.xlabel('sim')
plt.ylabel('sdss')

In [None]:
indx = int(np.random.choice(star_encoder.tile_coords.shape[0], 1))
x0 = int(star_encoder.tile_coords[indx, 0])
x1 = int(star_encoder.tile_coords[indx, 1])

f, axarr = plt.subplots(1, 3, figsize=(16, 6))

plotting_utils.plot_subimage(axarr[0], sdss_images_full[0, band],
                            map_locs_sdss_image.squeeze(),  
                            true_full_locs.squeeze(),  
                            x0, x1, 
                            subimage_slen = star_encoder.stamp_slen)

axarr[0].axvline(x=3, color = 'r')
axarr[0].axvline(x=5, color = 'r')
axarr[0].axhline(y=3, color = 'r')
axarr[0].axhline(y=5, color = 'r')
axarr[0].set_title('sdss')

plotting_utils.plot_subimage(axarr[1], sim_images_full[0, band],
                            map_locs_sim_image.squeeze(),  
                            true_full_locs.squeeze(),  
                            x0, x1, 
                            subimage_slen = star_encoder.stamp_slen)

axarr[1].axvline(x=3, color = 'r')
axarr[1].axvline(x=5, color = 'r')
axarr[1].axhline(y=3, color = 'r')
axarr[1].axhline(y=5, color = 'r')
axarr[1].set_title('simulated')

foo = torch.log(sim_images_full[0, band]) - torch.log(sdss_images_full[0, band])
plotting_utils.plot_subimage(axarr[2], foo,
                            None, 
                            true_full_locs.squeeze(),  
                            x0, x1, 
                            subimage_slen = star_encoder.stamp_slen, 
                            global_fig = f, add_colorbar = True)


# compare summary statistics 

In [None]:
true_locs, true_fluxes = filter_params(true_full_locs.squeeze(), 
                                          true_full_fluxes[0, :, 0], 
                                          slen = star_encoder.full_slen,
                                          pad = star_encoder.edge_padding)

est_locs_sim, est_fluxes_sim = filter_params(map_locs_sim_image.squeeze(), 
                                        map_fluxes_sim_image.squeeze()[:, 0], 
                                        slen = star_encoder.full_slen,
                                        pad = star_encoder.edge_padding)
est_locs_sdss, est_fluxes_sdss = filter_params(map_locs_sdss_image.squeeze(), 
                                        map_fluxes_sdss_image.squeeze()[:, 0], 
                                        slen = star_encoder.full_slen,
                                        pad = star_encoder.edge_padding)

completeness_vec, mag_vec = \
    image_statistics_lib.get_completeness_vec(est_locs_sim, true_locs, star_encoder.full_slen,
                                              est_fluxes_sim, true_fluxes)[0:2]
    
completeness_vec2, mag_vec2 = \
    image_statistics_lib.get_completeness_vec(est_locs_sdss, true_locs, star_encoder.full_slen,
                                              est_fluxes_sdss, true_fluxes)[0:2]

    
plt.plot(mag_vec[0:-1], completeness_vec, '--x', label = 'sim')
plt.plot(mag_vec2[0:-1], completeness_vec2, '--x', label = 'sdss')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('completeness')


In [None]:
tpr_vec, mag_vec = \
    image_statistics_lib.get_tpr_vec(est_locs_sim, true_locs, star_encoder.full_slen,
                                              est_fluxes_sim, true_fluxes)[0:2]
    
tpr_vec2, mag_vec2 = \
    image_statistics_lib.get_tpr_vec(est_locs_sdss, true_locs, star_encoder.full_slen,
                                              est_fluxes_sdss, true_fluxes)[0:2]

    
plt.plot(mag_vec[0:-1], tpr_vec, '--x', label = 'sim')
plt.plot(mag_vec2[0:-1], tpr_vec2, '--x', label = 'sdss')

plt.legend()
plt.xlabel('true log flux')
plt.ylabel('tpr')


# Look at image patches

In [None]:
# for i in range(1): 
#     fig, axarr = plt.subplots(1, 3, figsize=(16, 6))
#     indx = 1609 # int(np.random.choice(image_stamps.shape[0], 1))
#     # indx = np.random.choice(torch.where(true_subimage_n_stars > 3)[0].numpy(), 1)
    
#     x0 = int(star_encoder.tile_coords[indx, 0])
#     x1 = int(star_encoder.tile_coords[indx, 1]) 
    
#     # plot image stamp
#     im0 = axarr[0].matshow(image_stamps[indx].squeeze())
#     im0 = fig.colorbar(im0, ax=axarr[0])
    
#     # plot true locations      
#     patch_slen = (star_encoder.stamp_slen - 2 * star_encoder.edge_padding)
#     axarr[0].scatter(true_subimage_locs[indx, 0:true_subimage_n_stars[indx], 1] * (patch_slen - 1) + \
#                          star_encoder.edge_padding, 
#                     true_subimage_locs[indx, 0:true_subimage_n_stars[indx], 0] * (patch_slen - 1) + \
#                          star_encoder.edge_padding, 
#                     color = 'b')
    
#     axarr[0].scatter(map_subimage_locs[indx, 0:map_n_stars_stamps[indx], 1] * (patch_slen - 1) + \
#                          star_encoder.edge_padding, 
#                     map_subimage_locs[indx, 0:map_n_stars_stamps[indx], 0] * (patch_slen - 1) + \
#                          star_encoder.edge_padding, 
#                     color = 'r', marker = 'x')
    
    
#     axarr[0].axvline(x=3, color = 'r')
#     axarr[0].axvline(x=5, color = 'r')
#     axarr[0].axhline(y=3, color = 'r')
#     axarr[0].axhline(y=5, color = 'r')
    
#     axarr[0].set_title('observed; coords {}\n'.format([x0, x1]))
    
#     # plot reconstruction
#     recon_patch = vae_recon_mean[x0:(x0+star_encoder.stamp_slen), 
#                                    x1:(x1+star_encoder.stamp_slen)]
#     im1 = axarr[1].matshow(recon_patch)
    
#     axarr[1].axvline(x=3, color = 'r')
#     axarr[1].axvline(x=5, color = 'r')
#     axarr[1].axhline(y=3, color = 'r')
#     axarr[1].axhline(y=5, color = 'r')
    
#     axarr[1].scatter(map_subimage_locs[indx, 0:map_n_stars_stamps[indx], 1] * (patch_slen - 1) + \
#                          star_encoder.edge_padding, 
#                     map_subimage_locs[indx, 0:map_n_stars_stamps[indx], 0] * (patch_slen - 1) + \
#                          star_encoder.edge_padding, 
#                     color = 'r', marker = 'x')
#     fig.colorbar(im1, ax=axarr[1])
    
    
#     # plot residual
#     im2 = axarr[2].matshow(((recon_patch - image_stamps[indx].squeeze()))/image_stamps[indx].squeeze())
#     fig.colorbar(im2, ax=axarr[2])

In [None]:
def get_weights_from_n_stars(n_stars): 
    counts = torch.zeros(max(n_stars) + 1)
    
    for i in range(max(n_stars) + 1): 
        counts[i] = torch.sum(n_stars == i)
        
    weights = torch.zeros(len(n_stars))

    for i in range(max(n_stars) + 1): 
        weights = weights + len(n_stars) / counts[i] * (n_stars == i).float()
        
    return weights / weights.min()

In [None]:
weights = get_weights_from_n_stars(true_subimage_n_stars)

In [None]:
torch.unique(weights)

In [None]:
torch.unique(weights[true_subimage_n_stars == 0])

In [None]:
torch.unique(weights[true_subimage_n_stars == 4])

In [None]:
(true_subimage_n_stars == 0).float().sum() / (true_subimage_n_stars == 4).float().sum()

In [None]:
true_n_stars = true_subimage_n_stars.clone()

In [None]:
counts = torch.zeros(max(true_n_stars) + 1)

In [None]:
for i in range(max(true_n_stars) + 1): 
    counts[i] = torch.sum(true_n_stars == i)

In [None]:
weights = torch.zeros(len(true_n_stars))

for i in range(max(true_n_stars) + 1): 
    weights = weights + len(true_n_stars) / counts[i] * (true_n_stars == i).float()

In [None]:
weights = weights / weights.min()

In [None]:
true_n_stars

In [None]:
torch.histc(true_subimage_n_stars)