In [None]:
import numpy as np

import torch
import torch.optim as optim

import fitsio

import sys
sys.path.insert(0, '../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import inv_kl_objective_lib as objectives_lib

import time

import json

from torch.distributions import normal

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ', device)

print('torch version: ', torch.__version__)


In [None]:
# set seed
np.random.seed(4534)
_ = torch.manual_seed(2534)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

print(data_params)


In [None]:
# load psf
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = np.array([psf_r, psf_i])


In [None]:
# sky intensity: for the r and i band
sky_intensity = torch.Tensor([686., 1123.])


In [None]:
# draw data
print('generating data: ')
n_images = 10
t0 = time.time()
star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                            data_params,
                            sky_intensity = sky_intensity,
                            n_images = n_images,
                            add_noise = True)


In [None]:
batchsize = 2

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=True)


In [None]:
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = psf_og.shape[0],
                                           max_detections = 2)



In [None]:
# define optimizer
learning_rate = 1e-3
weight_decay = 1e-5
optimizer = optim.Adam([
                    {'params': star_encoder.parameters(),
                    'lr': learning_rate}],
                    weight_decay = weight_decay)



In [None]:
avg_loss, counter_loss, locs_loss, fluxes_loss \
        = objectives_lib.eval_star_encoder_loss(star_encoder, loader,
                                                    optimizer, train = True)

In [None]:
avg_loss

In [None]:
for _, data in enumerate(loader):
    true_fluxes = data['fluxes'].to(device)
    true_locs = data['locs'].to(device)
    images = data['image'].to(device)
    backgrounds = data['background'].to(device)

    break

In [None]:
loss, counter_loss, locs_loss, fluxes_loss = \
    objectives_lib.get_encoder_loss(star_encoder, images, backgrounds,
                                true_locs, true_fluxes)[0:4]

In [None]:
loss

In [None]:
counter_loss

In [None]:
image_stamps, subimage_locs, subimage_fluxes, true_n_stars, true_is_on_array = \
        star_encoder.get_image_stamps(images, true_locs, true_fluxes,
                                        clip_max_stars = True)

background_stamps = \
    star_encoder.get_image_stamps(backgrounds, None, None,
                                  trim_images = False)[0]

# get variational parameters
logit_loc_mean, logit_loc_log_var, \
    log_flux_mean, log_flux_log_var, log_probs = \
        star_encoder(image_stamps, background_stamps, true_n_stars)

In [None]:
background_stamps.shape

In [None]:
image_stamps.shape

In [None]:
subimage_fluxes

In [None]:
logit_loc_mean

In [None]:
log_probs

In [None]:
log_flux_log_var