In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
from simulated_datasets_lib import plot_multiple_stars
import starnet_vae_lib
import plotting_utils 

import objectives_lib

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2566/6/65/psField-002566-6-0065.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(2)
_ = torch.manual_seed(34543)

# Simulate data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 20

print(data_params)


In [None]:
max_stars = data_params['max_stars']

slen = data_params['slen']

In [None]:
batchsize = 200

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = batchsize,
                            use_fresh_data = False, 
                            add_noise = True)

In [None]:
# true parameters
simulated_data_loader = torch.utils.data.DataLoader(
                 dataset=simulated_dataset,
                 batch_size=batchsize,
                 shuffle=False)

for _, data in enumerate(simulated_data_loader):
    true_fluxes = data['fluxes']
    true_locs = data['locs']
    true_n_stars = data['n_stars']
    simulated_images = data['image']
    simulated_backgrounds = data['background']
    
    break

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
    
    plotting_utils.plot_image(axarr[i1, i2], simulated_images[i, 0, :, :],
                true_locs = true_locs[i, 0:int(true_n_stars[i]), :])
    

# Recall results on simulated data

### Load star counter

In [None]:
star_counter = starnet_vae_lib.StarCounter(data_params['slen'],
                                                n_bands = 1, 
                                               max_detections = data_params['max_stars'])

In [None]:
star_counter.load_state_dict(torch.load('../fits/starnet_invKL_counter_twenty_stars',
                               map_location=lambda storage, loc: storage))
star_counter.eval(); 

In [None]:
objectives_lib.eval_star_counter_loss(star_counter, simulated_data_loader, train = False)

### Load encoder

In [None]:
star_encoder = starnet_vae_lib.StarEncoder(data_params['slen'], 
                                           n_bands = 1, 
                                          max_detections = max_stars)

In [None]:
star_encoder.load_state_dict(torch.load('../fits/starnet_invKL_encoder_twenty_stars', 
                               map_location=lambda storage, loc: storage))
star_encoder.eval(); 

In [None]:
loss, locs_loss, fluxes_loss, perm = \
    objectives_lib.get_encoder_loss(star_encoder, simulated_images, simulated_backgrounds, true_locs,
                        true_fluxes, true_n_stars)
    
print(loss)

### get results on simulated data 

In [None]:
plotting_utils.print_results(star_counter, 
                               star_encoder, 
                               simulated_images, 
                               simulated_backgrounds, 
                               simulated_dataset.psf, 
                               true_locs, 
                               true_n_stars, 
                               indx = np.arange(0, 20), 
                                use_true_n_stars = False)

# Apply NN to real data

In [None]:
import sdss_dataset_lib

## Load hubble data

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_dataset = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65)

In [None]:
len(sdss_dataset)

In [None]:
sdss_data_loader = torch.utils.data.DataLoader(
                         dataset=sdss_dataset,
                         batch_size=len(sdss_dataset),
                         shuffle=False)

for _, data in enumerate(sdss_data_loader):
    true_sdss_fluxes = data['fluxes'].float()
    true_sdss_locs = data['locs'].float()
    true_sdss_n_stars = data['n_stars'].float()
    sdss_images = data['image']
    sdss_backgrounds = data['background']
    
    break

### check out losses

In [None]:
loss, locs_loss, fluxes_loss, perm = \
    objectives_lib.get_encoder_loss(star_encoder, sdss_images, sdss_backgrounds, true_sdss_locs,
                        true_sdss_fluxes, true_sdss_n_stars)
    
print(loss)

In [None]:
plt.hist(np.log10(locs_loss.detach().numpy() + 1000))

In [None]:
objectives_lib.eval_star_counter_loss(star_counter, sdss_data_loader, train = False)

## look at results

In [None]:
plotting_utils.print_results(star_counter, 
                               star_encoder, 
                               sdss_images, 
                               sdss_backgrounds, 
                               simulated_dataset.psf, 
                               true_sdss_locs, 
                               true_sdss_n_stars, 
                               indx = np.arange(0, 20))

# simulated images with hubble parameters

In [None]:
simulated_hubble_images = \
        simulated_dataset.draw_image_from_params(locs = true_sdss_locs, 
                                                 fluxes = true_sdss_fluxes,
                                                 n_stars = true_sdss_n_stars,
                                                 add_noise = True)

In [None]:
len(simulated_hubble_images)

In [None]:
plotting_utils.print_results(star_counter, 
                           star_encoder, 
                           simulated_hubble_images, 
                           torch.ones(len(simulated_hubble_images), 1, 1, 1) * simulated_dataset.sky_intensity, 
                           simulated_dataset.psf, 
                           true_sdss_locs, 
                           true_sdss_n_stars, 
                           indx = np.arange(0, 20))