In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib
import starnet_vae_lib
import psf_transform_lib

import utils

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Load data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
full_image = sdss_hubble_data.sdss_image.unsqueeze(0).to(device)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0).to(device)

# Define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 9,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = 1,
                                           max_detections = 2)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet-10162019-reweighted', 
                                       map_location=lambda storage, loc: storage))
star_encoder.eval(); 

# Sample parameters

In [None]:
n_samples = 10

In [None]:
# sample variational parameters
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full, \
    log_q_locs, log_q_fluxes, log_q_n_stars = \
        star_encoder.sample_star_encoder(full_image, full_background,
                                n_samples, return_map = False,
                                return_log_q = True)


# Get loss and reconstruction 

In [None]:
psf_fit_file = str(sdss_hubble_data.psf_file)

In [None]:
psf_og = sdss_psf.psf_at_points(0, 0, psf_fit_file = psf_fit_file)

In [None]:
psf = torch.Tensor(simulated_datasets_lib._expand_psf(psf_og, full_image.shape[-1])).to(device)

In [None]:
recon_means, neg_logprob = psf_transform_lib.get_psf_loss(full_image, full_background,
                                            sampled_locs_full_image,
                                            sampled_fluxes_full_image,
                                            n_stars = sampled_n_stars_full,
                                            psf = psf,
                                            pad = 5)

In [None]:
for i in range(n_samples): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze())
    axarr[0].set_title('observed')

    axarr[1].matshow(recon_means[i].squeeze().detach())
    axarr[1].set_title('sample reconstruction '+ str(i))


    _resid = recon_means[i].squeeze().detach() - full_image.squeeze()
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('sample residual ' + str(i))



# Check losses match

In [None]:
i = 0

In [None]:
pad = 5
slen = full_image.shape[-1]

for i in range(n_samples): 
    _full_image = full_image.squeeze()[pad:(slen - pad), pad:(slen - pad)]
    _recon_means = recon_means[i, 0, pad:(slen - pad), pad:(slen - pad)].clamp(min = 100)

    foo = -utils.eval_normal_logprob(_full_image,
                    _recon_means,
                    torch.log(_recon_means)).sum()

    assert foo == neg_logprob[i]

# Check loss at map value

In [None]:
# sample variational parameters
map_locs_full_image, map_fluxes_full_image, map_n_stars_full, \
    log_q_locs, log_q_fluxes, log_q_n_stars = \
        star_encoder.sample_star_encoder(full_image, full_background,
                                n_samples = 1, return_map = True,
                                return_log_q = True)

In [None]:
map_recon_mean, map_neg_logprob = psf_transform_lib.get_psf_loss(full_image, full_background,
                                            map_locs_full_image,
                                            map_fluxes_full_image,
                                            n_stars = map_n_stars_full,
                                            psf = psf,
                                            pad = 5)

In [None]:
map_neg_logprob

In [None]:
neg_logprob

# Check loss at true value

In [None]:
true_n_stars = (sdss_hubble_data.fluxes > 0).sum()

In [None]:
true_recon_mean, true_neg_logprob = psf_transform_lib.get_psf_loss(full_image, full_background,
                                            sdss_hubble_data.locs.unsqueeze(0),
                                            sdss_hubble_data.fluxes.unsqueeze(0),
                                            n_stars = torch.Tensor([true_n_stars]).type(torch.long),
                                            psf = psf,
                                            pad = 5)

In [None]:
    fig, axarr = plt.subplots(1, 3, figsize=(15, 6))

    axarr[0].matshow(full_image.squeeze())
    axarr[0].set_title('observed')

    axarr[1].matshow(true_recon_mean.squeeze().detach())
    axarr[1].set_title('sample reconstruction '+ str(i))


    _resid = true_recon_mean.squeeze().detach() - full_image.squeeze()
    vmax = _resid.abs().max()
    im2 = axarr[2].matshow(_resid, vmax = vmax, vmin = -vmax, 
                           cmap=plt.get_cmap('bwr'))
    fig.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('sample residual ' + str(i))



In [None]:
true_neg_logprob