In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
f_min = 1000.

In [None]:
bands = [2]
run = 94
camcol = 1
field = 12

In [None]:
sdss_data = sdss_dataset_lib.SloanDigitalSkySurvey(sdssdir = '../../celeste_net/sdss_stage_dir/', 
                                      run = run, camcol = camcol, field = field)

In [None]:
x0 = 100
x1 = 100
slen = 501

In [None]:
image = torch.Tensor(sdss_data[0]['image'][0, x0:(x0 + slen), x1:(x1 + slen)])

In [None]:
background = torch.Tensor(sdss_data[0]['background'][0, x0:(x0 + slen), x1:(x1 + slen)])

In [None]:
plt.matshow(image)
plt.colorbar()

# define VAEs

In [None]:
star_encoder1 = starnet_vae_lib.StarEncoder(full_slen = image.shape[-1],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2)

In [None]:
star_encoder1.load_state_dict(torch.load('../fits/results_11202019/starnet_r', 
                                       map_location=lambda storage, loc: storage))

In [None]:
star_encoder1.eval();

In [None]:
star_encoder2 = starnet_vae_lib.StarEncoder(full_slen = image.shape[-1],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2)

In [None]:
star_encoder2.load_state_dict(torch.load('../fits/results_11202019/wake-sleep_630x310_r-encoder-iter6', 
                                       map_location=lambda storage, loc: storage))

In [None]:
star_encoder2.eval();

In [None]:
map_locs, map_fluxes, map_nstars = \
    star_encoder2.sample_star_encoder(image.unsqueeze(0).unsqueeze(0), 
                                 background.unsqueeze(0).unsqueeze(0), 
                                 return_map = True)[0:3]

In [None]:
plt.matshow(image)
plt.scatter(map_locs.squeeze()[:, 1] * 500, 
            map_locs.squeeze()[:, 0] * 500, 
            color = 'r', marker = 'x')

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 6.5))

for i in range(4): 
    _x0 = int(np.random.choice(500, 1))
    _x1 = int(np.random.choice(500, 1))
    subimage_slen = 50
    
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], image, 
                                         map_locs.squeeze(), 
                                         None, 
                                         _x0, 
                                         _x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)

In [None]:
import fitsio

In [None]:
photo_name = "photoObj-{:06d}-{:d}-{:04d}.fits".format(run, camcol, field)

In [None]:
use_truth = True

In [None]:
if use_truth: 
    fits_file = fitsio.FITS('coadd_field_catalog_runjing_liu_0.fit')[1]
    ra = fits_file['ra'][:]
    decl = fits_file['dec'][:]
    
    which_stars = np.argwhere(fits_file['probpsf'][:] == 1).squeeze()
    which_galx = np.argwhere(fits_file['probpsf'][:] == 0).squeeze()
    
else: 
    fits_file = fitsio.FITS('../../celeste_net/sdss_stage_dir/' + str(run) + '/' + str(camcol) + '/' + str(field) + 
                            '/' + photo_name)[1]
    
    ra = fits_file['RA'][:]
    decl = fits_file['DEC'][:]
    
    # see https://www.sdss.org/dr12/algorithms/classify/
    which_stars = np.argwhere(fits_file['OBJC_TYPE'][:] == 6).squeeze()
    which_galx = np.argwhere(fits_file['OBJC_TYPE'][:] == 3).squeeze()

In [None]:
(ra.max(), ra.min())

In [None]:
(decl.max(), decl.min())

In [None]:
from astropy.io import fits
from astropy.wcs import WCS

In [None]:
frame_name = "frame-{}-{:06d}-{:d}-{:04d}.fits".format('r', run, camcol, field)

In [None]:
hdulist = fits.open('../../celeste_net/sdss_stage_dir/' + str(run) + '/' + str(camcol) + '/' + str(field) + \
                        '/' + frame_name)
wcs = WCS(hdulist['primary'].header)

In [None]:
pix_coordinates = wcs.wcs_world2pix(ra, decl, 0, ra_dec_order = True)

In [None]:
_x0 = pix_coordinates[1]
_x1 = pix_coordinates[0]

def get_locs(_x0, _x1): 
    which_locs = (_x0 > x0) & (_x0 < (x0 + slen - 1)) & \
                (_x1 > x1) & (_x1 < (x1 + slen - 1))

    _x0 = _x0[which_locs] - x0
    _x1 = _x1[which_locs] - x1

    return torch.Tensor([_x0, _x1]).transpose(0, 1) / (slen - 1)

star_locs = get_locs(_x0[which_stars], _x1[which_stars])
galaxy_locs = get_locs(_x0[which_galx], _x1[which_galx])

In [None]:
plt.matshow(image)
plt.colorbar();

plt.scatter(map_locs.squeeze()[:, 1] * 500, 
            map_locs.squeeze()[:, 0] * 500, 
            color = 'r', marker = 'x')

plt.scatter(star_locs[:, 1] * 500, 
            star_locs[:, 0] * 500, 
            color = 'b', marker = 'x')

plt.scatter(galaxy_locs[:, 1] * 500, 
            galaxy_locs[:, 0] * 500, 
            color = 'g', marker = 'x')



In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(24, 6))

_x0 = int(np.random.choice(500, 1))
_x1 = int(np.random.choice(500, 1))
subimage_slen = 50
    
plotting_utils.plot_subimage(axarr[0], image, 
                                     map_locs.squeeze(), 
                                     star_locs, 
                                     _x0, 
                                     _x1, subimage_slen = subimage_slen, 
                                    add_colorbar = True, 
                                     global_fig = fig)


plotting_utils.plot_subimage(axarr[1], image, 
                                     map_locs.squeeze(), 
                                     galaxy_locs, 
                                     _x0, 
                                     _x1, subimage_slen = subimage_slen, 
                                    add_colorbar = True, 
                                     global_fig = fig)
