In [None]:
import torch
import torch.optim as optim

import numpy as np

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import residuals_vae_lib
import simulated_datasets_lib
import starnet_vae_lib

import matplotlib.pyplot as plt

import inv_KL_objective_lib as objectives_lib

import time

import json

import plotting_utils

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ', device)

print('torch version: ', torch.__version__)

# set seed
np.random.seed(43534)
_ = torch.manual_seed(24534)



# Get Hubble data

In [None]:
max_stars = 20

hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = \
    sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file,
                                    sdssdir = '../../celeste_net/sdss_stage_dir/',
                                    slen = 11,
                                    run = 2566,
                                    camcol = 6,
                                    field = 65,
                                    max_detections = max_stars)
batchsize = len(sdss_hubble_data)
sdss_loader = torch.utils.data.DataLoader(
                 dataset=sdss_hubble_data,
                 batch_size=batchsize,
                 shuffle=False)


# record sky intensity
sky_intensity = sdss_hubble_data.sdss_background_full.mean()

# Get simulator

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = max_stars
data_params['sky_intensity'] = sky_intensity

print(data_params)

# dataset
simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(sdss_hubble_data.psf_file,
                            data_params,
                            n_stars = 1000,
                            add_noise = True)

batchsize = 1000
simulated_loader = torch.utils.data.DataLoader(
                                 dataset=simulated_dataset,
                                 batch_size=batchsize,
                                 shuffle=True)

for _, data in enumerate(simulated_loader):
    simulated_true_fluxes = data['fluxes']
    simulated_true_locs = data['locs']
    simulated_true_n_stars = data['n_stars']
    _simulated_images = data['image']
    simulated_backgrounds = data['background']
    
    break

# Get generative model

In [None]:
cycle = 10

In [None]:
if cycle == 0: 
    residual_vae = None
else: 
    residual_vae = residuals_vae_lib.ResidualVAE(slen = sdss_hubble_data.slen,
                                                n_bands = 1,
                                                f_min = 200.)
    residual_vae.load_state_dict(torch.load('../fits/residual_vae_wake' + str(cycle), 
                                   map_location=lambda storage, loc: storage))
    residual_vae.eval(); 
    
    

In [None]:
if cycle > 0: 
    eta = torch.randn(_simulated_images.shape[0], residual_vae.latent_dim).to(device)
    residuals = residual_vae.decode(eta)[0] # just taking the mean ...

    simulated_images = _simulated_images * residuals.detach() + _simulated_images
else: 
    simulated_images = _simulated_images
    residuals = _simulated_images * 0.0

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
        
    # image 
    axarr[i1, i2].matshow(residuals[i, 0, :, :].detach())

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
        
    # image 
    axarr[i1, i2].matshow(simulated_images[i, 0, :, :])
    axarr[i1, i2].set_title('n_stars: {}\n'.format(simulated_true_n_stars[i]))
    
    # plot locations 
    locs_i = simulated_true_locs[i]
    n_stars_i = simulated_true_n_stars[i]
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (simulated_images.shape[-1] - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (simulated_images.shape[-1] - 1)
    
    axarr[i1, i2].scatter(x = locs_x, y = locs_y, color = 'b')


# Get VAE

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(data_params['slen'],
                                           n_bands = 1,
                                          max_detections = max_stars)

# load iteration 0 results: i.e. encoder trained on simulated data only
encoder_init = '../fits/starnet_encoder_sleep' + str(cycle)
print('loading encoder from: ', encoder_init)
star_encoder.load_state_dict(torch.load(encoder_init,
                               map_location=lambda storage, loc: storage))

star_encoder.eval();

In [None]:
loss, counter_loss, locs_loss, fluxes_loss, perm = \
    objectives_lib.get_encoder_loss(star_encoder, simulated_images, simulated_backgrounds, 
                                    simulated_true_locs, simulated_true_fluxes, simulated_true_n_stars)
    
print(loss)

# Check results on simulated data

In [None]:
indx = np.arange(0, 20)

plotting_utils.print_results(star_encoder, 
                                simulated_images[indx], 
                                simulated_backgrounds[indx], 
                                simulated_dataset.simulator.psf, 
                                simulated_true_locs[indx], 
                                simulated_true_n_stars[indx],
                                use_true_n_stars = False)

# Check results on sdss data

In [None]:
for _, data in enumerate(sdss_loader):
    hubble_fluxes = data['fluxes']
    hubble_locs = data['locs']
    hubble_n_stars = data['n_stars']
    sdss_images = data['image']
    sdss_backgrounds = data['background']
    
    break

In [None]:
indx = np.arange(30, 40)
plotting_utils.print_results(star_encoder, 
                                sdss_images[indx],
                                sdss_backgrounds[indx],
                                simulated_dataset.simulator.psf, 
                                hubble_locs[indx],
                                hubble_n_stars[indx], 
                                use_true_n_stars = False, 
                                residual_clamp = 1e16)

In [None]:
indx = np.argwhere(hubble_n_stars <= 4).squeeze()[20:30]
plotting_utils.print_results(star_encoder, 
                                sdss_images[indx],
                                sdss_backgrounds[indx],
                                simulated_dataset.simulator.psf, 
                                hubble_locs[indx],
                                hubble_n_stars[indx], 
                                use_true_n_stars = False, 
                                residual_clamp = 200.)

In [None]:
map_n_stars, map_locs, map_fluxes, \
    logit_loc_mean, logit_loc_log_var, \
        log_flux_mean, log_flux_log_var, \
            log_probs, recon_mean = plotting_utils.get_variational_parameters(star_encoder, 
                        sdss_images[indx],
                        sdss_backgrounds[indx], 
                        simulated_dataset.simulator.psf, 
                        true_n_stars = None)

In [None]:
fig, axarr = plt.subplots(2, 5, figsize=(15, 6))

for i in range(10):
    
    image = sdss_images[indx[i]][0, :, :]
    n_stars_i = hubble_n_stars[indx[i]]
    locs_i = hubble_locs[indx[i]]
    
    plotting_utils.plot_image(axarr[int(np.floor(i / 5)), int(i % 5)], 
                            sdss_images[indx[i]][0, :, :], 
                            true_locs = locs_i[0:n_stars_i], 
                            estimated_locs = map_locs[i, 0:map_n_stars[i]])

In [None]:
residual_vae.load_state_dict(torch.load('../fits/residual_vae_wake1', 
                               map_location=lambda storage, loc: storage))
residual_vae.eval(); 

In [None]:
eta = torch.randn(10, residual_vae.latent_dim)

In [None]:
recon_mean, recon_logvar = residual_vae.decode(eta)

In [None]:
for i in range(10): 
    vmax = torch.abs(recon_mean[i, 0, :, :]).max()
    plt.matshow(recon_mean[i, 0, :, :].detach(), vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
    plt.colorbar()