In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import sdss_dataset_lib
import plotting_utils
import image_statistics_lib
import utils

import inv_kl_objective_lib as inv_kl_lib

import image_utils

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
np.random.seed(22)
_ = torch.manual_seed(22)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

In [None]:
data_params

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = np.array([psf_r, psf_i])

n_bands = psf_og.shape[0]

sky_intensity = torch.Tensor([686., 1123.])

In [None]:
# Draw from the same distribution I used int the sleep phase
n_images = 1

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                    data_params,
                    sky_intensity = sky_intensity,
                    n_images = n_images,
                    add_noise = True)
        
images_full = simulated_dataset.images.detach()
backgrounds_full = simulated_dataset.background.detach()
        
which_on = (simulated_dataset.fluxes > 0).any(2).squeeze()
        
true_full_locs = simulated_dataset.locs[:, which_on, :]
true_full_fluxes = simulated_dataset.fluxes[:, which_on, :]
        
        
simulator = simulated_dataset.simulator

In [None]:
simulated_dataset.n_stars

# Load VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                            stamp_slen = 7,
                                            step = 2,
                                            edge_padding = 2, 
                                            n_bands = 2,
                                            max_detections = 2)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/results_11122019/starnet_ri',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

# image stamps

In [None]:
image_stamps, true_subimage_locs, true_subimage_fluxes, \
    true_subimage_n_stars, true_is_on_array = \
        star_encoder.get_image_stamps(images_full, true_full_locs, true_full_fluxes, 
                                      trim_images = False, clip_max_stars = True)

background_stamps = star_encoder.get_image_stamps(backgrounds_full, None, None, trim_images=False)[0]

# first test with n_stars fixed

In [None]:
n_samples = 100

In [None]:
# sample! get full image parameters

In [None]:
locs_full_image, fluxes_full_image, n_stars_full, \
    log_q_locs, log_q_fluxes, log_q_n_stars = \
        star_encoder.sample_star_encoder(images_full,
                                        backgrounds_full,
                                        n_samples = n_samples, 
                                        n_stars = true_subimage_n_stars)

In [None]:
# convert sampled full image parameters to subimage parameters
subimage_locs_sampled = torch.zeros((n_samples, star_encoder.tile_coords.shape[0], 
                                     star_encoder.max_detections, 2))

subimage_fluxes_sampled = torch.zeros((n_samples, star_encoder.tile_coords.shape[0], 
                                     star_encoder.max_detections, star_encoder.n_bands))

# doing this in one batch freezes my laptop ..
for i in range(n_samples): 
    if (i % 20 == 0): 
        print(i)
        
    subimage_locs, subimage_fluxes, n_stars_patches, _ = \
        image_utils.get_params_in_patches(star_encoder.tile_coords, 
                                        locs_full_image[i:(i+1)], 
                                        fluxes_full_image[i:(i+1)], 
                                        star_encoder.full_slen, 
                                        star_encoder.stamp_slen, 
                                        star_encoder.edge_padding)
    
    subimage_locs_sampled[i] = subimage_locs
    subimage_fluxes_sampled[i] = subimage_fluxes
    
    assert torch.all(true_subimage_n_stars == n_stars_patches)
    

In [None]:
# get subimage variational parameters

In [None]:
logit_loc_mean, logit_loc_logvar, \
    log_flux_mean, log_flux_logvar, log_probs = \
        star_encoder(image_stamps, background_stamps, true_subimage_n_stars)

In [None]:
# check subimage parameters match a normal!!

In [None]:
from scipy import stats

In [None]:
# locations
z_stats = \
    (utils._logit(subimage_locs_sampled).mean(0) - logit_loc_mean) * \
         true_is_on_array.float().unsqueeze(2) / torch.exp(0.5 * logit_loc_logvar)
    
z_stats = z_stats.detach()
z_stats = z_stats.flatten()[z_stats.flatten() != 0]
h = plt.hist(z_stats, bins = 50, density = 1);

pdf = stats.norm.pdf(h[1], z_stats.mean(), z_stats.std())

plt.plot(h[1], pdf)

In [None]:
# fluxes 
z_stats = \
    (torch.log(subimage_fluxes_sampled + 1e-18).mean(0) - log_flux_mean) * \
         true_is_on_array.float().unsqueeze(2) / torch.exp(0.5 * log_flux_logvar)
    
z_stats = z_stats.detach()
z_stats = z_stats.flatten()[z_stats.flatten() != 0]
h = plt.hist(z_stats, bins = 50, density = 1);

pdf = stats.norm.pdf(h[1], z_stats.mean(), z_stats.std())

plt.plot(h[1], pdf)

# Visually inspect

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf=psf_og, 
                                                 slen = images_full.shape[-1], 
                                                 sky_intensity=sky_intensity)

In [None]:
# map locations and fluxes

In [None]:
map_locs_full_image, map_fluxes_full_image, map_n_stars_full, _, _, _ = \
    star_encoder.sample_star_encoder(images_full,
                                        backgrounds_full,
                                       n_samples = 1, return_map = True)

In [None]:
map_recon_mean = simulator.draw_image_from_params(locs = map_locs_full_image,
                                                  fluxes = map_fluxes_full_image, 
                                                  n_stars = map_n_stars_full)

In [None]:
band = 0

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(images_full[0, band])

axarr[1].matshow(map_recon_mean[0, band].detach())

_resid = (map_recon_mean[0, band].detach() - images_full[0, band]) / images_full[0, band]
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])


In [None]:
n_samples = 10

In [None]:
locs_full_image, fluxes_full_image, n_stars_full, \
    log_q_locs, log_q_fluxes, log_q_n_stars = \
        star_encoder.sample_star_encoder(images_full,
                                        backgrounds_full,
                                        n_samples = n_samples, 
                                        n_stars = None)

In [None]:
sampled_recon_mean = simulator.draw_image_from_params(locs = locs_full_image,
                                                  fluxes = fluxes_full_image, 
                                                  n_stars = n_stars_full, 
                                                    add_noise = False)


In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

axarr[0].matshow(images_full[0, band])

axarr[1].matshow(map_recon_mean[0, band].detach())
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean[0, band].detach() - images_full[0, band])
vmax = _resid.abs().max()
im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                       cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])
axarr[2].set_title('map residual')

for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(images_full[0, band])
    axarr[0].set_title('observed')

    axarr[1].matshow(sampled_recon_mean[i, band].detach())
    axarr[1].set_title('sample reconstruction '+ str(i))


    _resid = sampled_recon_mean[i, band].detach() - images_full[0, band]
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('sample residual ' + str(i))



In [None]:
x0 = 30
x1 = 50
subimage_slen = 10

fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

plotting_utils.plot_subimage(axarr[0], images_full[0, band], 
                            map_locs_full_image.squeeze(), 
                            true_full_locs.squeeze(), 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[0].set_title('observed')

plotting_utils.plot_subimage(axarr[1], map_recon_mean[0, band], 
                            map_locs_full_image.squeeze(), 
                            None, 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig)
axarr[1].set_title('map reconstruction')

_resid = (map_recon_mean[0, band].detach() - images_full[0, band])
plotting_utils.plot_subimage(axarr[2], _resid, 
                            map_locs_full_image.squeeze(), 
                            None, 
                            x0, x1, subimage_slen, 
                            add_colorbar = True, 
                            global_fig = fig, 
                            diverging_cmap = True)
axarr[1].set_title('map residual')

for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    plotting_utils.plot_subimage(axarr[0], images_full[0, band], 
                                locs_full_image[i].squeeze(), 
                                true_full_locs.squeeze(), 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig)
    axarr[0].set_title('observed')

    plotting_utils.plot_subimage(axarr[1], sampled_recon_mean[i, band].detach(), 
                                locs_full_image[i].squeeze(), 
                                None, 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig)
    axarr[1].set_title('sampled reconstruction')

    _resid = (sampled_recon_mean[i, band].detach() - images_full[0, band])
    plotting_utils.plot_subimage(axarr[2], _resid.squeeze(), 
                                locs_full_image[i].squeeze(), 
                                None, 
                                x0, x1, subimage_slen, 
                                add_colorbar = True, 
                                global_fig = fig, 
                                diverging_cmap = True)
    axarr[1].set_title('sampled residual')

