In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import starnet_lib
import inv_kl_objective_lib as inv_kl_lib
import image_statistics_lib

import plotting_utils

np.random.seed(34534)


# Load the data

In [None]:
fmin = 1000

In [None]:
bands = [2, 3]

In [None]:
if len(bands) == 2: 
    background_bias = torch.Tensor([168., 222.])
elif len(bands) == 1: 
    background_bias = torch.Tensor([168.])
else: 
    assert 1 == 2, 'not implemented error'
    

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = bands, 
                                                  background_bias = background_bias)

# image 
full_image = sdss_hubble_data.sdss_image
full_background = sdss_hubble_data.sdss_background 

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > fmin)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


# Simulator

In [None]:
import fitsio

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

if len(bands) == 2: 
    psf_og = np.array([psf_r, psf_i])
elif len(bands) == 1: 
    psf_og = np.array([psf_r])
else: 
    assert 1 == 2, 'not implemented error'
    
sky_intensity = full_background.reshape(full_background.shape[0], -1).mean(1)


In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf=psf_og, 
                                                slen = full_image.shape[-1], 
                                                transpose_psf = False,
                                                sky_intensity = sky_intensity)



# Load and clean DAOPHOT results

In [None]:
def load_daophot_results(data_file, sdss_hubble_data, include_i_band = False): 
    
    x0 = sdss_hubble_data.x0
    x1 = sdss_hubble_data.x1 
    
    daophot_file = np.loadtxt(data_file)
    
    # load desired quantities
    daophot_ra = daophot_file[:, 4]
    daophot_decl = daophot_file[:, 5]
    daophot_mags = daophot_file[:, 22]
    
    # get pixel coordinates
    pix_coords = sdss_hubble_data.wcs.wcs_world2pix(daophot_ra, daophot_decl, 0, ra_dec_order = True)
    
    # get locations inside our square
    which_locs = (pix_coords[1] > x0) & (pix_coords[1] < (x0 + sdss_hubble_data.slen - 1)) & \
                        (pix_coords[0] > x1) & (pix_coords[0] < (x1 + sdss_hubble_data.slen - 1))
    
    # scale between zero and ones
    daophot_locs0 = (pix_coords[1][which_locs] - x0) / (sdss_hubble_data.slen - 1)
    daophot_locs1 = (pix_coords[0][which_locs] - x1) / (sdss_hubble_data.slen - 1)
    daophot_locs = torch.Tensor(np.array([daophot_locs0, daophot_locs1]).transpose())
    
    # get fluxes
    daophot_fluxes = sdss_dataset_lib.convert_mag_to_nmgy(daophot_mags[which_locs]) * \
                        sdss_hubble_data.nelec_per_nmgy_full.mean() * \
                        sdss_hubble_data.fudge_conversion
    
    if include_i_band: 
        daophot_i_mags = daophot_file[:, 29]
        daophot_i_fluxes = sdss_dataset_lib.convert_mag_to_nmgy(daophot_i_mags[which_locs]) * \
                        sdss_hubble_data.nelec_per_nmgy_full.mean() * \
                        sdss_hubble_data.fudge_conversion
        daophot_fluxes = torch.Tensor([daophot_fluxes, daophot_i_fluxes]).transpose(0, 1)
    
    else: 
        daophot_fluxes = torch.Tensor(daophot_fluxes).unsqueeze(1)
    
    return daophot_locs, daophot_fluxes

In [None]:
def align_daophot_locs(daophot_locs, daophot_fluxes, true_locs, slen): 
    
    perm = image_statistics_lib.get_locs_error(daophot_locs, true_locs).argmin(0)
    which_brightest = torch.nonzero(torch.log10(daophot_fluxes).squeeze() > 4.5).squeeze()
    locs_err = (daophot_locs[which_brightest] - true_locs[perm][which_brightest]) * (slen - 1)
    
    bias_x1 = locs_err[:, 1].median() / (slen - 1)
    bias_x0 = locs_err[:, 0].median() / (slen - 1)
    
    return bias_x0, bias_x1

In [None]:
daophot_locs, daophot_fluxes = load_daophot_results('../daophot_results/m2_2583.phot.txt', 
                                                    sdss_hubble_data, include_i_band = True)

bias_x0, bias_x1 = align_daophot_locs(daophot_locs, daophot_fluxes[:, 0], true_locs, sdss_hubble_data.slen)

daophot_locs[:, 0] -= bias_x0
daophot_locs[:, 1] -= bias_x1

# after filtering, some locs are less than 0 or
which_filter = (daophot_locs[:, 0] > 0) & (daophot_locs[:, 0] < 1) & \
                (daophot_locs[:, 1] > 0) & (daophot_locs[:, 1] < 1)
    
daophot_locs = daophot_locs[which_filter]
daophot_fluxes = daophot_fluxes[which_filter]

# Get DAOPHOT reconstruction

In [None]:
daophot_recon = simulator.draw_image_from_params(locs = daophot_locs.unsqueeze(0), 
                                fluxes = daophot_fluxes.unsqueeze(0), 
                                n_stars = torch.Tensor([len(daophot_fluxes)]).type(torch.long), 
                                add_noise = False)

In [None]:
plt.matshow(sdss_hubble_data.sdss_image[0])

In [None]:
for i in range(len(bands)): 
    plt.matshow(daophot_recon[0, i])
    plt.colorbar()

In [None]:
for i in range(len(bands)): 
    foo = (daophot_recon[0, i].squeeze() - sdss_hubble_data.sdss_image[i]) / sdss_hubble_data.sdss_image[i]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(12, 3))

np.random.seed(23423)
for i in range(3): 

    x0 = int(np.random.choice(100, 1))
    x1 = int(np.random.choice(100, 1))
    plotting_utils.plot_subimage(axarr[i], sdss_hubble_data.sdss_image[0],
                                         daophot_locs,
                                         sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > 1000], 
                                         x0, x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)
    

fig.tight_layout()

# plt.savefig('../../qualifying_exam_slides/figures/daophot_results.png')

In [None]:
for i in range(3): 
    fig, axarr = plt.subplots(1, 2, figsize=(12, 6))

    x0 = int(np.random.choice(100, 1))
    x1 = int(np.random.choice(100, 1))
    plotting_utils.plot_subimage(axarr[0], sdss_hubble_data.sdss_image[0],
                                         daophot_locs,
                                         sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > 1000], 
                                         x0, x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         global_fig = fig)
    
    plotting_utils.plot_subimage(axarr[1], foo,
                                         daophot_locs,
                                         sdss_hubble_data.locs[sdss_hubble_data.fluxes[:, 0] > 1000], 
                                         x0, x1, subimage_slen = 10, 
                                        add_colorbar = True, 
                                         diverging_cmap = True, 
                                         global_fig = fig)

In [None]:
filter_daophot_min = true_fluxes[:, 0] > daophot_fluxes[:, 0].min()

completeness, tpr, _, _ = \
    image_statistics_lib.get_summary_stats(daophot_locs, 
                                           true_locs[filter_daophot_min, :], 
                                           101, 
                                           daophot_fluxes[:, 0], 
                                           true_fluxes[filter_daophot_min, 0], pad = 0, slack = 0.5)

In [None]:
(completeness, tpr)

# Color magnitude diagram on subset of M2

In [None]:
daophot_mag = sdss_dataset_lib.convert_nmgy_to_mag(daophot_fluxes[:, 0] /\
                                                       sdss_hubble_data.nelec_per_nmgy_full.mean())

daophot_i_mag = sdss_dataset_lib.convert_nmgy_to_mag(daophot_fluxes[:, 1] /\
                                                       sdss_hubble_data.nelec_per_nmgy_full.mean())

In [None]:
plt.hist(daophot_mag);

In [None]:
plt.hist(daophot_mag - daophot_i_mag);

In [None]:
plt.plot(daophot_mag - daophot_i_mag, -daophot_mag, '+')

# Color magnitude diagram on full M2

In [None]:
daophot_txt = np.loadtxt('../daophot_results/m2_2583.phot.txt')

In [None]:
daophot_mag_full = daophot_txt[:, 22]
daophot_i_mag_full = daophot_txt[:, 29]

which_keep = (daophot_mag_full < 99) & (daophot_i_mag_full < 99)

daophot_mag_full = daophot_mag_full[which_keep]
daophot_i_mag_full = daophot_i_mag_full[which_keep]

In [None]:
plt.hist(daophot_mag_full);

In [None]:
plt.hist(daophot_mag_full - daophot_i_mag_full);

In [None]:
fic_sequence = np.loadtxt('../daophot_results/an_table_28.txt')

In [None]:
plt.plot(daophot_mag_full - daophot_i_mag_full, -daophot_mag_full, '+', alpha = 0.5)
plt.plot(fic_sequence[:, 3] - fic_sequence[:, 2], -fic_sequence[:, 0], color = 'red')