In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')
import sdss_psf
import star_datasets_lib
import starnet_vae_lib

import objectives_lib
import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)


In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(32)
_ = torch.manual_seed(54)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 4


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 2048

star_dataset = \
    star_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = n_images,
                            use_fresh_data = False, 
                            add_noise = True)


In [None]:
# true parameters
batchsize = n_images

test_loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

for _, data in enumerate(test_loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    true_n_stars = data['n_stars']
    images = data['image']
    
    break

In [None]:
true_is_on = star_datasets_lib.get_is_on_from_n_stars(true_n_stars, max_stars)

In [None]:
images.shape

In [None]:
# histogram of fluxes
plt.hist(true_fluxes.numpy().flatten()); 

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
        
    # image 
    axarr[i1, i2].matshow(images[i, 0, :, :])
    axarr[i1, i2].set_title('n_stars: {}\n'.format(true_n_stars[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[i1, i2].scatter(x = locs_x, y = locs_y, color = 'b')


# Load VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(data_params['slen'], 
                                           n_bands = 1, 
                                          max_detections = max_stars)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder',
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
objectives_lib.eval_star_encoder_loss(star_encoder, test_loader, train = False)

# elbo_objective_lib.eval_star_encoder_kl_loss(star_encoder, test_loader, train = False)

In [None]:
loss, perm = objectives_lib.get_encoder_loss(star_encoder, images, true_locs,
                        true_fluxes, true_n_stars)
    
print(loss.mean())

In [None]:
plt.hist(loss.detach().numpy().flatten());

In [None]:
logit_loc_mean, logit_loc_logvar, \
    log_flux_mean, log_flux_logvar = \
        star_encoder(images, true_n_stars)

In [None]:
for i in range(0, 20): 
    _, axarr = plt.subplots(1, 3, figsize=(12, 4))
    
    ##################
    # MAP ESTIMATES
    ##################
    # observed image 
    axarr[0].matshow(images[i, 0, :, :])
    
    # plot true locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    
    # plot estimated locations 
    est_locs_i = torch.sigmoid(logit_loc_mean[i, :, :]).detach()
    est_locs_x = (est_locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    est_locs_y = (est_locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
        
    colors = ['red', 'green', 'purple', 'cyan']
    axarr[0].scatter(x = est_locs_x, y = est_locs_y, c = colors[0:int(n_stars_i)], marker = 'x')
    
    ###################
    # look at samples 
    ###################
    # observed image 
    axarr[1].matshow(images[i, 0, :, :])
        
    # SAMPLES 
    logit_loc_mean_i = logit_loc_mean[i, :, :]
    logit_loc_logvar_i = logit_loc_logvar[i, :, :]
    sample_locs_i = torch.sigmoid(logit_loc_mean_i.unsqueeze(2) + \
                                  torch.randn(max_stars, 2, 100) * \
                                  torch.exp(0.5 * logit_loc_logvar_i.unsqueeze(2))).detach()
    sample_locs_x = (sample_locs_i[1, 0, :]) * (images.shape[-1] - 1) 
    sample_locs_y = (sample_locs_i[1, 1, :]) * (images.shape[-1] - 1)
    
    axarr[1].scatter(x = sample_locs_x, y = sample_locs_y, c = 'r', marker = 'x', alpha = 0.1)
    
    # plot true locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[1].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)



In [None]:
seq_tensor = torch.LongTensor([i for i in range(images.shape[0])])

In [None]:
plt.plot(logit_loc_mean[:, 0, 0].detach().numpy(), 
        objectives_lib._logit(true_locs)[seq_tensor, perm, 0].numpy(), '+')

plt.plot(logit_loc_mean[:, 0, 1].detach().numpy(), 
        objectives_lib._logit(true_locs)[seq_tensor, perm, 1].numpy(), '+')

In [None]:
plt.plot(log_flux_mean[true_n_stars > 0, 0].detach().numpy(), 
        torch.log(true_fluxes)[seq_tensor, perm][true_n_stars > 0].numpy(), '+')

plt.plot(log_flux_mean[true_n_stars > 0, 0].detach().numpy(), 
        log_flux_mean[true_n_stars > 0, 0].detach().numpy())

# check parameters

In [None]:
# permute true parameters 
def permute_params(true_locs, true_fluxes, perm): 
    batchsize = true_locs.shape[0]
    max_stars = true_locs.shape[1]

    locs_perm = torch.zeros((batchsize, max_stars, 2))
    fluxes_perm = torch.zeros((batchsize, max_stars))
    seq_tensor = torch.LongTensor([i for i in range(batchsize)])

    for i in range(max_stars):
        locs_perm[:, i, :] = true_locs[seq_tensor, perm[:, i], :]
        fluxes_perm[:, i] = true_fluxes[seq_tensor, perm[:, i]]
        
    return locs_perm, fluxes_perm

In [None]:
locs_perm, fluxes_perm = permute_params(true_locs, true_fluxes, perm)

In [None]:
# get variational parameters
logit_loc_mean, logit_loc_log_var, \
        log_flux_mean, log_flux_log_var = star_encoder(images, true_n_stars)

In [None]:
map_locs = torch.sigmoid(logit_loc_mean)
map_fluxes = torch.exp(log_flux_mean)

In [None]:
# error in locs
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    plt.plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 0].detach().numpy(), '+', color = 'blue', alpha = 0.3)
    plt.plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 0].detach().numpy(), color = 'red')
    plt.xlabel('Estimated x coordinate', fontsize = 16)
    plt.ylabel('True x coordinate', fontsize = 16)

# error in locs
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 1].detach().numpy(), '+', color = 'blue', alpha = 0.3)
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 1].detach().numpy(), color = 'red')
    plt.xlabel('Estimated location', fontsize = 16)
    plt.ylabel('True location', fontsize = 16)



In [None]:
# error in locs
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 1].detach().numpy(), '+', color = 'blue', alpha = 0.3)
    plt.plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 1].detach().numpy(), color = 'red')
    plt.xlabel('Estimated y coordinate', fontsize = 16)
    plt.ylabel('True y coordinate', fontsize = 16)



In [None]:
    axarr[1].plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 locs_perm[is_on_i == 1, i, 1].detach().numpy(), '+')
    axarr[1].plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 1].detach().numpy())
    axarr[1].set_xlabel('Estimated location')
    axarr[1].set_ylabel('True location')

In [None]:
for i in range(max_stars): 
    
    is_on_i = true_is_on[:, i]
    
    
    axarr[0].plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                locs_perm[is_on_i == 1, i, 0].detach().numpy(), '+')
    axarr[0].plot(map_locs[is_on_i == 1, i, 0].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 0].detach().numpy())
    axarr[0].set_xlabel('pred')
    axarr[0].set_ylabel('truth')

    axarr[1].plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 locs_perm[is_on_i == 1, i, 1].detach().numpy(), '+')
    axarr[1].plot(map_locs[is_on_i == 1, i, 1].detach().numpy(), 
                 map_locs[is_on_i == 1, i, 1].detach().numpy())
    axarr[1].set_xlabel('pred')
    axarr[1].set_ylabel('truth')

In [None]:
for i in range(max_stars): 
   
    is_on_i = true_is_on[:, i]
#     plt.figure()
    
    plt.plot(map_fluxes[is_on_i == 1, i].detach().numpy(), 
             fluxes_perm[is_on_i == 1, i].detach().numpy(), '+', color = 'blue', alpha = 0.3)
    
    plt.plot(map_fluxes[is_on_i == 1, i].detach().numpy(), 
             map_fluxes[is_on_i == 1, i].detach().numpy(), color = 'red')
    
    plt.xlabel('Estimated flux', fontsize = 16)
    plt.ylabel('True flux', fontsize = 16)
    



In [None]:
# Check reconstructions 

In [None]:
recon_images = star_dataset.draw_image_from_params(map_locs, map_fluxes, true_n_stars, add_noise = False)

In [None]:
for i in range(40, 60): 
    _, axarr = plt.subplots(1, 3, figsize=(12, 4))

    # observed image 
    axarr[0].matshow(images[i, 0, :, :])
    axarr[0].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    
    # plot estimated locations 
    map_locs_i = map_locs[i].detach()
    n_stars_i = true_n_stars[i]
    est_locs_x = (map_locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    est_locs_y = (map_locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = est_locs_x, y = est_locs_y, c = 'r', marker = 'x')
    
    # plot posterior samples
    axarr[1].matshow(images[i, 0, :, :].detach())
    axarr[1].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)
    

    for k in range(int(n_stars_i)): 
        samples = torch.sigmoid(torch.sqrt(torch.exp(logit_loc_log_var[i, k, :])) * \
                      torch.randn((1000, 2)) + logit_loc_mean[i, k, :]).detach()
        
        axarr[1].scatter(x = samples[:, 0] * (images.shape[-1] - 1) , 
                         y = samples[:, 1] * (images.shape[-1] - 1) , 
                         c = 'r', marker = 'x', alpha = 0.05)
    axarr[1].scatter(x = locs_x, y = locs_y, c = 'b')
    
    # plot residuals
    axarr[2].matshow(images[i, 0, :, :] - recon_images[i, 0, :, :].detach())
    axarr[2].get_xaxis().set_visible(False)
    axarr[2].get_yaxis().set_visible(False)
    axarr[2].set_title('residuals')
    


In [None]:
for i in range(0, 3): 
    _, axarr = plt.subplots(1, 2, figsize=(12, 6))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    

    # plot posterior samples
    axarr[0].matshow(images[i, 0, :, :].detach())
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    

    for k in range(int(n_stars_i)): 
        samples = torch.sigmoid(torch.sqrt(torch.exp(logit_loc_log_var[i, k, :])) * \
                      torch.randn((1000, 2)) + logit_loc_mean[i, k, :]).detach()
        
        axarr[0].scatter(x = samples[:, 0] * (images.shape[-1] - 1) , 
                         y = samples[:, 1] * (images.shape[-1] - 1) , 
                         c = 'r', marker = 'x', alpha = 0.05)
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    
    # plot residuals
    axarr[1].matshow(images[i, 0, :, :] - recon_images[i, 0, :, :].detach())
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)
    
    if i == 0: 
        axarr[0].set_title('Observed image'.format(locs_loss[i]), fontsize = 26)
        axarr[1].set_title('Residuals', fontsize = 26)
    
    


In [None]:
plt.hist(logit_loc_log_var.detach().numpy().flatten())

In [None]:
plt.hist(log_flux_log_var.detach().numpy().flatten())

In [None]:
plt.hist(fluxes_loss.detach()); 

In [None]:
# check where the locs are particularly bad

In [None]:
# bad_indx =  np.argwhere(locs_loss.detach().numpy() > 5).squeeze()
# bad_indx = np.argwhere(fluxes_loss.detach().numpy() > 2).squeeze()

bad_indx = np.argwhere(true_n_stars == 4).squeeze()

In [None]:
bad_indx

In [None]:
count = 0
for i in bad_indx: 
    _, axarr = plt.subplots(1, 3, figsize=(12, 4))

    # observed image 
    axarr[0].matshow(images[i, 0, :, :])
    axarr[0].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = locs_x, y = locs_y, c = 'b')
    axarr[0].get_xaxis().set_visible(False)
    axarr[0].get_yaxis().set_visible(False)
    
    # plot estimated locations 
    map_locs_i = map_locs[i].detach()
    n_stars_i = true_n_stars[i]
    est_locs_x = (map_locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    est_locs_y = (map_locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    
    axarr[0].scatter(x = est_locs_x, y = est_locs_y, c = 'r', marker = 'x')
    
    # plot posterior samples
    axarr[1].matshow(images[i, 0, :, :].detach())
    axarr[1].set_title('Observed image \n locs loss: {:.04f}'.format(locs_loss[i]))
    axarr[1].get_xaxis().set_visible(False)
    axarr[1].get_yaxis().set_visible(False)
    

    for k in range(int(n_stars_i)): 
        samples = torch.sigmoid(torch.sqrt(torch.exp(logit_loc_log_var[i, k, :])) * \
                      torch.randn((1000, 2)) + logit_loc_mean[i, k, :]).detach()
        
        axarr[1].scatter(x = samples[:, 0] * (images.shape[-1] - 1) , 
                         y = samples[:, 1] * (images.shape[-1] - 1) , 
                         c = 'r', marker = 'x', alpha = 0.01)
    axarr[1].scatter(x = locs_x, y = locs_y, c = 'b')
    
    # plot residuals
    axarr[2].matshow(images[i, 0, :, :] - recon_images[i, 0, :, :].detach())
    axarr[2].get_xaxis().set_visible(False)
    axarr[2].get_yaxis().set_visible(False)
    axarr[2].set_title('residuals')
    
    count += 1
    
    if count > 10: 
        break