In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf

from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

In [None]:
# load data
bands = [2, 3]
x0 = 630
x1 = 310
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = x0,
                                                    x1 = x1, 
                                                    bands = bands, 
                                                    background_bias = torch.Tensor([167., 222.]))


In [None]:
# the full image
plt.matshow(sdss_hubble_data.sdss_image_full[0])

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full[0])
plt.plot(sdss_hubble_data.locs_full_x1, 
         sdss_hubble_data.locs_full_x0, alpha = 0.2)

In [None]:
# check patch 

for i in range(len(bands)):
    plt.matshow(sdss_hubble_data.sdss_image[i])
    plt.colorbar()


In [None]:
# check alignment between bands 

if len(bands) > 1: 
    band_diff = (sdss_hubble_data.sdss_image[1]) - \
                (sdss_hubble_data.sdss_image[0])

    plt.matshow(band_diff, vmax = band_diff.abs().max(), vmin = -band_diff.abs().max(), 
                cmap = plt.get_cmap('bwr'))

    plt.colorbar()

# Distribution of colors

In [None]:
if len(bands) > 1: 
    foo = (sdss_hubble_data.sdss_image[1]) / \
                (sdss_hubble_data.sdss_image[0])

    foo = torch.log10(foo).flatten() * (2.5)
    plt.hist(foo, bins = 100);

    print(foo.mean())
    print(foo.var().sqrt())

In [None]:
if len(bands) > 1: 

    foo = (sdss_hubble_data.fluxes[:, 1]) / \
                (sdss_hubble_data.fluxes[:, 0])

    foo = torch.log10(foo).flatten() * (2.5)
    plt.hist(foo, bins = 100);

    print(foo.mean())
    print(foo.var().sqrt())

# plot a few subimages

In [None]:
fmin = 500.

In [None]:
import plotting_utils

In [None]:
x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

In [None]:
for i in range(6): 
    x0 = int(np.random.choice(x0_vec, 1))
    x1 = int(np.random.choice(x1_vec, 1))
    
    which_bright = (sdss_hubble_data.fluxes > fmin)[:, 0]
    
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))
    plotting_utils.plot_subimage(axarr[0], 
                                sdss_hubble_data.sdss_image[0], 
                                None, 
                                sdss_hubble_data.locs[which_bright], 
                                x0, x1, 
                                subimage_slen = 10, 
                                add_colorbar = True, 
                                global_fig = f)
    axarr[0].set_title('band ' + str(bands[0]))
    
    if len(bands) > 1: 
        plotting_utils.plot_subimage(axarr[1], 
                                    sdss_hubble_data.sdss_image[1], 
                                    None, 
                                    sdss_hubble_data.locs[which_bright], 
                                    x0, x1, 
                                    subimage_slen = 10, 
                                    add_colorbar = True, 
                                    global_fig = f)
        axarr[1].set_title('band ' + str(bands[1]))

        plotting_utils.plot_subimage(axarr[2], 
                                    (sdss_hubble_data.sdss_image[1]) - \
                                     (sdss_hubble_data.sdss_image[0]), 
                                    None, 
                                    sdss_hubble_data.locs[which_bright], 
                                    x0, x1, 
                                    subimage_slen = 10, 
                                    add_colorbar = True, 
                                    global_fig = f, 
                                    diverging_cmap = True)

# Test my simulator

In [None]:
import simulated_datasets_lib
import fitsio

In [None]:
# load psf 
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = np.array([psf_r, psf_i])


In [None]:
sky_intensity = sdss_hubble_data.sdss_background.reshape(len(bands), -1).mean(1)

# background_bias = torch.Tensor([168., 222.])
# sky_intensity = sky_intensity + background_bias

simulator = simulated_datasets_lib.StarSimulator(psf = torch.Tensor(psf_og),
                                    slen = sdss_hubble_data.slen, 
                                    transpose_psf = False, 
                                    sky_intensity = sky_intensity)

In [None]:
simulator.sky_intensity

In [None]:
filter_by_bright = False

In [None]:
if filter_by_bright: 
    
    which_bright = sdss_hubble_data.fluxes[:, 0] > fmin
    
    _locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
    _fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)
    _n_stars = torch.Tensor([len(_locs[0])]).type(torch.LongTensor)
else: 
    _fluxes = sdss_hubble_data.fluxes.unsqueeze(0)
    _locs = sdss_hubble_data.locs.unsqueeze(0)
    _n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor)

In [None]:
_n_stars

In [None]:
plt.hist(sdss_dataset_lib.convert_nmgy_to_mag(_fluxes.flatten() / 
                                              sdss_hubble_data.nelec_per_nmgy.mean()), bins = 100);

In [None]:
recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes, 
                                                n_stars = _n_stars, 
                                                add_noise = False) 

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(recon_mean[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = recon_mean[0,i] - observed[i]
    foo = (residual / observed[i])
    im2 = axarr[2].matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

In [None]:
residual = (recon_mean[0,0] - observed[0])

In [None]:
residual[5:95, 5:95].median()

In [None]:
sdss_hubble_data.sdss_background.reshape(2, -1).mean(1)

In [None]:
plt.hist(residual.flatten().clamp(min = -2000), bins = 100);

In [None]:
### plot some subimages 
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

print([x0, x1])

plotting_utils.plot_subimage(axarr[0], 
                            observed[0], 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)


plotting_utils.plot_subimage(axarr[1], 
                            recon_mean[0, 0], 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[2], 
                            (recon_mean[0, 0] - observed[0]) / observed[0], 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f, 
                            diverging_cmap = True)

 # Estimate background

In [None]:
band = 1

In [None]:
obs = sdss_hubble_data.sdss_image.numpy()[band, 5:95, 5:95]
recon = recon_mean.squeeze(0)[band, 5:95, 5:95].numpy()

In [None]:
objective_fun = lambda background_bias : ((obs - (recon + background_bias))**2 / (recon + background_bias)).sum()

In [None]:
from scipy.optimize import minimize

In [None]:
import autograd

In [None]:
x0 = np.zeros(1)

In [None]:
estimate_background = True
if estimate_background: 
    out = minimize(fun = objective_fun, 
        x0 = x0, 
        jac = autograd.jacobian(objective_fun), 
        method = 'BFGS')

In [None]:
out

In [None]:
sky_intensity[band].numpy() + out.x

# Estimate power law

In [None]:
import autograd.numpy as anp

In [None]:
fluxes = sdss_hubble_data.fluxes.squeeze()[sdss_hubble_data.fluxes.squeeze() > 1e3].numpy()

In [None]:
negloglik = lambda log_alpha : - (log_alpha + anp.exp(log_alpha) * anp.log(1e3) - \
                                  (anp.exp(log_alpha) + 1) * anp.log(fluxes)).sum()

In [None]:
minimize(fun = negloglik, 
        x0 = np.log([0.5]), 
        jac = autograd.jacobian(negloglik), 
        method = 'BFGS')

In [None]:
np.exp(-0.27348425)

In [None]:
sim_fluxes = simulated_datasets_lib._draw_pareto(alpha = 0.5, f_min = 1e3, shape = (1000, ))

In [None]:
bins = plt.hist(np.log(fluxes));
plt.hist(np.log(sim_fluxes), alpha = 0.76, bins = bins[1]);

# Check distribution on image stamps

In [None]:
import image_utils

In [None]:
sdss_hubble_data.sdss_image.shape

In [None]:
image_stamps = \
    image_utils.tile_images(sdss_hubble_data.sdss_image.unsqueeze(0),
                            subimage_slen = 9,
                            step = 2)

In [None]:
image_stamps.shape

In [None]:
tile_coords = image_utils.get_tile_coords(sdss_hubble_data.sdss_image.shape[-1], 
                                          sdss_hubble_data.sdss_image.shape[-1],
                                        subimage_slen = 9, 
                                          step = 2);

In [None]:
sdss_hubble_data.fluxes.shape

In [None]:
subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > fmin].unsqueeze(0),
                                      sdss_hubble_data.fluxes[sdss_hubble_data.fluxes[:, 0] > fmin].unsqueeze(0),
                                      sdss_hubble_data.sdss_image.shape[-1],
                                      subimage_slen = 9,
                                      edge_padding = 3)

In [None]:
from torch.distributions.poisson import Poisson

In [None]:
poisson_distr = Poisson(rate = 0.4)

In [None]:
x = np.arange(0, 7)
h = plt.hist(n_stars, x)

plt.plot(torch.Tensor(h[1]), 
            h[0].sum() * torch.exp(poisson_distr.log_prob(torch.Tensor(h[1]))), 
           marker = 'x', color = 'red')