In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf

from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

In [None]:
# load data
bands = [2, 3]
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = bands)

In [None]:
sdss_hubble_data.sdss_background.reshape(len(bands), -1).mean(1)

In [None]:
# the full image
plt.matshow(sdss_hubble_data.sdss_image_full[0])

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[0][880:980, 180:280])

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full[0])
plt.plot(sdss_hubble_data.locs_full_x1, 
         sdss_hubble_data.locs_full_x0, alpha = 0.2)

In [None]:
# check patch 

f, axarr = plt.subplots(1, 2, figsize=(8, 4))

im0 = axarr[0].matshow(sdss_hubble_data.sdss_image[0])
f.colorbar(im0, ax=axarr[0])

im1 = axarr[1].matshow(sdss_hubble_data.sdss_image[1])
f.colorbar(im0, ax=axarr[1])

In [None]:
foo = sdss_hubble_data.sdss_image[0] - sdss_hubble_data.sdss_image[1]

plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
from astropy.io import fits

In [None]:
hdulist_i = fits.open('../../celeste_net/sdss_stage_dir/2583/2/136/frame-i-002583-2-0136.fits')
wcs_i = WCS(hdulist_i['primary'].header)

In [None]:
pix_coordinates_r = sdss_hubble_data.wcs.wcs_world2pix(sdss_hubble_data.hubble_ra, 
                                                       sdss_hubble_data.hubble_dc, 0, ra_dec_order = True)

pix_coordinates_i = wcs_i.wcs_world2pix(sdss_hubble_data.hubble_ra, 
                                        sdss_hubble_data.hubble_dc, 0, ra_dec_order = True)

In [None]:
plt.hist(pix_coordinates_r[0] - pix_coordinates_i[0]); 

In [None]:
plt.hist(pix_coordinates_r[1] - pix_coordinates_i[1]); 

In [None]:
shift_x0 = np.median(pix_coordinates_r[1] - pix_coordinates_i[1])
shift_x1 = np.median(pix_coordinates_r[0] - pix_coordinates_i[0])

In [None]:
from simulated_datasets_lib import _get_mgrid

In [None]:
grid = _get_mgrid(sdss_hubble_data.sdss_image.shape[-1]).unsqueeze(0)

In [None]:
locs = torch.Tensor([shift_x1, shift_x0]).unsqueeze(0).unsqueeze(0).unsqueeze(0)/100

In [None]:
grid_loc = grid

In [None]:
foo = torch.nn.functional.grid_sample(sdss_hubble_data.sdss_image[0, :, :].unsqueeze(0).unsqueeze(0),
                                      grid_loc)

In [None]:
foo.shape

In [None]:
plt.matshow(sdss_hubble_data.sdss_image[0, :, :])

In [None]:
plt.matshow(foo.squeeze())

In [None]:
HTcat = \
    np.loadtxt('../hubble_data/NCG7089/hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt.txt', skiprows=True)

In [None]:
ra = HTcat[:, 21]
decl = HTcat[:, 22]

In [None]:
from astropy.io import fits

In [None]:
hdulist_r = fits.open('../../celeste_net/sdss_stage_dir/2583/2/136/frame-r-002583-2-0136.fits')
hdulist_i = fits.open('../../celeste_net/sdss_stage_dir/2583/2/136/frame-i-002583-2-0136.fits')

In [None]:
wcs_r = WCS(hdulist_r['primary'].header)
wcs_i = WCS(hdulist_i['primary'].header)

In [None]:
pix_coordinates_r = wcs_r.wcs_world2pix(ra, decl, 0, ra_dec_order = True)
pix_coordinates_i = wcs_i.wcs_world2pix(ra, decl, 0, ra_dec_order = True)

In [None]:
plt.hist(pix_coordinates_r[0] - pix_coordinates_i[0]); 

In [None]:
plt.hist(pix_coordinates_r[1] - pix_coordinates_i[1]); 

In [None]:
pix_coordinates

In [None]:
pix_coordinates = wcs.wcs_world2pix(0, 0, 0, ra_dec_order = True)
pix_coordinates

In [None]:
wcs.all_pix2world([[1.0, 1.0]], 0.)

In [None]:
pix_coordinates

In [None]:
pix_coordinates

# plot a few subimages

In [None]:
fmin = 1000.

In [None]:
import plotting_utils

In [None]:
x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

In [None]:
f, axarr = plt.subplots(2, 3, figsize=(16, 12))

for i in range(6): 
    x0 = int(np.random.choice(x0_vec, 1))
    x1 = int(np.random.choice(x1_vec, 1))
    
    which_bright = sdss_hubble_data.fluxes > 1000.

    plotting_utils.plot_subimage(axarr[i // 3, i % 3], 
                                sdss_hubble_data.sdss_image.squeeze(), 
                                None, 
                                sdss_hubble_data.locs[which_bright], 
                                x0, x1, 
                                subimage_slen = 10, 
                                add_colorbar = True, 
                                global_fig = f)

# Test my simulator

In [None]:
import simulated_datasets_lib
import fitsio

In [None]:
# load psf 
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_g = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-g.fits')[0].read()

psf_og = np.array([psf_r, psf_g])
psf_og = np.array([psf_r])

In [None]:
sky_intensity = torch.Tensor([sdss_hubble_data.sdss_background.mean()])

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf = psf_og,
                                    slen = sdss_hubble_data.slen, 
                                    sky_intensity = sky_intensity)

In [None]:
_fluxes = sdss_hubble_data.fluxes.unsqueeze(0).unsqueeze(0)
_locs = sdss_hubble_data.locs.unsqueeze(0)
_n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor)

In [None]:
recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                fluxes = _fluxes, 
                                n_stars = _n_stars, 
                                add_noise = False)

In [None]:
observed.shape

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

observed = sdss_hubble_data.sdss_image.squeeze() 
im0 = axarr[0].matshow(observed[0])
f.colorbar(im0, ax=axarr[0])

im1 = axarr[1].matshow(recon_mean[0, 0])
f.colorbar(im1, ax=axarr[1])

residual = recon_mean[0,0] - observed[0]
foo = (residual / observed[0])
im2 = axarr[2].matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr'))
f.colorbar(im2, ax=axarr[2])

In [None]:
residual.flatten().median()

In [None]:
plt.hist(residual.flatten().clamp(min = -1000, max = 1000), bins = 1000); 

In [None]:
plt.hist((residual / observed).flatten(), bins = 100); 

In [None]:
(residual / observed).mean()

In [None]:
### plot some subimages 
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

print([x0, x1])

plotting_utils.plot_subimage(axarr[0], 
                            observed, 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)


plotting_utils.plot_subimage(axarr[1], 
                            recon_mean.squeeze(), 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[2], 
                            residual / observed, 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

In [None]:
simulator.sky_intensity

# Check distribution on image stamps

In [None]:
import image_utils

In [None]:
sdss_hubble_data.sdss_image.shape

In [None]:
image_stamps = \
    image_utils.tile_images(sdss_hubble_data.sdss_image.unsqueeze(0),
                            subimage_slen = 9,
                            step = 2)

In [None]:
tile_coords = image_utils.get_tile_coords(sdss_hubble_data.sdss_image.shape[-1], 
                                          sdss_hubble_data.sdss_image.shape[-1],
                                        subimage_slen = 9, 
                                          step = 2);

In [None]:
subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      sdss_hubble_data.locs[sdss_hubble_data.fluxes > fmin].unsqueeze(0),
                                      sdss_hubble_data.fluxes[sdss_hubble_data.fluxes > fmin].unsqueeze(0),
                                      sdss_hubble_data.sdss_image.shape[-1],
                                      subimage_slen = 9,
                                      edge_padding = 3)

In [None]:
torch.sum(sdss_hubble_data.fluxes > fmin)

In [None]:
from torch.distributions.poisson import Poisson

In [None]:
poisson_distr = Poisson(rate = 0.5)

In [None]:
x = np.arange(0, 7)
h = plt.hist(n_stars, x)

plt.plot(torch.Tensor(h[1]), 
            h[0].sum() * torch.exp(poisson_distr.log_prob(torch.Tensor(h[1]))), 
           marker = 'x', color = 'red')

In [None]:
plt.hist(torch.log10(sdss_hubble_data.fluxes[which_bright]))

In [None]:
sdss_dataset_lib.convert_mag_to_nmgy(22.5) * sdss_hubble_data.nelec_per_nmgy_full.mean()