In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import inv_KL_objective_lib as objectives_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 103
data_params['min_stars'] = 400
data_params['max_stars'] = 400
data_params['alpha'] = 0.5

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 1

star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_images = n_images,
                            add_noise = True)

num_unlabeled = star_dataset.fluxes.shape[0]
print('num unlabeled', num_unlabeled)

In [None]:
# get loader 
batchsize = n_images

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

loader.dataset.set_params_and_images()

In [None]:
for _, data in enumerate(loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    images = data['image']
    backgrounds = data['background']

In [None]:
images.shape

In [None]:
i = np.random.choice(batchsize, 1)[0]

In [None]:
for i in range(1): 
    # observed image 
    plt.matshow(images[i, 0, :, :])

    # plot locations 
#     locs_i = true_locs[i]
#     n_stars_i = true_n_stars[i]
#     fluxes_i = true_fluxes[i]
#     locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
#     locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)

#     plt.scatter(x = locs_x, y = locs_y, c = 'b')
    
#     plt.title('Observed image; nstars {}\n'.format(n_stars_i))


# Pass through VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'], 
                                        stamp_slen = 15,
                                        step = 8,
                                        edge_padding = 2, 
                                        n_bands = 1, 
                                        max_detections = 20)

### Check my extraction of subimages

In [None]:
image_stamps, subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    star_encoder.get_image_stamps(images, true_locs, true_fluxes, trim_images=True)

In [None]:
# get reconstruction with the subimage parameters 
patch_simulator = simulated_datasets_lib.StarSimulator(psf_fit_file, 
                                                       star_encoder.stamp_slen - 2 * star_encoder.edge_padding, 
                                                       data_params['sky_intensity'])

_n_stars = data_params['max_stars'] * torch.ones(subimage_locs.shape[0]).type(torch.LongTensor)

recon_means = patch_simulator.draw_image_from_params(subimage_locs, 
                                                subimage_fluxes, 
                                                _n_stars, 
                                                add_noise = False)

In [None]:
for i in range(10): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 6))
    indx = int(np.random.choice(image_stamps.shape[0], 1))
    
    which_nonzero = is_on_array[indx]
    
    im1 = axarr[0].matshow(image_stamps[indx].squeeze())
    patch_slen = star_encoder.stamp_slen - 2 * star_encoder.edge_padding
    axarr[0].scatter(subimage_locs[indx, which_nonzero, 1] * (patch_slen - 1), 
                    subimage_locs[indx, which_nonzero, 0] * (patch_slen - 1))
    f.colorbar(im1, ax = axarr[0])
    
    axarr[0].set_title('n_stars: {}\n'.format(n_stars[indx]))
    
    im2 = axarr[1].matshow(recon_means[indx].squeeze())
    f.colorbar(im2, ax = axarr[1])
    
    im3 = axarr[2].matshow(image_stamps[indx].squeeze() - recon_means[indx].squeeze())
    f.colorbar(im3, ax = axarr[2])

### Check my parameters

In [None]:
background_stamps = backgrounds.mean() # TODO

In [None]:
image_stamps, subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    star_encoder.get_image_stamps(images, true_locs, true_fluxes, trim_images=False)

In [None]:
logit_loc_mean, logit_loc_log_var, \
    log_flux_mean, log_flux_log_var, log_probs = \
        star_encoder(image_stamps, background_stamps, n_stars)

In [None]:
logit_loc_mean.shape

In [None]:
logit_loc_mean

In [None]:
locs_log_probs_all = \
        objectives_lib.get_locs_logprob_all_combs(subimage_locs,
                                    logit_loc_mean,
                                    logit_loc_log_var)


In [None]:
locs_log_probs_all.shape

In [None]:
flux_log_probs_all = \
        objectives_lib.get_fluxes_logprob_all_combs(subimage_fluxes, \
                                    log_flux_mean, log_flux_log_var)

In [None]:
flux_log_probs_all.shape

In [None]:
is_on_array.shape

In [None]:
import hungarian_alg

In [None]:
perm = hungarian_alg.run_batch_hungarian_alg_parallel(locs_log_probs_all, is_on_array)

In [None]:
is_on_array.shape

In [None]:
locs_log_probs_all.shape

In [None]:
perm.shape

In [None]:
locs_log_probs_all.shape

In [None]:
objectives_lib._permute_losses_mat(locs_log_probs_all, perm).shape

In [None]:
objectives_lib._permute_losses_mat(locs_log_probs_all, perm).shape

In [None]:
objectives_lib.get_encoder_loss(star_encoder, 
                                images, 
                                backgrounds, 
                               true_locs, 
                               true_fluxes)

In [None]:
objectives_lib.eval_star_encoder_loss(star_encoder, loader)

In [None]:
plt.hist(n_stars)

In [None]:
len(loader)

In [None]:
image = image_stamps; background = 686.

In [None]:
log_img = torch.log(image - background + 1000)

# means = log_img.view(image.shape[0], self.n_bands, -1).mean(-1)
# stds = log_img.view(image.shape[0], self.n_bands, -1).std(-1)
mins = log_img.view(image.shape[0], 1, -1).min(-1)[0]
maxes = log_img.view(image.shape[0], 1, -1).max(-1)[0]

In [None]:
mins[0]

In [None]:
maxes[0]

In [None]:
foo = (log_img - mins.unsqueeze(-1).unsqueeze(-1)) / (maxes - mins).unsqueeze(-1).unsqueeze(-1)

In [None]:
torch.min(foo[0])

In [None]:
torch.max(foo[0])

In [None]:
foo = np.loadtxt('../fits/test_losses')

In [None]:
foo[:, 6]