In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
f_min = 1000.

In [None]:
bands = [2]
run = 94
camcol = 1
field = 12

In [None]:
sdss_data = sdss_dataset_lib.SloanDigitalSkySurvey(sdssdir = '../../celeste_net/sdss_stage_dir/', 
                                      run = run, camcol = camcol, field = field)

In [None]:
x0 = 400
x1 = 300
slen = 101

In [None]:
image = torch.Tensor(sdss_data[0]['image'][0, x0:(x0 + slen), x1:(x1 + slen)])

In [None]:
background = torch.Tensor(sdss_data[0]['background'][0, x0:(x0 + slen), x1:(x1 + slen)])

In [None]:
plt.matshow(image)
plt.colorbar()

# define VAEs

In [None]:
star_encoder1 = starnet_vae_lib.StarEncoder(full_slen = image.shape[-1],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2)

In [None]:
star_encoder1.load_state_dict(torch.load('../fits/results_11202019/starnet_r', 
                                       map_location=lambda storage, loc: storage))

In [None]:
star_encoder1.eval();

In [None]:
star_encoder2 = starnet_vae_lib.StarEncoder(full_slen = image.shape[-1],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2)

In [None]:
star_encoder2.load_state_dict(torch.load('../fits/results_11202019/wake-sleep_630x310_r-encoder-iter6', 
                                       map_location=lambda storage, loc: storage))

In [None]:
star_encoder2.eval();

In [None]:
map_locs, map_fluxes, map_nstars = \
    star_encoder2.sample_star_encoder(image.unsqueeze(0).unsqueeze(0), 
                                 background.unsqueeze(0).unsqueeze(0), 
                                 return_map = True)[0:3]

In [None]:
plt.matshow(image)
plt.scatter(map_locs.squeeze()[:, 1] * slen, 
            map_locs.squeeze()[:, 0] * slen, 
            color = 'r', marker = 'x')

In [None]:
fig, axarr = plt.subplots(2, 2, figsize=(8, 6.5))

for i in range(4): 
    _x0 = int(np.random.choice(slen, 1))
    _x1 = int(np.random.choice(slen, 1))
    subimage_slen = 50
    
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], image, 
                                         map_locs.squeeze(), 
                                         None, 
                                         _x0, 
                                         _x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)

In [None]:
import fitsio

In [None]:
photo_name = "photoObj-{:06d}-{:d}-{:04d}.fits".format(run, camcol, field)

In [None]:
use_truth = True

In [None]:
if use_truth: 
    fits_file = fitsio.FITS('coadd_field_catalog_runjing_liu_0.fit')[1]
    ra = fits_file['ra'][:]
    decl = fits_file['dec'][:]
    
    flux_r = fits_file['psfmag_r'][:]
    flux_i = fits_file['psfmag_i'][:]
    
    which_stars = np.argwhere((fits_file['probpsf'][:] == 1) & (flux_r < 22.5)).squeeze()
    which_galx = np.argwhere(fits_file['probpsf'][:] == 0).squeeze()
        
else: 
    fits_file = fitsio.FITS('../../celeste_net/sdss_stage_dir/' + 
                            str(run) + '/' + str(camcol) + '/' + str(field) + 
                            '/' + photo_name)[1]
    
    ra = fits_file['RA'][:]
    decl = fits_file['DEC'][:]
    
    # see https://www.sdss.org/dr12/algorithms/classify/
    which_stars = np.argwhere(fits_file['OBJC_TYPE'][:] == 6).squeeze()
    which_galx = np.argwhere(fits_file['OBJC_TYPE'][:] == 3).squeeze()

In [None]:
(ra.max(), ra.min())

In [None]:
(decl.max(), decl.min())

In [None]:
from astropy.io import fits
from astropy.wcs import WCS

In [None]:
frame_name = "frame-{}-{:06d}-{:d}-{:04d}.fits".format('r', run, camcol, field)

In [None]:
hdulist = fits.open('../../celeste_net/sdss_stage_dir/' + str(run) + '/' + str(camcol) + '/' + str(field) + \
                        '/' + frame_name)
wcs = WCS(hdulist['primary'].header)

In [None]:
pix_coordinates = wcs.wcs_world2pix(ra, decl, 0, ra_dec_order = True)

In [None]:
_x0 = pix_coordinates[1]
_x1 = pix_coordinates[0]

def get_locs(_x0, _x1): 
    which_locs = (_x0 > x0) & (_x0 < (x0 + slen - 1)) & \
                (_x1 > x1) & (_x1 < (x1 + slen - 1))

    _x0 = _x0[which_locs] - x0
    _x1 = _x1[which_locs] - x1

    return torch.Tensor([_x0, _x1]).transpose(0, 1) / (slen - 1)

star_locs = get_locs(_x0[which_stars], _x1[which_stars])
galaxy_locs = get_locs(_x0[which_galx], _x1[which_galx])

In [None]:
save_fig = True

In [None]:
fig, axarr = plt.subplots(1, 1, figsize=(6, 5))

im = axarr.matshow(image)
fig.colorbar(im);



axarr.scatter(map_locs.squeeze()[:, 1] * slen, 
            map_locs.squeeze()[:, 0] * slen, 
            color = 'r', marker = 'x')


axarr.scatter(star_locs[:, 1] * slen, 
            star_locs[:, 0] * slen, 
            color = 'w', marker = 'o', alpha = 0.9)

axarr.scatter(galaxy_locs[:, 1] * slen, 
            galaxy_locs[:, 0] * slen, 
            color = 'g', marker = 'x')

fig.tight_layout()


if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/sparse_field_test.png')

In [None]:
fig, axarr = plt.subplots(1, 1, figsize=(4, 5 * .75))

_x0 = 15 #  int(np.random.choice(slen, 1))
_x1 = 33 # int(np.random.choice(slen, 1))
subimage_slen = 10
    
plotting_utils.plot_subimage(axarr, image, 
                                     map_locs.squeeze(), 
                                     None, 
                                     _x0, 
                                     _x1, subimage_slen = subimage_slen, 
                                    add_colorbar = True, 
                                     global_fig = fig)

fig.tight_layout()
if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/sparse_field_test_zoomed.png')

In [None]:
psf_sparse = sdss_psf.psf_at_points(0, 0, '../../celeste_net/sdss_stage_dir/94/1/12/psField-000094-1-0012.fit')
psf_crowded = sdss_psf.psf_at_points(0, 0, '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit')

fig, axarr = plt.subplots(1, 2, figsize=(8, 4))

axarr[0].matshow(psf_sparse[20:31, 20:31])
axarr[0].set_title('PSF in sparse field (94, 1, 12)\n')

axarr[1].matshow(psf_crowded[20:31, 20:31])
axarr[1].set_title('PSF in crowded field (2583, 2, 136)\n')

fig.tight_layout()

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/sparse_field_test_psfs.png')

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(24, 6))

_x0 = 40 #  int(np.random.choice(slen, 1))
_x1 = 50 # int(np.random.choice(slen, 1))
subimage_slen = 10
    
plotting_utils.plot_subimage(axarr[0], image, 
                                     map_locs.squeeze(), 
                                     star_locs, 
                                     _x0, 
                                     _x1, subimage_slen = subimage_slen, 
                                    add_colorbar = True, 
                                     global_fig = fig)


plotting_utils.plot_subimage(axarr[1], image, 
                                     map_locs.squeeze(), 
                                     galaxy_locs, 
                                     _x0, 
                                     _x1, subimage_slen = subimage_slen, 
                                    add_colorbar = True, 
                                     global_fig = fig)


In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(24, 6))

_x0 = 15 #  int(np.random.choice(slen, 1))
_x1 = 33 # int(np.random.choice(slen, 1))
subimage_slen = 10
    
plotting_utils.plot_subimage(axarr[0], image, 
                                     map_locs.squeeze(), 
                                     star_locs, 
                                     _x0, 
                                     _x1, subimage_slen = subimage_slen, 
                                    add_colorbar = True, 
                                     global_fig = fig)


plotting_utils.plot_subimage(axarr[1], image, 
                                     map_locs.squeeze(), 
                                     galaxy_locs, 
                                     _x0, 
                                     _x1, subimage_slen = subimage_slen, 
                                    add_colorbar = True, 
                                     global_fig = fig)



In [None]:
psf = sdss_psf.psf_at_points(0, 0, '../../celeste_net/sdss_stage_dir/94/1/12/psField-000094-1-0012.fit')

In [None]:
plt.matshow(psf)

# why dont we do the PSF EDA on this sparse field?

In [None]:
sdss_image = torch.Tensor(sdss_data[0]['image'])

In [None]:
pix_coordinates = wcs.wcs_world2pix(ra, decl, 0, ra_dec_order = True)

In [None]:
star_locs0 = pix_coordinates[1][which_stars] / sdss_image.shape[1]
star_locs1 = pix_coordinates[0][which_stars] / sdss_image.shape[2]

which_keep = (star_locs0 > 0) & (star_locs0 < 1) & (star_locs1 > 0) & (star_locs1 < 1)

star_locs = torch.Tensor(np.array([star_locs0, star_locs1]).transpose()[which_keep])

flux_r = flux_r[torch.nonzero(torch.Tensor(which_keep)).squeeze()]
flux_i = flux_i[torch.nonzero(torch.Tensor(which_keep)).squeeze()]

In [None]:
plt.matshow(sdss_image[0])
plt.scatter(star_locs[:, 1] * sdss_image.shape[2], 
            star_locs[:, 0] * sdss_image.shape[1])

In [None]:
def get_star_patches(full_image, true_locs, which_stars, subimage_slen):
    
    assert len(full_image.shape) == 4
    assert (true_locs >= 0).all() & (true_locs <= 1).all()
    
    slen0 = full_image.shape[-2]
    slen1 = full_image.shape[-1]
    
    which_locs = true_locs[which_stars]
    
    star_patches = torch.zeros(which_locs.shape[0], subimage_slen, subimage_slen)
    patch_coords = torch.zeros(which_locs.shape[0], 2)
    
    is_blended = torch.zeros(which_locs.shape[0])
    
    for i in range(which_stars.shape[0]):
        loc_i = which_locs[i] * torch.Tensor([slen0 - 1., slen1 - 1.])

        which_pix = loc_i.round().type(torch.long)

        x0 = int(which_pix[0] - (subimage_slen - 1) / 2)
        x1 = int(which_pix[1] - (subimage_slen - 1) / 2)
        
        assert x0 > 0
        assert x1 > 0
        assert (x0 + subimage_slen) < slen0
        assert (x1 + subimage_slen) < slen1
        
        star_patches[i] = full_image[0, 0, x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]
        patch_coords[i] = torch.Tensor([x0, x1])
        
        if (star_patches[i, int((subimage_slen - 1) / 2), 
                            int((subimage_slen - 1) / 2)]) < star_patches[i].max(): 
            is_blended[i] = 1
        
        
    return star_patches, patch_coords, is_blended

### Brighter fatter effect

In [None]:
plt.hist(flux_r)

In [None]:
thresh = np.median(flux_r)

In [None]:
thresh

In [None]:
which_stars = torch.nonzero((flux_r.squeeze() > thresh) & 
                           (star_locs[:, 0] < 0.95) & (star_locs[:, 1] < 0.95) & 
                           (star_locs[:, 0] > 0.05) & (star_locs[:, 1] > 0.05)).squeeze()


print(len(which_stars))

subimage_slen = 9

star_patches, patch_coords, is_blended = \
    get_star_patches(sdss_image.unsqueeze(0), star_locs, which_stars, subimage_slen)

star_patches = star_patches[is_blended == 0]
    
# normalize
star_patches_normalized = \
    star_patches / star_patches.view(star_patches.shape[0], -1).sum(1).unsqueeze(-1).unsqueeze(-1)

plt.matshow(star_patches_normalized.mean(0))
plt.colorbar()

In [None]:
which_stars = torch.nonzero((flux_r.squeeze() < thresh) & 
                           (star_locs[:, 0] < 0.95) & (star_locs[:, 1] < 0.95) & 
                           (star_locs[:, 0] > 0.05) & (star_locs[:, 1] > 0.05)).squeeze()


print(len(which_stars))

subimage_slen = 9

star_patches, patch_coords, is_blended = \
    get_star_patches(sdss_image.unsqueeze(0), star_locs, which_stars, subimage_slen)

star_patches = star_patches[is_blended == 0]
    
# normalize
star_patches_normalized2 = \
    star_patches / star_patches.view(star_patches.shape[0], -1).sum(1).unsqueeze(-1).unsqueeze(-1)

plt.matshow(star_patches_normalized2.mean(0))
plt.colorbar()

In [None]:
star_patches_normalized.mean()

In [None]:
star_patches_normalized2.mean()

In [None]:
# figure for paper
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

# vmax = max(star_patches_normalized.mean(0).max(), star_patches_normalized2.mean(0).max())
# vmin = min(star_patches_normalized.mean(0).min(), star_patches_normalized2.mean(0).min())

im0 = axarr[0].matshow(star_patches_normalized.mean(0))
fig.colorbar(im0, ax = axarr[0])
axarr[0].set_title('\nOverlayed Stars with mag > 21.1', fontsize = 16)

im1 = axarr[1].matshow(star_patches_normalized2.mean(0))
fig.colorbar(im1, ax = axarr[1])
axarr[1].set_title('\nOverlayed Stars with mag <  21.1', fontsize = 16)

diff = star_patches_normalized2.mean(0) - star_patches_normalized.mean(0)
im2 = axarr[2].matshow(diff, vmax = diff.abs().max(), vmin = -diff.abs().max(), 
                       cmap = plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])
axarr[2].set_title('\nDifference', fontsize = 16)

for i in range(3): 
    axarr[i].set_xticks([])
    axarr[i].set_yticks([])

fig.tight_layout()
fig.savefig('../../qualifying_exam_slides/figures/psf_by_mag.png')

In [None]:
diff.sum()

### The effect of color

In [None]:
plt.hist(flux_r - flux_i)

In [None]:
plt.plot(flux_r.numpy(), 
         (flux_r - flux_i).numpy(), 'x')

In [None]:
thresh = np.median(flux_r - flux_i)

In [None]:
thresh

In [None]:
which_stars = torch.nonzero(((flux_r - flux_i).squeeze() > thresh) & 
                           (star_locs[:, 0] < 0.95) & (star_locs[:, 1] < 0.95) & 
                           (star_locs[:, 0] > 0.05) & (star_locs[:, 1] > 0.05)).squeeze()


print(len(which_stars))

subimage_slen = 9

star_patches, patch_coords, is_blended = \
    get_star_patches(sdss_image.unsqueeze(0), star_locs, which_stars, subimage_slen)

star_patches = star_patches[is_blended == 0]
    
# normalize
star_patches_normalized = \
    star_patches / star_patches.view(star_patches.shape[0], -1).sum(1).unsqueeze(-1).unsqueeze(-1)

plt.matshow(star_patches_normalized.mean(0))
plt.colorbar()

In [None]:
which_stars = torch.nonzero(((flux_r - flux_i).squeeze() < thresh) & 
                           (star_locs[:, 0] < 0.95) & (star_locs[:, 1] < 0.95) & 
                           (star_locs[:, 0] > 0.05) & (star_locs[:, 1] > 0.05)).squeeze()


print(len(which_stars))

subimage_slen = 9

star_patches, patch_coords, is_blended = \
    get_star_patches(sdss_image.unsqueeze(0), star_locs, which_stars, subimage_slen)

star_patches = star_patches[is_blended == 0]
    
# normalize
star_patches_normalized2 = \
    star_patches / star_patches.view(star_patches.shape[0], -1).sum(1).unsqueeze(-1).unsqueeze(-1)

plt.matshow(star_patches_normalized2.mean(0))
plt.colorbar()

In [None]:
star_patches_normalized.mean()

In [None]:
star_patches_normalized2.mean()

In [None]:
# figure for paper
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

# vmax = max(star_patches_normalized.mean(0).max(), star_patches_normalized2.mean(0).max())
# vmin = min(star_patches_normalized.mean(0).min(), star_patches_normalized2.mean(0).min())

im0 = axarr[0].matshow(star_patches_normalized.mean(0))
fig.colorbar(im0, ax = axarr[0])
axarr[0].set_title('\nOverlayed Stars with mag > 21.1', fontsize = 16)

im1 = axarr[1].matshow(star_patches_normalized2.mean(0))
fig.colorbar(im1, ax = axarr[1])
axarr[1].set_title('\nOverlayed Stars with mag <  21.1', fontsize = 16)

diff = star_patches_normalized2.mean(0) - star_patches_normalized.mean(0)
im2 = axarr[2].matshow(diff, vmax = diff.abs().max(), vmin = -diff.abs().max(), 
                       cmap = plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])
axarr[2].set_title('\nDifference', fontsize = 16)

for i in range(3): 
    axarr[i].set_xticks([])
    axarr[i].set_yticks([])

fig.tight_layout()
# fig.savefig('../../qualifying_exam_slides/figures/psf_by_mag.png')

In [None]:
diff.sum()