In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib

import inv_KL_objective_lib as objectives_lib

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(22)
_ = torch.manual_seed(22)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 103
data_params['min_stars'] = 100
data_params['max_stars'] = 100
data_params['alpha'] = 0.5


In [None]:
max_stars = data_params['max_stars']

In [None]:
batchsize = 1

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_images = batchsize,
                            add_noise = True)


In [None]:
# true parameters
loader = torch.utils.data.DataLoader(
                 dataset=simulated_dataset,
                 batch_size=batchsize,
                 shuffle=False)

for _, data in enumerate(loader):
    true_full_fluxes = data['fluxes']
    true_full_locs = data['locs']
    images_full = data['image']
    
    break


In [None]:
# histogram of fluxes
plt.hist(np.log(np.log(true_full_fluxes.numpy().flatten())))

In [None]:
plt.matshow(images_full.squeeze());

# Load VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                            stamp_slen = 15,
                                            step = 8,
                                            edge_padding = 2, 
                                            n_bands = 1,
                                            max_detections = 15)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder_batched_images_100stars_default', 
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
objectives_lib.eval_star_encoder_loss(star_encoder, loader, train = False)

# Get inferred parameters

In [None]:
image_stamps, true_subimage_locs, true_subimage_fluxes, true_n_stars, is_on_array = \
        star_encoder.get_image_stamps(images_full, true_full_locs, true_full_fluxes, 
                                      trim_images = False)

In [None]:
logit_loc_mean, logit_loc_log_var, \
    log_flux_mean, log_flux_log_var, log_probs = \
        star_encoder(image_stamps, data_params['sky_intensity'], true_n_stars)

In [None]:
_backgrounds = torch.ones((image_stamps.shape[0], 1, 1, 1)) * data_params['sky_intensity']

In [None]:
loss, counter_loss, locs_loss, fluxes_loss, perm = \
    objectives_lib.get_encoder_loss(star_encoder, images_full, _backgrounds, true_full_locs, true_full_fluxes)

In [None]:
plt.hist(counter_loss.detach())

In [None]:
np.loadtxt('../fits/')

# check parameters

# Check reconstructions 

In [None]:
import plotting_utils

In [None]:
_psf = simulated_datasets_lib._trim_psf(simulated_dataset.simulator.psf, 
                                        star_encoder.stamp_slen - 2 * star_encoder.edge_padding)

In [None]:
indx = np.arange(0, 20)
plotting_utils.print_results(star_encoder, 
                                image_stamps[indx], 
                                _backgrounds[indx], 
                                _psf, 
                                true_subimage_locs[indx], 
                                is_on_array[indx],
                                use_true_n_stars = False)

In [None]:
indx = np.arange(20, 40)
plotting_utils.print_results(star_encoder, 
                                image_stamps[indx], 
                                _backgrounds[indx], 
                                _psf, 
                                true_subimage_locs[indx], 
                                is_on_array[indx],
                                use_true_n_stars = False)

# check out deblending properties

In [None]:
n_trials = 10

_n_stars = (torch.ones(n_trials) * 2).type(torch.LongTensor)
_fluxes = torch.ones(n_trials, max_stars) * simulated_dataset.f_min * 100

_locs = torch.rand(n_trials, max_stars, 2)

dist = 0.0
incr = 0.01
for i in range(_locs.shape[0]):
    dist = dist + incr
    _locs[i, 0, :] = 0.5 + dist
    _locs[i, 1,:] = 0.5 - dist

In [None]:
_images = simulated_dataset.simulator.draw_image_from_params(_locs, _fluxes, _n_stars,
                                                             add_noise = False)

_backgrounds = torch.ones(10, 1, 1, 1) * simulated_dataset.sky_intensity

In [None]:
plotting_utils.print_results(star_encoder, 
                            _images, 
                            _backgrounds, 
                            simulated_dataset.simulator.psf, 
                            _locs,
                            _n_stars, 
                            use_true_n_stars = False)

In [None]:
import sdss_dataset_lib

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   slen = 11, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65, 
                                                max_detections = max_stars)

In [None]:
len(sdss_hubble_data)

In [None]:
# true parameters
hubble_loader = torch.utils.data.DataLoader(
                 dataset=sdss_hubble_data,
                 batch_size=len(sdss_hubble_data),
                 shuffle=False)

for _, data in enumerate(hubble_loader):
    hubble_fluxes = data['fluxes'].type(torch.float)
    hubble_locs = data['locs'].type(torch.float)
    hubble_n_stars = data['n_stars']
    sdss_images = data['image']
    sdss_backgrounds = data['background']
    
    break

In [None]:
plt.hist(torch.log10(images - simulated_dataset.sky_intensity).flatten());

In [None]:
plt.hist(torch.log10(sdss_images - sdss_backgrounds).flatten());

In [None]:
indx = np.arange(0, 10)
plotting_utils.print_results(star_encoder, 
                                sdss_images[indx],
                                sdss_backgrounds[indx],
                                simulated_dataset.simulator.psf, 
                                hubble_locs[indx],
                                hubble_n_stars[indx], 
                                use_true_n_stars = False, 
                                residual_clamp = 1e16)