In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib

import objectives_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)


In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2566/6/65/psField-002566-6-0065.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['max_stars'] = 20

print(data_params)


In [None]:
n_stars = 1024

star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = n_stars,
                            use_fresh_data = False, 
                            add_noise = True)

In [None]:
# true parameters
batchsize = n_stars
test_loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=True)

In [None]:
for _, data in enumerate(test_loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    true_n_stars = data['n_stars']
    images = data['image']
    backgrounds = data['background']
    
    break

In [None]:
plt.hist(true_fluxes[:, 0])

In [None]:
images.shape

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
        
    # image 
    axarr[i1, i2].matshow(images[i, 0, :, :])
    axarr[i1, i2].set_title('n_stars: {}\n'.format(true_n_stars[i]))
    
    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)
    axarr[i1, i2].scatter(x = locs_x, y = locs_y, c = 'b')


# Load VAE

In [None]:
star_counter = starnet_vae_lib.StarCounter(data_params['slen'],
                                                n_bands = 1, 
                                               max_detections = data_params['max_stars'])

In [None]:
star_counter.load_state_dict(torch.load('../fits/starnet_invKL_counter_twenty_stars_new_prior',
                               map_location=lambda storage, loc: storage))
star_counter.eval(); 

In [None]:
log_probs = star_counter(images, backgrounds)

# predictive accuracy

In [None]:
log_probs[:, 0]

In [None]:
map_n_stars = torch.argmax(log_probs, dim = 1)

In [None]:
(map_n_stars.type(torch.FloatTensor) == true_n_stars).float().mean()

In [None]:
# top three accuracy

topk_n_stars = torch.topk(log_probs, k = 3, dim = 1)[1]

((topk_n_stars.type(torch.FloatTensor)[:, 0] == true_n_stars) + \
(topk_n_stars.type(torch.FloatTensor)[:, 1] == true_n_stars) + \
(topk_n_stars.type(torch.FloatTensor)[:, 2] == true_n_stars)).float().mean()


In [None]:
for i in range(data_params['max_stars']): 
    error = (map_n_stars.type(torch.FloatTensor)[true_n_stars == i] == \
             true_n_stars[[true_n_stars == i]]).float().mean()
    
    print(error)

In [None]:
objectives_lib.eval_star_counter_loss(star_counter, test_loader, train = False)

# plot individual images

In [None]:
def plot_categorical_probs(log_prob_vec, fig):
    n_cat = len(log_prob_vec)
    points = [(i, torch.exp(log_prob_vec[i])) for i in range(n_cat)] 

    for pt in points:
        # plot (x,y) pairs.
        # vertical line: 2 x,y pairs: (a,0) and (a,b)
        plt.plot([pt[0],pt[0]], [0,pt[1]], color = 'blue')

    fig.plot(np.arange(n_cat), 
             torch.exp(log_prob_vec).detach().numpy(), 
             'o', markersize = 5, color = 'blue')


In [None]:
for indx in range(0, 10): 
    f, axarr = plt.subplots(1, 2, figsize=(8, 4))
    
    axarr[0].matshow(images[indx].squeeze())
    axarr[0].scatter(true_locs[indx, 0:int(true_n_stars[indx]), 1] * (images.shape[-1] - 1),
                  true_locs[indx, 0:int(true_n_stars[indx]), 0] * (images.shape[-1] - 1), 
                  marker = 'o', color = 'blue')
    
    title = 'truth: {:.02f} \n'.format(true_n_stars[indx]) + \
            'pred: {:.02f}'.format(map_n_stars[indx])
    
    axarr[0].set_title('truth: {:.02f} \n'.format(true_n_stars[indx]))
    
    plot_categorical_probs(log_probs[indx, :], axarr[1])
    axarr[1].set_title('pred: {:.02f}'.format(map_n_stars[indx]))

# Look at only the ones that are wrong

In [None]:
i = 0
for indx in range(100, 200): 
    if true_n_stars.long()[indx] == map_n_stars[indx]: 
        continue
    
    f, axarr = plt.subplots(1, 2, figsize=(8, 4))
    
    axarr[0].matshow(images[indx].squeeze())
    axarr[0].scatter(true_locs[indx, 0:int(true_n_stars[indx]), 1] * (images.shape[-1] - 1),
                  true_locs[indx, 0:int(true_n_stars[indx]), 0] * (images.shape[-1] - 1), 
                  marker = 'o', color = 'blue')
    
    title = 'truth: {:.02f} \n'.format(true_n_stars[indx]) + \
            'pred: {:.02f}'.format(map_n_stars[indx])
    
    axarr[0].set_title('truth: {:.02f} \n'.format(true_n_stars[indx]))
    
    plot_categorical_probs(log_probs[indx, :], axarr[1])
    axarr[1].set_title('pred: {:.02f} \n'.format(map_n_stars[indx]))
    
    i += 1
    if i > 10: 
        break