In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import image_statistics_lib

import plotting_utils

np.random.seed(34534)


# Load the data

In [None]:
fmin = 1000

In [None]:
bands = [2]

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = bands)

# image 
full_image = sdss_hubble_data.sdss_image
full_background = sdss_hubble_data.sdss_background 

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > fmin)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


# Simulator

In [None]:
import fitsio

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

if len(bands) == 2: 
    psf_og = np.array([psf_r, psf_i])
elif len(bands) == 1: 
    psf_og = np.array([psf_r])
else: 
    assert 1 == 2, 'not implemented error'
    
sky_intensity = full_background.reshape(full_background.shape[0], -1).mean(1)


In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf=psf_og, 
                                                slen = full_image.shape[-1], 
                                                transpose_psf = False,
                                                sky_intensity = sky_intensity)



# Load and clean DAOPHOT results

In [None]:
len0 = sdss_hubble_data.sdss_image_full.shape[1]
len1 = sdss_hubble_data.sdss_image_full.shape[2]

In [None]:
daophot_file = np.loadtxt('./m2_2583.phot.txt')

In [None]:
daophot_ra = daophot_file[:, 4]
daophot_decl = daophot_file[:, 5]
daophot_mags = daophot_file[:, 22]

In [None]:
pix_coords = sdss_hubble_data.wcs.wcs_world2pix(daophot_ra, daophot_decl, 0, ra_dec_order = True)
# pix_coords = (daophot_file[:, 21], daophot_file[:, 20])

pix_coords[1] = pix_coords[1] - 0.38
pix_coords[0] = pix_coords[0] - 0.56

In [None]:
x0 = 630
x1 = 310

In [None]:
which_locs = (pix_coords[1] > x0) & (pix_coords[1] < (x0 + 100)) & \
                        (pix_coords[0] > x1) & (pix_coords[0] < (x1 + 100))
    


In [None]:
which_locs.sum()

In [None]:
daophot_locs0 = (pix_coords[1][which_locs] - x0)/ 100
daophot_locs1 = (pix_coords[0][which_locs] - x1)/ 100

In [None]:
daophot_locs = torch.Tensor(np.array([daophot_locs0, daophot_locs1]).transpose())

In [None]:
daophot_locs.max()

In [None]:
daophot_locs.min()

In [None]:
daophot_fluxes = sdss_dataset_lib.convert_mag_to_nmgy(daophot_mags[which_locs]) * \
                    sdss_hubble_data.nelec_per_nmgy_full.mean() * \
                    sdss_hubble_data.fudge_conversion
        
daophot_fluxes = torch.Tensor(daophot_fluxes).unsqueeze(1)

In [None]:
plt.hist(daophot_fluxes.squeeze())

# Get DAOPHOT reconstruction

In [None]:
daophot_recon = simulator.draw_image_from_params(locs = daophot_locs.unsqueeze(0), 
                                fluxes = daophot_fluxes.unsqueeze(0), 
                                n_stars = torch.Tensor([len(daophot_fluxes)]).type(torch.long), 
                                add_noise = False)

In [None]:
plt.matshow(daophot_recon.squeeze())

In [None]:
foo = (daophot_recon.squeeze() - sdss_hubble_data.sdss_image[0]) / sdss_hubble_data.sdss_image[0]
plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
plt.colorbar()

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(12, 3))

np.random.seed(23423)
for i in range(3): 

    x0 = int(np.random.choice(100, 1))
    x1 = int(np.random.choice(100, 1))
    plotting_utils.plot_subimage(axarr[i], sdss_hubble_data.sdss_image[0],
                                         daophot_locs,
                                         sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > 1000], 
                                         x0, x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)
    

fig.tight_layout()

plt.savefig('../../qualifying_exam_slides/figures/daophot_results.png')

In [None]:

for i in range(3): 
    fig, axarr = plt.subplots(1, 2, figsize=(12, 6))

    x0 = int(np.random.choice(100, 1))
    x1 = int(np.random.choice(100, 1))
    plotting_utils.plot_subimage(axarr[0], sdss_hubble_data.sdss_image[0],
                                         daophot_locs,
                                         sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > 1000], 
                                         x0, x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)
    
    plotting_utils.plot_subimage(axarr[1], foo,
                                         daophot_locs,
                                         sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > 1000], 
                                         x0, x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         diverging_cmap = True, 
                                         global_fig = fig)

In [None]:
my_completeness1, my_tpr1, my_complete_bool1, my_tpr_bool = \
    image_statistics_lib.get_summary_stats(daophot_locs, 
                                           true_locs, 
                                           101, 
                                           daophot_fluxes[:, 0], 
                                           true_fluxes[:, 0], pad = 0, slack = 0.5)

In [None]:
my_completeness1

In [None]:
my_tpr1

In [None]:
my_complete_bool1.sum()

In [None]:
true_locs.shape

In [None]:
my_complete_bool1.shape

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(12, 6))

x0 = int(np.random.choice(100, 1))
x1 = int(np.random.choice(100, 1))
    
subimage_slen = 10

plotting_utils.plot_subimage(axarr[0], sdss_hubble_data.sdss_image[0],
                                     daophot_locs,
                                     sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > 1000], 
                                     x0, x1, subimage_slen = 10, 
                                    add_colorbar = True, 
                                     global_fig = fig)


# true locations that I missed
_locs = true_locs[my_complete_bool1 == 0] * (full_image.shape[-1] - 1)
which_locs = (_locs[:, 0] > x0) & \
                (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
                (_locs[:, 1] > x1) & \
                (_locs[:, 1] < (x1 + subimage_slen - 1))
__locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
axarr[0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'o')

# # estimated locations that were false
# _locs = est_locs1[tpr1_bool == 0] * (full_image.shape[-1] - 1)
# which_locs = (_locs[:, 0] > x0) & \
#                 (_locs[:, 0] < (x0 + subimage_slen - 1)) & \
#                 (_locs[:, 1] > x1) & \
#                 (_locs[:, 1] < (x1 + subimage_slen - 1))
# __locs = (_locs[which_locs, :] - torch.Tensor([[x0, x1]])) 
# axarr[0, 0].scatter(__locs[:, 1], __locs[:, 0], color = 'orange', marker = 'x')



In [None]:
perm = image_statistics_lib.get_locs_error(daophot_locs, true_locs).argmin(0)

In [None]:
which_brightest = torch.nonzero(torch.log10(daophot_fluxes).squeeze() > 4.5).squeeze()

In [None]:
locs_err = (daophot_locs[which_brightest] - true_locs[perm][which_brightest]) * 100.

In [None]:
locs_err[:, 1].median()

In [None]:
locs_err[:, 0].median()

In [None]:
plt.hist(locs_err[:, 1], bins = 30);

In [None]:
plt.hist(locs_err[:, 0], bins = 30);

In [None]:
true_locs