In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import fitsio

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_objective_lib

import utils

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_g = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-g.fits')[0].read()
psf_og = np.array([psf_r, psf_g])

n_bands = psf_og.shape[0]

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)


data_params['f_max'] = 1e5

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 4

star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                            data_params,
                            n_images = n_images,
                            sky_intensity = torch.Tensor([686, 1000]),
                            add_noise = True)

In [None]:
# get loader 
batchsize = n_images

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

loader.dataset.set_params_and_images()

In [None]:
for _, data in enumerate(loader):
    true_full_fluxes = data['fluxes']
    true_full_locs = data['locs']
    full_images = data['image']
    full_backgrounds = data['background']
        
    break

In [None]:
data['n_stars']

In [None]:
for i in range(4): 
    plt.matshow(full_images[i, 0, :, :] - full_backgrounds[i, 0, :, :])
    plt.colorbar()
    print((full_images[i, 0, :, :] - full_backgrounds[i, 0, :, :]).min())

# Define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = n_bands,
                                           max_detections = 2)

# Check my extraction of subimages

In [None]:
image_stamps, subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    star_encoder.get_image_stamps(full_images, true_full_locs, true_full_fluxes, trim_images=False)

In [None]:
plt.hist(n_stars, bins = np.arange(min(n_stars), max(n_stars) + 2))

In [None]:
# check these two quantities match 
is_on_array2 = utils.get_is_on_from_n_stars(n_stars, max(n_stars))
assert torch.all(is_on_array2 == is_on_array)

# check total number of stars

This asserts that we indeed have a covering of stars

In [None]:
for i in range(n_images):
    _true_locs = true_full_locs[i] * (data_params['slen'] - 1) 

    pad = np.float(star_encoder.edge_padding)

    n_true_stars = torch.sum((_true_locs[:, 0] > pad) & (_true_locs[:, 1] > pad) & \
                (_true_locs[:, 0] < (data_params['slen'] - pad - 1))& \
                (_true_locs[:, 1] < (data_params['slen']  - pad - 1)))
    
    n_stamps_per_batch = star_encoder.tile_coords.shape[0]
    
    # check subimage n_stars add up
    assert torch.sum(n_stars[(i * n_stamps_per_batch):((i + 1) * n_stamps_per_batch)]) == n_true_stars
    
    # check number of nonzero fluxes add up
    assert (subimage_fluxes[(i * n_stamps_per_batch):((i + 1) * n_stamps_per_batch), :, 0] > 0).sum() == \
                n_true_stars
    assert (subimage_fluxes[(i * n_stamps_per_batch):((i + 1) * n_stamps_per_batch), :, 1] > 0).sum() == \
                n_true_stars
    
    # check number of nonzero locs add up
    assert (subimage_locs[(i * n_stamps_per_batch):((i + 1) * n_stamps_per_batch)] > 0).sum() == \
                (n_true_stars * 2)

### assert we have correct number and pattern of nonzero entries

In [None]:
assert torch.all((subimage_locs * is_on_array.unsqueeze(2).float()) == subimage_locs)
assert torch.all((subimage_fluxes * is_on_array.unsqueeze(2).float()) == subimage_fluxes)

In [None]:
assert torch.all((subimage_locs != 0).view(subimage_locs.shape[0], -1).float().sum(1) == \
                     n_stars.float() * 2)
assert torch.all((subimage_fluxes != 0).view(subimage_locs.shape[0], -1).float().sum(1) ==\
                     n_stars.float() * n_bands)

### this is what the NN sees

In [None]:
for i in range(10): 
    f, axarr = plt.subplots(1, 3, figsize=(12, 6))
    indx = int(np.random.choice(image_stamps.shape[0], 1))
    
    which_nonzero = is_on_array[indx].type(torch.bool)
    
    # Plot my image patch and subimage locs
    im1 = axarr[0].matshow(image_stamps[indx, 0].squeeze())
    patch_slen = star_encoder.stamp_slen - 2 * star_encoder.edge_padding
    axarr[0].scatter(subimage_locs[indx, which_nonzero, 1] * (patch_slen - 1) + star_encoder.edge_padding, 
                    subimage_locs[indx, which_nonzero, 0] * (patch_slen - 1) + star_encoder.edge_padding, 
                    color = 'b')
    
    # plot subset of full image and subst of full locs; They should match. 
    x0 = star_encoder.tile_coords[indx % star_encoder.tile_coords.shape[0], 0]
    x1 = star_encoder.tile_coords[indx % star_encoder.tile_coords.shape[0], 1]

    image_full_i = full_images[indx // star_encoder.tile_coords.shape[0]]
    image_patch_i = image_full_i[:, x0:(x0 + star_encoder.stamp_slen), 
                                         x1:(x1 + star_encoder.stamp_slen)]
    axarr[1].matshow(image_patch_i[0])
    
    # check images match
    assert torch.all((image_stamps[indx].squeeze() - image_patch_i) == 0)
    
    locs_i = true_full_locs[indx // star_encoder.tile_coords.shape[0]] * (star_encoder.full_slen - 1)
    
    which_locs = ((locs_i[:, 0] > x0.float()) & (locs_i[:, 1] > x1.float())) & \
                    (locs_i[:, 0] < (x0 + star_encoder.stamp_slen).float() - 1) & \
                    (locs_i[:, 1] < (x1 + star_encoder.stamp_slen).float() - 1)
            
    axarr[1].scatter(locs_i[which_locs, 1] - x1, 
               locs_i[which_locs, 0] - x0, 
               marker = 'o', color = 'b')

    axarr[1].axvline(x=star_encoder.edge_padding, color = 'r')
    axarr[1].axvline(x=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')
    axarr[1].axhline(y=star_encoder.edge_padding, color = 'r')
    axarr[1].axhline(y=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')
    
        
    axarr[0].axvline(x=star_encoder.edge_padding, color = 'r')
    axarr[0].axvline(x=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')
    axarr[0].axhline(y=star_encoder.edge_padding, color = 'r')
    axarr[0].axhline(y=star_encoder.stamp_slen - star_encoder.edge_padding - 1, color = 'r')
    
    
    #####
    # experimentation
#     foo = skimage.exposure.equalize_hist(image_full_i.squeeze().numpy())
#     im2 = axarr[2].matshow(foo[x0:(x0 + star_encoder.stamp_slen), 
#                         x1:(x1 + star_encoder.stamp_slen)])
#     f.colorbar(im2, ax = axarr[2])

# Check my output parameters

In [None]:
image_stamps, subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    star_encoder.get_image_stamps(full_images, true_full_locs, true_full_fluxes, 
                                  trim_images=False, clip_max_stars = True)

In [None]:
background_stamps = full_backgrounds.mean() # TODO

In [None]:
logit_loc_mean, logit_loc_log_var, \
    log_flux_mean, log_flux_log_var, log_probs = \
        star_encoder(image_stamps, background_stamps, n_stars)

In [None]:
star_encoder.locs_mean_indx_mat

In [None]:
star_encoder.locs_var_indx_mat

In [None]:
star_encoder.fluxes_mean_indx_mat

In [None]:
star_encoder.fluxes_var_indx_mat

### Check we have the correct number (and pattern) of nonzero entries

In [None]:
assert torch.all((logit_loc_mean * is_on_array.unsqueeze(2).float()) == logit_loc_mean)
assert torch.all((logit_loc_log_var * is_on_array.unsqueeze(2).float()) == logit_loc_log_var)


assert torch.all((log_flux_mean * is_on_array.unsqueeze(2).float()) == log_flux_mean)
assert torch.all((log_flux_log_var * is_on_array.unsqueeze(2).float()) == log_flux_log_var)

In [None]:
assert torch.all((logit_loc_mean != 0).view(logit_loc_mean.shape[0], -1).float().sum(1) == \
                     n_stars.float() * 2)
assert torch.all((logit_loc_log_var != 0).view(logit_loc_log_var.shape[0], -1).float().sum(1) == \
                     n_stars.float() * 2)

# assert torch.all((log_flux_mean != 0).float().sum(1) == n_stars.float())
# assert torch.all((log_flux_log_var != 0).float().sum(1) == n_stars.float())

In [None]:
# check sample star encoder

In [None]:
locs_full_image, fluxes_full_image, n_stars_full, \
    log_q_locs, log_q_fluxes, log_q_n_stars = \
        star_encoder.sample_star_encoder(full_images[0:1], full_backgrounds[0:1], 
                                 return_log_q=True, return_map = False, n_samples = 3)