In [None]:
import sys
sys.path.insert(0, '../')

In [None]:
import numpy as np

import torch
import torch.optim as optim

import sdss_dataset_lib

import simulated_datasets_lib
import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib

from wake_sleep_lib import run_joint_wake, run_wake, run_sleep

import psf_transform_lib

import time
import fitsio

import json



In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ', device)

print('torch version: ', torch.__version__)



In [None]:
# set seed
np.random.seed(4534)
_ = torch.manual_seed(2534)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



In [None]:
# get sdss data
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(sdssdir='../../celeste_net/sdss_stage_dir/',
                                       hubble_cat_file = '../hubble_data/NCG7089/' + \
                                        'hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt.txt',
                                        bands = [2, 3])



In [None]:
# sdss image
full_image = sdss_hubble_data.sdss_image.unsqueeze(0).to(device)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0).to(device)


In [None]:
# simulated data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

print(data_params)


In [None]:
sky_intensity = full_background.reshape(full_background.shape[1], -1).mean(1)

In [None]:
# load psf
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

psf_og = np.array([psf_r, psf_i])


In [None]:
# draw data
print('generating data: ')
n_images = 6
t0 = time.time()
star_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                            data_params,
                            n_images = n_images,
                            sky_intensity = sky_intensity,
                            add_noise = True)

print('data generation time: {:.3f}secs'.format(time.time() - t0))
# get loader
batchsize = 2

loader = torch.utils.data.DataLoader(
                 dataset=star_dataset,
                 batch_size=batchsize,
                 shuffle=True)


In [None]:
# define VAE
star_encoder = starnet_vae_lib.StarEncoder(full_slen = data_params['slen'],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = 2,
                                           max_detections = 2)


In [None]:
# define psf transform
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og).to(device),
                                    data_params['slen'],
                                    kernel_size = 3)


In [None]:
filename = '../fits/results_11122019/wake_sleep_ri-loc630x310'
init_encoder = '../fits/results_11122019/starnet_ri'

# optimzers
psf_lr = 0.1
wake_optimizer = optim.Adam([
                    {'params': psf_transform.parameters(),
                    'lr': psf_lr}],
                    weight_decay = 1e-5)

sleep_optimizer = optim.Adam([
                    {'params': star_encoder.parameters(),
                    'lr': 5e-5}],
                    weight_decay = 1e-5)


In [None]:
# run wake

In [None]:
encoder_file = init_encoder

# load encoder
print('loading encoder from: ', encoder_file)
star_encoder.load_state_dict(torch.load(encoder_file,
                               map_location=lambda storage, loc: storage));
star_encoder.to(device);
star_encoder.eval();


In [None]:
run_wake(full_image, full_background, star_encoder, psf_transform,
                optimizer = wake_optimizer,
                n_epochs = 4,
                n_samples = 2,
                out_filename = filename + '-psf_transform',
                iteration = 0,
                use_iwae = True)

In [None]:
# run sleep 

In [None]:
star_encoder.load_state_dict(torch.load(encoder_file,
                                   map_location=lambda storage, loc: storage));
star_encoder.to(device)

# load trained transform
psf_transform_file = filename + '-psf_transform' + '-iter' + str(0)
print('loading psf_transform from: ', psf_transform_file)
psf_transform.load_state_dict(torch.load(psf_transform_file,
                            map_location=lambda storage, loc: storage));
psf_transform.to(device)
loader.dataset.simulator.psf = psf_transform.forward().detach()

In [None]:
run_sleep(star_encoder,
            loader,
            sleep_optimizer,
            n_epochs = 2,
            out_filename = filename + '-encoder',
            iteration = 1)

In [None]:
# Look at wake phase in more detail

In [None]:
sampled_locs_full_image, sampled_fluxes_full_image, sampled_n_stars_full, \
    log_q_locs, log_q_fluxes, log_q_n_stars = \
        star_encoder.sample_star_encoder(full_image, full_background,
                                n_samples = 10, return_map = False,
                                return_log_q = True,
                                training = False)


In [None]:
sampled_fluxes_full_image.shape