In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf

from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

In [None]:
# load data
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
sdss_hubble_data.sdss_background.mean()

In [None]:
# the full image
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())
plt.plot(sdss_hubble_data.locs_full_x1, 
         sdss_hubble_data.locs_full_x0, alpha = 0.2)

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze()[900:1000, 150:250])
plt.colorbar()

In [None]:
# check patch 

plt.matshow(sdss_hubble_data.sdss_image.squeeze())
plt.colorbar()

# plot a few subimages

In [None]:
fmin = 1000.

In [None]:
import plotting_utils

In [None]:
x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

In [None]:
f, axarr = plt.subplots(2, 3, figsize=(16, 12))

for i in range(6): 
    x0 = int(np.random.choice(x0_vec, 1))
    x1 = int(np.random.choice(x1_vec, 1))
    
    which_bright = sdss_hubble_data.fluxes > 1000.

    plotting_utils.plot_subimage(axarr[i // 3, i % 3], 
                                sdss_hubble_data.sdss_image.squeeze(), 
                                None, 
                                sdss_hubble_data.locs[which_bright], 
                                x0, x1, 
                                subimage_slen = 10, 
                                add_colorbar = True, 
                                global_fig = f)

# Test my simulator

In [None]:
import simulated_datasets_lib

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=str(sdss_hubble_data.psf_file), 
                                    slen = sdss_hubble_data.slen, 
                                    sky_intensity = sdss_hubble_data.sdss_background.mean())

In [None]:
recon_mean = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                                fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                                n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                                add_noise = False)

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

observed = sdss_hubble_data.sdss_image.squeeze() 
im0 = axarr[0].matshow(observed)
f.colorbar(im0, ax=axarr[0])

im1 = axarr[1].matshow(recon_mean.squeeze())
f.colorbar(im1, ax=axarr[1])

residual = recon_mean.squeeze() - observed
foo = (residual / observed)
im2 = axarr[2].matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr'))
f.colorbar(im2, ax=axarr[2])

In [None]:
residual.flatten().median()

In [None]:
plt.hist(residual.flatten().clamp(min = -1000, max = 1000), bins = 1000); 

In [None]:
plt.hist((residual / observed).flatten(), bins = 100); 

In [None]:
(residual / observed).mean()

In [None]:
### plot some subimages 
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

print([x0, x1])

plotting_utils.plot_subimage(axarr[0], 
                            observed, 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)


plotting_utils.plot_subimage(axarr[1], 
                            recon_mean.squeeze(), 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[2], 
                            residual / observed, 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

In [None]:
simulator.sky_intensity

# Check distribution on image stamps

In [None]:
import image_utils

In [None]:
sdss_hubble_data.sdss_image.shape

In [None]:
image_stamps = \
    image_utils.tile_images(sdss_hubble_data.sdss_image.unsqueeze(0),
                            subimage_slen = 9,
                            step = 2)

In [None]:
tile_coords = image_utils.get_tile_coords(sdss_hubble_data.sdss_image.shape[-1], 
                                          sdss_hubble_data.sdss_image.shape[-1],
                                        subimage_slen = 9, 
                                          step = 2);

In [None]:
subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      sdss_hubble_data.locs[sdss_hubble_data.fluxes > fmin].unsqueeze(0),
                                      sdss_hubble_data.fluxes[sdss_hubble_data.fluxes > fmin].unsqueeze(0),
                                      sdss_hubble_data.sdss_image.shape[-1],
                                      subimage_slen = 9,
                                      edge_padding = 3)

In [None]:
torch.sum(sdss_hubble_data.fluxes > fmin)

In [None]:
from torch.distributions.poisson import Poisson

In [None]:
poisson_distr = Poisson(rate = 0.5)

In [None]:
x = np.arange(0, 7)
h = plt.hist(n_stars, x)

plt.plot(torch.Tensor(h[1]), 
            h[0].sum() * torch.exp(poisson_distr.log_prob(torch.Tensor(h[1]))), 
           marker = 'x', color = 'red')

In [None]:
plt.hist(torch.log10(sdss_hubble_data.fluxes[which_bright]))

In [None]:
sdss_dataset_lib.convert_mag_to_nmgy(22.5) * sdss_hubble_data.nelec_per_nmgy_full.mean()