In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import sys
sys.path.insert(0, './../')
import sdss_psf

import sys
sys.path.insert(0, './../')
import star_datasets_lib
import starnet_vae_lib

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 4

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_stars = 1024

star_dataset = \
    star_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = n_stars,
                            use_fresh_data = False, 
                            add_noise = True)

num_unlabeled = star_dataset.fluxes.shape[0]
print('num unlabeled', num_unlabeled)

In [None]:
# get loader 
batchsize = n_stars

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

loader.dataset.set_params_and_images()

In [None]:
for _, data in enumerate(loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    true_n_stars = data['n_stars']
    images = data['image']
    
    break

In [None]:
i = np.random.choice(batchsize, 1)[0]

In [None]:
for i in range(10): 
    # observed image 
    plt.matshow(images[i, 0, :, :])

    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    fluxes_i = true_fluxes[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)

    plt.scatter(x = locs_x, y = locs_y, c = 'b')
    
    plt.title('Observed image; nstars {}\n'.format(n_stars_i))


# load VAE

In [None]:
star_rnn = starnet_vae_lib.StarRNN(n_bands=1, slen=data_params['slen'])

In [None]:
star_rnn.load_state_dict(torch.load('../fits/test_fit_one_detection', \
                               map_location=lambda storage, loc: storage))
star_rnn.eval(); 

In [None]:
import objectives_lib

In [None]:
loss, perm = objectives_lib.get_invKL_loss(star_rnn, images, true_fluxes, true_locs, true_n_stars)

In [None]:
loss.mean()

In [None]:
logit_locs_mean, logit_locs_logvar, \
            log_flux_mean, log_flux_logvar, prob_on = \
                star_rnn.forward_once(images, \
                                        h_i = torch.zeros(images.shape[0], 180))

In [None]:
for i in range(10): 
    # observed image 
    plt.matshow(images[i, 0, :, :])
    n_stars_i = true_n_stars[i]

    
    # plot estimated locations
    est_locs_i = torch.sigmoid(logit_locs_mean.detach()[i]) * (images.shape[-1] - 1) 
    plt.scatter(x = est_locs_i[0], y = est_locs_i[1], c = 'r')
    
    plt.scatter(x = torch.sigmoid(torch.randn(100) * torch.exp(0.5 * logit_locs_logvar.detach()[i, 0]) + \
                        logit_locs_mean.detach()[i][0]) * (images.shape[-1] - 1), 
                y = torch.sigmoid(torch.randn(100) * torch.exp(0.5 * logit_locs_logvar.detach()[i, 1]) + \
                        logit_locs_mean.detach()[i][1]) * (images.shape[-1] - 1), 
                color = 'red', marker = 'x')
    
    prob_on_i = np.round(prob_on[i].detach().numpy(), 5)
    plt.title('Observed image; nstars {}; \n prob on {}\n'.format(n_stars_i, prob_on_i))
    
    # plot true locations 
    locs_i = true_locs[i]
    fluxes_i = true_fluxes[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)

    plt.scatter(x = locs_x, y = locs_y, c = 'b')


In [None]:
log_flux_mean

In [None]:
seq_tensor = torch.LongTensor([i for i in range(len(log_flux_mean))])

In [None]:
plt.scatter(x = log_flux_mean.detach(), 
            y = torch.log(true_fluxes[seq_tensor, perm]))
plt.scatter(log_flux_mean.detach(), log_flux_mean.detach(), color = 'red')

In [None]:
log_flux_mean

In [None]:
torch.log(true_fluxes[seq_tensor, perm])