In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib

import objectives_lib

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(21)
_ = torch.manual_seed(21)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 20

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
batchsize = 200

star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = batchsize,
                            use_fresh_data = False, 
                            add_noise = True)

num_unlabeled = star_dataset.fluxes.shape[0]
print('num unlabeled', num_unlabeled)

In [None]:
# true parameters
test_loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

for _, data in enumerate(test_loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    true_n_stars = data['n_stars']
    images = data['image']
    backgrounds = data['background']
    
    break

In [None]:
true_is_on = simulated_datasets_lib.get_is_on_from_n_stars(true_n_stars, max_stars)

In [None]:
images.shape

In [None]:
# histogram of fluxes
plt.hist(np.log10(true_fluxes.numpy().flatten()));

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
        
    # image 
    axarr[i1, i2].matshow(images[i, 0, :, :] - backgrounds[i, 0, :, :])
    axarr[i1, i2].set_title('n_stars: {}\n'.format(true_n_stars[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[i1, i2].scatter(x = locs_x, y = locs_y, color = 'b')


# Load VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(data_params['slen'], 
                                           n_bands = 1, 
                                          max_detections = max_stars)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder_twenty_stars_experimental', 
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
# objectives_lib.eval_star_encoder_loss(star_encoder, test_loader, train = False)

In [None]:
loss, locs_loss, fluxes_loss, perm = \
    objectives_lib.get_encoder_loss(star_encoder, images, backgrounds, true_locs,
                        true_fluxes, true_n_stars)
    
print(loss)

In [None]:
plt.hist(locs_loss.detach().numpy().flatten());

In [None]:
plt.hist(fluxes_loss.detach().numpy().flatten()); 

# check parameters

In [None]:
# permute true parameters 
def permute_params(true_locs, true_fluxes, perm): 
    batchsize = true_locs.shape[0]
    max_stars = true_locs.shape[1]

    locs_perm = torch.zeros((batchsize, max_stars, 2))
    fluxes_perm = torch.zeros((batchsize, max_stars))
    seq_tensor = torch.LongTensor([i for i in range(batchsize)])

    for i in range(max_stars):
        locs_perm[:, i, :] = true_locs[seq_tensor, perm[:, i], :]
        fluxes_perm[:, i] = true_fluxes[seq_tensor, perm[:, i]]
        
    return locs_perm, fluxes_perm

In [None]:
locs_perm, fluxes_perm = permute_params(true_locs, true_fluxes, perm)

In [None]:
# get variational parameters
logit_loc_mean, logit_loc_log_var, \
        log_flux_mean, log_flux_log_var = star_encoder(images, backgrounds, true_n_stars)

In [None]:
foo = log_flux_log_var.flatten().detach().numpy()
plt.hist(foo[foo != 0], bins = 100);

In [None]:
foo = logit_loc_log_var.flatten().detach().numpy()
plt.hist(foo[foo != 0], bins = 100);

In [None]:
map_locs = torch.sigmoid(logit_loc_mean)
map_fluxes = torch.exp(log_flux_mean)

In [None]:
# error in locs
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    plt.plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 0].detach().numpy(), '+', color = 'blue')
    plt.plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 0].detach().numpy(), color = 'red')
    plt.xlabel('Estimated x coordinate', fontsize = 16)
    plt.ylabel('True x coordinate', fontsize = 16)


In [None]:
# error in locs
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 1].detach().numpy(), '+', color = 'blue')
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 1].detach().numpy(), color = 'red')
    plt.xlabel('Estimated y coordinate', fontsize = 16)
    plt.ylabel('True y coordinate', fontsize = 16)



In [None]:
for i in range(max_stars): 
   
    is_on_i = true_is_on[:, i]
#     plt.figure()
    
    plt.plot(np.log(map_fluxes[is_on_i == 1, i].detach().numpy()), 
             np.log(fluxes_perm[is_on_i == 1, i].detach().numpy()), '+', color = 'blue')
    
    plt.plot(np.log(map_fluxes[is_on_i == 1, i].detach().numpy()), 
             np.log(map_fluxes[is_on_i == 1, i].detach().numpy()), color = 'red')
    
    plt.xlabel('Estimated flux', fontsize = 16)
    plt.ylabel('True flux', fontsize = 16)
    



# Check reconstructions 

## These are results conditional on the true number of stars

In [None]:
recon_images = star_dataset.draw_image_from_params(map_locs, map_fluxes, true_n_stars, add_noise = False)

In [None]:
for i in range(30, 50): 
    _, axarr = plt.subplots(1, 3, figsize=(12, 4))

    # observed image 
    axarr[0].matshow(images[i, 0, :, :])
    axarr[0].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    
    # plot estimated locations 
    map_locs_i = map_locs[i].detach()
    n_stars_i = true_n_stars[i]
    est_locs_y = (map_locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    est_locs_x = (map_locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = est_locs_x, y = est_locs_y, c = 'r', marker = 'x')
    
    # plot posterior samples
    axarr[1].matshow(images[i, 0, :, :].detach())
    axarr[1].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)
    

    for k in range(int(n_stars_i)): 
        samples = torch.sigmoid(torch.sqrt(torch.exp(logit_loc_log_var[i, k, :])) * \
                      torch.randn((1000, 2)) + logit_loc_mean[i, k, :]).detach()
        
        axarr[1].scatter(x = samples[:, 1] * (images.shape[-1] - 1) , 
                         y = samples[:, 0] * (images.shape[-1] - 1) , 
                         c = 'r', marker = 'x', alpha = 0.05)
    axarr[1].scatter(x = locs_x, y = locs_y, c = 'b')
    
    # plot residuals
    axarr[2].matshow(images[i, 0, :, :] - recon_images[i, 0, :, :].detach())
    axarr[2].get_xaxis().set_visible(False)
    axarr[2].get_yaxis().set_visible(False)
    axarr[2].set_title('residuals')
    


In [None]:
# results for a specific n_stars
for i in range(0, 200): 
    
    if true_n_stars[i] != 8: 
        continue 
        
    _, axarr = plt.subplots(1, 3, figsize=(12, 4))

    # observed image 
    axarr[0].matshow(images[i, 0, :, :])
    axarr[0].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    
    # plot estimated locations 
    map_locs_i = map_locs[i].detach()
    n_stars_i = true_n_stars[i]
    est_locs_y = (map_locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    est_locs_x = (map_locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = est_locs_x, y = est_locs_y, c = 'r', marker = 'x')
    
    # plot posterior samples
    axarr[1].matshow(images[i, 0, :, :].detach())
    axarr[1].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)
    

    for k in range(int(n_stars_i)): 
        samples = torch.sigmoid(torch.sqrt(torch.exp(logit_loc_log_var[i, k, :])) * \
                      torch.randn((1000, 2)) + logit_loc_mean[i, k, :]).detach()
        
        axarr[1].scatter(x = samples[:, 1] * (images.shape[-1] - 1) , 
                         y = samples[:, 0] * (images.shape[-1] - 1) , 
                         c = 'r', marker = 'x', alpha = 0.05)
    axarr[1].scatter(x = locs_x, y = locs_y, c = 'b')
    
    # plot residuals
    axarr[2].matshow(images[i, 0, :, :] - recon_images[i, 0, :, :].detach())
    axarr[2].get_xaxis().set_visible(False)
    axarr[2].get_yaxis().set_visible(False)
    axarr[2].set_title('residuals')
    


# Combined encoder and counter results

### Load star counter

In [None]:
star_counter = starnet_vae_lib.StarCounter(data_params['slen'],
                                                n_bands = 1, 
                                               max_detections = data_params['max_stars'])

In [None]:
star_counter.load_state_dict(torch.load('../fits/starnet_invKL_counter_twenty_stars_experimental',
                               map_location=lambda storage, loc: storage))
star_counter.eval(); 

### get results from counter and encoder combined 

In [None]:
log_probs = star_counter(images)

In [None]:
map_n_stars = torch.argmax(log_probs, dim = 1)

In [None]:
# get variational parameters
logit_loc_mean, logit_loc_log_var, \
        log_flux_mean, log_flux_log_var = star_encoder(images, backgrounds, map_n_stars)

In [None]:
map_locs = torch.sigmoid(logit_loc_mean)
map_fluxes = torch.exp(log_flux_mean)

In [None]:
recon_images = star_dataset.draw_image_from_params(map_locs, map_fluxes, map_n_stars, add_noise = False)

In [None]:
def plot_categorical_probs(log_prob_vec, fig):
    n_cat = len(log_prob_vec)
    points = [(i, torch.exp(log_prob_vec[i])) for i in range(n_cat)] 

    for pt in points:
        # plot (x,y) pairs.
        # vertical line: 2 x,y pairs: (a,0) and (a,b)
        plt.plot([pt[0],pt[0]], [0,pt[1]], color = 'blue')

    fig.plot(np.arange(n_cat), 
             torch.exp(log_prob_vec).detach().numpy(), 
             'o', markersize = 5, color = 'blue')


In [None]:
for i in range(30, 50): 
    _, axarr = plt.subplots(1, 3, figsize=(12, 4))

    # observed image 
    axarr[0].matshow(images[i, 0, :, :])
    
    # plot true locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    
    # plot estimated locations 
    map_locs_i = map_locs[i].detach()
    map_n_stars_i = map_n_stars[i]
    est_locs_y = (map_locs_i[0:int(map_n_stars_i), 0]) * (images.shape[-1] - 1) 
    est_locs_x = (map_locs_i[0:int(map_n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = est_locs_x, y = est_locs_y, c = 'r', marker = 'x')
    
    axarr[0].set_title('Observed image \n est/true n_stars: {} / {}'.format(map_n_stars_i, int(n_stars_i)))
    
    # plot residuals
    axarr[1].matshow(images[i, 0, :, :] - recon_images[i, 0, :, :].detach())
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)
    axarr[1].set_title('residuals')
    
    # plot uncertainty in number of stars
    plot_categorical_probs(log_probs[i], axarr[2])
    axarr[2].set_title('estimated distribution on n_stars')
    axarr[2].plot(np.ones(100) * int(n_stars_i) + 0.05, 
                  np.linspace(start = 0, 
                              stop = np.max(np.exp(log_probs.detach().numpy()[i])), num = 100), 
                  color = 'red', 
                  linestyle = '--')
    


In [None]:
# results for a specific n_stars
for i in range(0, 200): 
    
    if true_n_stars[i] != 8: 
        continue 
        
    _, axarr = plt.subplots(1, 3, figsize=(12, 4))

    # observed image 
    axarr[0].matshow(images[i, 0, :, :])
    axarr[0].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    
    # plot estimated locations 
    map_locs_i = map_locs[i].detach()
    n_stars_i = true_n_stars[i]
    est_locs_y = (map_locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    est_locs_x = (map_locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = est_locs_x, y = est_locs_y, c = 'r', marker = 'x')
    
    # plot posterior samples
    axarr[1].matshow(images[i, 0, :, :].detach())
    axarr[1].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)
    

    for k in range(int(n_stars_i)): 
        samples = torch.sigmoid(torch.sqrt(torch.exp(logit_loc_log_var[i, k, :])) * \
                      torch.randn((1000, 2)) + logit_loc_mean[i, k, :]).detach()
        
        axarr[1].scatter(x = samples[:, 1] * (images.shape[-1] - 1) , 
                         y = samples[:, 0] * (images.shape[-1] - 1) , 
                         c = 'r', marker = 'x', alpha = 0.05)
    axarr[1].scatter(x = locs_x, y = locs_y, c = 'b')
    
    # plot residuals
    axarr[2].matshow(images[i, 0, :, :] - recon_images[i, 0, :, :].detach())
    axarr[2].get_xaxis().set_visible(False)
    axarr[2].get_yaxis().set_visible(False)
    axarr[2].set_title('residuals')
    
