In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2566/6/65/psField-002566-6-0065.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
simulated_datasets_lib.StarSimulator(psf_fit_file=psf_fit_file, 
                                    )

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 11
data_params['min_stars'] = 0
data_params['max_stars'] = 6
data_params['alpha'] = 0.5

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 1024

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_images = n_images,
                            add_noise = True)

In [None]:
_, axarr = plt.subplots(2, 5, figsize=(18, 8))
for i in range(0, 10): 
    
    i1 = int(np.floor(i / 5))
    i2 = i % 5
    
    indx = int(np.random.choice(n_images, 1))
    data = simulated_dataset[indx]
        
    # image 
    axarr[i1, i2].matshow(data['image'][0, :, :] - data['background'][0, :, :])
    axarr[i1, i2].set_title('n_stars: {}\n'.format(data['n_stars']))
    
    # plot locations 
    locs_i = data['locs']
    n_stars_i = data['n_stars']
    locs_y = (locs_i[0:int(n_stars_i), 0]) * (simulated_dataset.slen - 1) 
    locs_x = (locs_i[0:int(n_stars_i), 1]) * (simulated_dataset.slen - 1)
    
    axarr[i1, i2].scatter(x = locs_x, y = locs_y, color = 'b')


# check psf

In [None]:
plt.matshow(np.log(simulated_dataset.simulator.psf_full))

In [None]:
plt.matshow(simulated_dataset.simulator.psf, vmin = 0., vmax = 0.1)

In [None]:
n_trials = 10

_n_stars = (torch.ones(n_trials) * 1).type(torch.LongTensor)
_fluxes = torch.ones(n_trials, max_stars) * simulated_dataset.f_min * 100

_locs = torch.zeros(n_trials, max_stars, 2) + 0.5

diff = -1/11
for i in range(n_trials): 
    _locs[i, :, 0] = 0.5 + diff
    _locs[i, :, 1] = 0.5 + diff
    
    diff += 1/33

In [None]:
_images = simulated_dataset.simulator.draw_image_from_params(_locs, _fluxes, _n_stars, add_noise = False)

In [None]:
fig, axarr = plt.subplots(2, 5, figsize=(16, 8))

for i in range(n_trials): 
    axarr[i // 5, i % 5].matshow(_images[i, 0, :, :])
    axarr[i // 5, i % 5].scatter(_locs[i, 0, 1] * (_images.shape[-1] - 1), 
                _locs[i, 0, 0] * (_images.shape[-1] - 1))
    
    axarr[i // 5, i % 5].set_title('loc: {}\n'.format(_locs[i, 0]))
    
fig.tight_layout()

In [None]:
_images[0, 0, :, :] - _images[1, 0, :, :]

# check out edge effects

# Compare with Hubble data

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   slen = 11, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65, 
                                                max_detections = 20)

## Check full image

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze())
plt.plot(sdss_hubble_data.locs_x1, sdss_hubble_data.locs_x0, alpha = 0.2)

In [None]:
# check the counts matrix
plt.matshow(sdss_hubble_data.counts_mat)
plt.colorbar()

In [None]:
# these are the tiles we kept
plt.matshow((sdss_hubble_data.counts_mat > 0) & 
            (sdss_hubble_data.counts_mat < sdss_hubble_data.max_detections))

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze()[0:(45 * 11), (94 * 11):(150 * 11)])

In [None]:
bool_mat = (sdss_hubble_data.counts_mat > 0) & \
            (sdss_hubble_data.counts_mat < sdss_hubble_data.max_detections)
    
plt.matshow(bool_mat[0:45, 94:150])

In [None]:
27 * 11

In [None]:
138 * 11

### Check out some sample images

In [None]:
for i in range(0, 10): 
    
    indx = np.random.choice(len(sdss_hubble_data), 1)
        
    data = sdss_hubble_data[indx]
        
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    
    true_image = data['image'].squeeze() - data['background'].squeeze()
    vmin = np.min(true_image)
    vmax = np.max(true_image)
    
    # get dim stars as well ... 
    x0 = sdss_hubble_data.tile_coords[indx, 0]
    x1 = sdss_hubble_data.tile_coords[indx, 1]
    
    _, _, _, _, locs_dim, fluxes_dim = \
            sdss_hubble_data._get_hubble_params_in_patch(x0, x1, sdss_hubble_data.slen, 
                                                        return_dim_stars = True)
    
    # plot observed image 
    foo = axarr[0].matshow(true_image, vmin=vmin, vmax=vmax)
    axarr[0].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
    border_filter = data['fluxes_border'] > 0

    axarr[0].plot(data['locs_border'][border_filter, 1] * (sdss_hubble_data.slen - 1), 
                 data['locs_border'][border_filter, 0] * (sdss_hubble_data.slen - 1), '.', color = 'r')
    axarr[0].plot(locs_dim[:, 1] * (sdss_hubble_data.slen - 1), 
                 locs_dim[:, 0] * (sdss_hubble_data.slen - 1), '.', color = 'g')

    axarr[0].set_title('observed; n_stars = {}; indx = {}\n'.format(data['n_stars'], indx))
    
    # use parameters, simulate image 
    simulated_image = \
        simulated_dataset.draw_image_from_params(locs = torch.Tensor(data['locs']).unsqueeze(0), 
                                                 fluxes = torch.Tensor(data['fluxes']).unsqueeze(0), 
                                                 n_stars = torch.Tensor([data['n_stars']]), 
                                                 add_noise = False)
        
    simulated_image = simulated_image.squeeze().numpy() - simulated_dataset.sky_intensity
    axarr[1].matshow(simulated_image, vmin=vmin, vmax=vmax)
    axarr[1].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
    axarr[1].set_title('simulated; flux rat = {:06f}\n'.format(vmax / np.max(simulated_image)))
    
    fig.colorbar(foo, ax=[axarr[0], axarr[1]])
    
    foo2 = axarr[2].matshow(((true_image - simulated_image) > 0) * 1)
    fig.colorbar(foo2, ax=[axarr[2]])

# understanding flux distributions

In [None]:
true_fluxes = sdss_hubble_data.fluxes[sdss_hubble_data.which_bright]

In [None]:
plt.hist(np.log10(sdss_hubble_data.fluxes), bins = 100);

In [None]:
plt.hist(np.log10(true_fluxes))

In [None]:
min(true_fluxes)

In [None]:
foo = simulated_datasets_lib._draw_pareto_maxed(1300, 1e6, alpha = 0.5, shape = (len(true_fluxes), ))

In [None]:
plt.hist(np.log10(foo), alpha = 0.5); 
plt.hist(np.log10(true_fluxes), alpha = 0.5);

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   slen = 30, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65, 
                                                    max_detections = 200)

In [None]:
for i in range(0, 10): 
    
    indx = np.random.choice(len(sdss_hubble_data), 1)
        
    data = sdss_hubble_data[indx]
        
    fig, axarr = plt.subplots(1, 2, figsize=(12, 6))
    
    true_image = data['image'].squeeze() - data['background'].squeeze()
    vmin = np.min(true_image)
    vmax = np.max(true_image)
    
    # plot observed image 
    foo = axarr[0].matshow(true_image, vmin=vmin, vmax=vmax)
    axarr[0].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
    axarr[0].plot(data['locs_border'][:, 1] * (sdss_hubble_data.slen - 1), 
                 data['locs_border'][:, 0] * (sdss_hubble_data.slen - 1), '.', color = 'r')

    axarr[0].set_title('observed image; n_stars = {}; indx = {}'.format(data['n_stars'], indx))
    
#     # use parameters, simulate image 
#     simulated_image = \
#         simulated_dataset.draw_image_from_params(locs = torch.Tensor(data['locs']).unsqueeze(0), 
#                                                  fluxes = torch.Tensor(data['fluxes']).unsqueeze(0), 
#                                                  n_stars = torch.Tensor([data['n_stars']]), 
#                                                  add_noise = False)
    
#     simulated_image = simulated_image.squeeze().numpy() - data['background'].squeeze()
#     axarr[1].matshow(simulated_image, vmin=vmin, vmax=vmax)
#     axarr[1].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
#                  data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
#     axarr[1].set_title('simulated image; flux rat = {:06f}'.format(vmax / np.max(simulated_image)))
    
#     fig.colorbar(foo, ax=axarr.ravel().tolist())
