In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2566/6/65/psField-002566-6-0065.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 20
data_params['alpha'] = 0.5

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 1024

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = n_images,
                            use_fresh_data = False, 
                            add_noise = True)

In [None]:
simulated_dataset.images.shape

In [None]:
# observed image 
i = np.random.choice(n_images, 1)[0]

plt.matshow(simulated_dataset.images[i, 0, :, :])
plt.title('Observed image')

# # plot locations 
locs_i = simulated_dataset.locs[i]
n_stars_i = simulated_dataset.n_stars[i]
fluxes_i = simulated_dataset.fluxes[i]
locs_x = (locs_i[0:int(n_stars_i), 0]) * (simulated_dataset.slen - 1) 
locs_y = (locs_i[0:int(n_stars_i), 1]) * (simulated_dataset.slen - 1)

plt.scatter(x = locs_y, y = locs_x, c = 'b')


# check psf

In [None]:
plt.matshow(np.log(simulated_dataset.psf_full))

In [None]:
plt.matshow(torch.log(simulated_dataset.psf))

In [None]:
locs = torch.rand(20, 2)

locs[0] = torch.Tensor([-0.3, -0.3])
one_star_images = simulated_datasets_lib.plot_one_star(locs = locs, 
                                                        psf = simulated_dataset.psf)

In [None]:
# locs = torch.rand(20, 2)
# one_star_images = simulated_datasets_lib.plot_one_star(locs = locs, 
#                                                         psf = simulated_dataset.psf)

In [None]:
# for i in range(one_star_images.shape[0]): 
#     fig, axarr = plt.subplots(1, 2, figsize=(16, 8))

#     vmin = torch.min(one_star_images[i])
#     vmax = torch.max(one_star_images[i])
    
#     # plot observed image 
#     foo = axarr[0].matshow(one_star_images[i].squeeze(), vmin=vmin, vmax=vmax)
#     axarr[0].plot(locs[i, 1]* (sdss_hubble_data.slen - 1), 
#                   locs[i, 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')

#     axarr[0].set_title('locs = {}'.format(locs[i]))
    
    
    
#     axarr[1].matshow(simulated_dataset.psf, vmin=vmin, vmax=vmax)
# #     axarr[1].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
# #              data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
#     axarr[1].set_title('simulated image; flux rat = {:06f}'.format(torch.max(simulated_dataset.psf) / vmax))
    
#     fig.colorbar(foo, ax=axarr.ravel().tolist())
    

# Compare with Hubble data

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65)

## Check full image

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())
plt.plot(sdss_hubble_data.locs_x1, sdss_hubble_data.locs_x0, alpha = 0.2)

In [None]:
# check the counts matrix
plt.matshow(sdss_hubble_data.counts_mat)
plt.colorbar()

In [None]:
# these are the tiles we kept
plt.matshow((sdss_hubble_data.counts_mat > 0) & (sdss_hubble_data.counts_mat < sdss_hubble_data.max_detections))

### Check out some sample images

In [None]:
for i in range(0, 10): 
    
    indx = np.random.choice(len(sdss_hubble_data), 1)
        
    data = sdss_hubble_data[indx]
        
    fig, axarr = plt.subplots(1, 2, figsize=(16, 8))
    
    true_image = data['image'].squeeze() - data['background'].squeeze()
    vmin = np.min(true_image)
    vmax = np.max(true_image)
    
    # plot observed image 
    foo = axarr[0].matshow(true_image, vmin=vmin, vmax=vmax)
    axarr[0].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
    axarr[0].plot(data['locs_border'][:, 1] * (sdss_hubble_data.slen - 1), 
                 data['locs_border'][:, 0] * (sdss_hubble_data.slen - 1), '.', color = 'r')

    axarr[0].set_title('observed image; n_stars = {}; indx = {}'.format(data['n_stars'], indx))
    
    # use parameters, simulate image 
    simulated_image = \
        simulated_dataset.draw_image_from_params(locs = torch.Tensor(data['locs']).unsqueeze(0), 
                                                 fluxes = torch.Tensor(data['fluxes']).unsqueeze(0), 
                                                 n_stars = torch.Tensor([data['n_stars']]), 
                                                 add_noise = False)
    
    simulated_image = simulated_image.squeeze().numpy() - data['background'].squeeze()
    axarr[1].matshow(simulated_image, vmin=vmin, vmax=vmax)
    axarr[1].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
    axarr[1].set_title('simulated image; flux rat = {:06f}'.format(vmax / np.max(simulated_image)))
    
    fig.colorbar(foo, ax=axarr.ravel().tolist())
    

# understanding flux distributions

In [None]:
true_fluxes = sdss_hubble_data.fluxes[sdss_hubble_data.which_bright]

In [None]:
plt.hist(np.log10(sdss_hubble_data.fluxes), bins = 100);

In [None]:
plt.hist(np.log10(true_fluxes))

In [None]:
min(true_fluxes)

In [None]:
foo = simulated_datasets_lib._draw_pareto_maxed(1300, 1e6, alpha = 0.5, shape = (len(true_fluxes), ))

In [None]:
plt.hist(np.log10(foo), alpha = 0.5); 
plt.hist(np.log10(true_fluxes), alpha = 0.5);

In [None]:
locs = torch.rand(20, 2)

locs[0] = torch.Tensor([-0.5, -0.5])
one_star_images = simulated_datasets_lib.plot_one_star(locs = locs, 
                                                        psf = simulated_dataset.psf)

In [None]:
plt.matshow(one_star_images.squeeze()[0])