In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')
import sdss_psf
import star_datasets_lib
import starnet_vae_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 4

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_stars = 1024

star_dataset = \
    star_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = n_stars,
                            use_fresh_data = False, 
                            add_noise = True)

num_unlabeled = star_dataset.fluxes.shape[0]
print('num unlabeled', num_unlabeled)

In [None]:
# get loader 
batchsize = n_stars

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

loader.dataset.set_params_and_images()

In [None]:
for _, data in enumerate(loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    true_n_stars = data['n_stars']
    images = data['image']
    
    break

In [None]:
i = np.random.choice(batchsize, 1)[0]

In [None]:
for i in range(10): 
    # observed image 
    plt.matshow(images[i, 0, :, :])

    # plot locations 
    locs_i = true_locs[i]
    n_stars_i = true_n_stars[i]
    fluxes_i = true_fluxes[i]
    locs_x = (locs_i[0:int(n_stars_i), 0]) * (images.shape[-1] - 1) 
    locs_y = (locs_i[0:int(n_stars_i), 1]) * (images.shape[-1] - 1)

    plt.scatter(x = locs_x, y = locs_y, c = 'b')
    
    plt.title('Observed image; nstars {}\n'.format(n_stars_i))


# Check Counter

In [None]:
star_counter = starnet_vae_lib.StarCounter(slen = data_params['slen'], 
                            n_bands = 1,
                            max_detections=max_stars)

In [None]:
log_probs = star_counter(images)

In [None]:
log_probs.size()

In [None]:
torch.exp(log_probs).sum(1)