In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import inv_KL_objective_lib as objectives_lib

import KL_objective_lib 

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 31
data_params['min_stars'] = 200
data_params['max_stars'] = 200
data_params['alpha'] = 0.5

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 4

star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_images = n_images,
                            add_noise = True)

num_unlabeled = star_dataset.fluxes.shape[0]
print('num unlabeled', num_unlabeled)

In [None]:
# get loader 
batchsize = n_images

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=False)

loader.dataset.set_params_and_images()

In [None]:
for _, data in enumerate(loader):
    true_full_fluxes = data['fluxes']
    true_full_locs = data['locs']
    full_images = data['image']
    full_backgrounds = data['background']
    
    break

In [None]:
full_images.shape

In [None]:
true_full_locs.shape

In [None]:
for i in range(1): 
    plt.matshow(full_images[i, 0, :, :])

# Define VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                           stamp_slen = 7,
                                           step = 4,
                                           edge_padding = 1,
                                           n_bands = 1,
                                           max_detections = 3)


In [None]:
# check my reconstruction loss

In [None]:
full_backgrounds = torch.ones(full_images.shape) * data_params['sky_intensity']

In [None]:
map_loss, ps_loss = KL_objective_lib.get_kl_loss(star_encoder, 
                            full_images, 
                            full_backgrounds, 
                            star_dataset.simulator)

In [None]:
ps_loss