results after 200 epochs; commit number "record results of current encoder"

In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib

import inv_KL_objective_lib as objectives_lib

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2566/6/65/psField-002566-6-0065.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(22)
_ = torch.manual_seed(22)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 4
data_params['sky_intensity'] = 701.4577

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
batchsize = 1000

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = batchsize,
                            add_noise = True)


In [None]:
# true parameters
loader = torch.utils.data.DataLoader(
                 dataset=simulated_dataset,
                 batch_size=batchsize,
                 shuffle=False)

for _, data in enumerate(loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    true_n_stars = data['n_stars']
    images = data['image']
    backgrounds = data['background']
    
    break

In [None]:
true_is_on = simulated_datasets_lib.get_is_on_from_n_stars(true_n_stars, max_stars)

In [None]:
images.shape

In [None]:
# histogram of fluxes
plt.hist(np.log10(true_fluxes.numpy().flatten()));

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
        
    # image 
    axarr[i1, i2].matshow(images[i, 0, :, :] - backgrounds[i, 0, :, :])
    axarr[i1, i2].set_title('n_stars: {}\n'.format(true_n_stars[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[i1, i2].scatter(x = locs_x, y = locs_y, color = 'b')


# Load VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(data_params['slen'], 
                                           n_bands = 1, 
                                          max_detections = max_stars)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/saved_results/starnet_encoder_sleep0_four_stars', 
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
# objectives_lib.eval_star_encoder_loss(star_encoder, loader, train = False)

In [None]:
loss, counter_loss, locs_loss, fluxes_loss, perm = \
    objectives_lib.get_encoder_loss(star_encoder, images, backgrounds, true_locs,
                        true_fluxes, true_n_stars)
    
print(loss)

In [None]:
counter_loss.mean()

In [None]:
plt.hist(locs_loss.detach().numpy().flatten());

In [None]:
plt.hist(fluxes_loss.detach().numpy().flatten()); 

# check parameters

In [None]:
# permute true parameters 
def permute_params(true_locs, true_fluxes, perm): 
    batchsize = true_locs.shape[0]
    max_stars = true_locs.shape[1]

    locs_perm = torch.zeros((batchsize, max_stars, 2))
    fluxes_perm = torch.zeros((batchsize, max_stars))
    seq_tensor = torch.LongTensor([i for i in range(batchsize)])

    for i in range(max_stars):
        locs_perm[:, i, :] = true_locs[seq_tensor, perm[:, i], :]
        fluxes_perm[:, i] = true_fluxes[seq_tensor, perm[:, i]]
        
    return locs_perm, fluxes_perm

In [None]:
locs_perm, fluxes_perm = permute_params(true_locs, true_fluxes, perm)

In [None]:
# get variational parameters
logit_loc_mean, logit_loc_log_var, \
        log_flux_mean, log_flux_log_var, log_probs = star_encoder(images, backgrounds, true_n_stars)

In [None]:
foo = log_flux_log_var.flatten().detach().numpy()
plt.hist(foo[foo != 0], bins = 100);

In [None]:
foo = logit_loc_log_var.flatten().detach().numpy()
plt.hist(foo[foo != 0], bins = 100);

In [None]:
map_locs = torch.sigmoid(logit_loc_mean)
map_fluxes = torch.exp(log_flux_mean)

In [None]:
# error in locs
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    plt.plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 0].detach().numpy(), '+', color = 'blue')
    plt.plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 0].detach().numpy(), color = 'red')
    plt.xlabel('Estimated x coordinate', fontsize = 16)
    plt.ylabel('True x coordinate', fontsize = 16)


In [None]:
# error in locs
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 1].detach().numpy(), '+', color = 'blue')
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 1].detach().numpy(), color = 'red')
    plt.xlabel('Estimated y coordinate', fontsize = 16)
    plt.ylabel('True y coordinate', fontsize = 16)



In [None]:
for i in range(max_stars): 
   
    is_on_i = true_is_on[:, i]
#     plt.figure()
    
    plt.plot(np.log(map_fluxes[is_on_i == 1, i].detach().numpy()), 
             np.log(fluxes_perm[is_on_i == 1, i].detach().numpy()), '+', color = 'blue')
    
    plt.plot(np.log(map_fluxes[is_on_i == 1, i].detach().numpy()), 
             np.log(map_fluxes[is_on_i == 1, i].detach().numpy()), color = 'red')
    
    plt.xlabel('Estimated flux', fontsize = 16)
    plt.ylabel('True flux', fontsize = 16)
    



# Check reconstructions 

In [None]:
import plotting_utils

In [None]:
indx = np.arange(0, 20)
plotting_utils.print_results(star_encoder, 
                                images[indx], 
                                backgrounds[indx], 
                                simulated_dataset.simulator.psf, 
                                true_locs[indx], 
                                true_n_stars[indx],
                                use_true_n_stars = False)

In [None]:
indx = np.arange(20, 40)
plotting_utils.print_results(star_encoder, 
                                images[indx], 
                                backgrounds[indx], 
                                simulated_dataset.simulator.psf, 
                                true_locs[indx], 
                                true_n_stars[indx],
                                use_true_n_stars = False)

# check out deblending properties

In [None]:
n_trials = 10

_n_stars = (torch.ones(n_trials) * 2).type(torch.LongTensor)
_fluxes = torch.ones(n_trials, max_stars) * simulated_dataset.f_min * 100

_locs = torch.rand(n_trials, max_stars, 2)

dist = 0.0
incr = 0.01
for i in range(_locs.shape[0]):
    dist = dist + incr
    _locs[i, 0, :] = 0.5 + dist
    _locs[i, 1,:] = 0.5 - dist

In [None]:
_images = simulated_dataset.simulator.draw_image_from_params(_locs, _fluxes, _n_stars,
                                                             add_noise = False)

_backgrounds = torch.ones(10, 1, 1, 1) * simulated_dataset.sky_intensity

In [None]:
plotting_utils.print_results(star_encoder, 
                            _images, 
                            _backgrounds, 
                            simulated_dataset.simulator.psf, 
                            _locs,
                            _n_stars, 
                            use_true_n_stars = False)

In [None]:
import sdss_dataset_lib

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   slen = 11, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65, 
                                                max_detections = max_stars)

In [None]:
len(sdss_hubble_data)

In [None]:
# true parameters
hubble_loader = torch.utils.data.DataLoader(
                 dataset=sdss_hubble_data,
                 batch_size=len(sdss_hubble_data),
                 shuffle=False)

for _, data in enumerate(hubble_loader):
    hubble_fluxes = data['fluxes'].type(torch.float)
    hubble_locs = data['locs'].type(torch.float)
    hubble_n_stars = data['n_stars']
    sdss_images = data['image']
    sdss_backgrounds = data['background']
    
    break

In [None]:
plt.hist(torch.log10(images - simulated_dataset.sky_intensity).flatten());

In [None]:
plt.hist(torch.log10(sdss_images - sdss_backgrounds).flatten());

In [None]:
indx = np.arange(0, 10)
plotting_utils.print_results(star_encoder, 
                                sdss_images[indx],
                                sdss_backgrounds[indx],
                                simulated_dataset.simulator.psf, 
                                hubble_locs[indx],
                                hubble_n_stars[indx], 
                                use_true_n_stars = False, 
                                residual_clamp = 1e16)