In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf

from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

In [None]:
# load data
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
# the full image
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())

In [None]:
# the 100 x 100 subset in Portillos
plt.matshow(sdss_hubble_data.sdss_image.squeeze())

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())
plt.plot(sdss_hubble_data.locs_full_x1, 
         sdss_hubble_data.locs_full_x0, alpha = 0.2)

In [None]:
sdss_hubble_data.nelec_per_nmgy

# plot a few subimages

In [None]:
import plotting_utils

In [None]:
x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

In [None]:
plt.hist(sdss_hubble_data.locs.flatten())

In [None]:
f, axarr = plt.subplots(2, 3, figsize=(16, 12))

for i in range(6): 
    x0 = int(np.random.choice(x0_vec, 1))
    x1 = int(np.random.choice(x1_vec, 1))
    
    which_bright = sdss_hubble_data.fluxes > 1300.

    plotting_utils.plot_subimage(axarr[i // 3, i % 3], 
                                sdss_hubble_data.sdss_image.squeeze(), 
                                None, 
                                sdss_hubble_data.locs[which_bright], 
                                x0, x1, 
                                subimage_slen = 10, 
                                add_colorbar = True, 
                                global_fig = f)

# The images shown in Portillos

In [None]:
f, axarr = plt.subplots(2, 2, figsize=(16, 12))

# image 1
x0 = 53
x1 = 70

plotting_utils.plot_subimage(axarr[0, 0], 
                            sdss_hubble_data.sdss_image.squeeze(), 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

axarr[0, 0].invert_yaxis() 


# image 2
x0 = 41
x1 = 23

plotting_utils.plot_subimage(axarr[0, 1], 
                            sdss_hubble_data.sdss_image.squeeze(), 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

axarr[0, 1].invert_yaxis() 


# image 3
x0 = 31
x1 = 83

plotting_utils.plot_subimage(axarr[1, 0], 
                            sdss_hubble_data.sdss_image.squeeze(), 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

axarr[1, 0].invert_yaxis() 

# image 4
x0 = 32
x1 = 64

plotting_utils.plot_subimage(axarr[1, 1], 
                            sdss_hubble_data.sdss_image.squeeze(), 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

axarr[1, 1].invert_yaxis() 

# Test my simulator

In [None]:
import simulated_datasets_lib

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=str(sdss_hubble_data.psf_file), 
                                    slen = sdss_hubble_data.slen, 
                                    sky_intensity = 0.)

In [None]:
recon_mean = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                                fluxes = sdss_hubble_data.fluxes.unsqueeze(0), 
                                n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor), 
                                add_noise = False)

In [None]:
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

observed = sdss_hubble_data.sdss_image.squeeze() - sdss_hubble_data.sdss_background.squeeze()
im0 = axarr[0].matshow(observed)
f.colorbar(im0, ax=axarr[0])

im1 = axarr[1].matshow(recon_mean.squeeze())
f.colorbar(im1, ax=axarr[1])

residual = recon_mean.squeeze() - observed
im2 = axarr[2].matshow(residual / observed)
f.colorbar(im2, ax=axarr[2])

In [None]:
### plot some subimages 
f, axarr = plt.subplots(1, 3, figsize=(16, 4))

x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

x0 = int(np.random.choice(x0_vec, 1))
x1 = int(np.random.choice(x1_vec, 1))

print([x0, x1])

plotting_utils.plot_subimage(axarr[0], 
                            observed, 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)


plotting_utils.plot_subimage(axarr[1], 
                            recon_mean.squeeze(), 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

plotting_utils.plot_subimage(axarr[2], 
                            residual/observed, 
                            None, 
                            sdss_hubble_data.locs[which_bright], 
                            x0, x1, 
                            subimage_slen = 10, 
                            add_colorbar = True, 
                            global_fig = f)

In [None]:
plt.hist((residual / observed).flatten(), bins = 100);

In [None]:
simulator.psf.max()

In [None]:
sdss_hubble_data.sdss_data[0]['calibration'] 

In [None]:
x1_loc[which_pixels]