In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 20

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 1024

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = n_images,
                            use_fresh_data = False, 
                            add_noise = True)

In [None]:
simulated_dataset.images.shape

In [None]:
# observed image 
i = np.random.choice(n_images, 1)[0]

plt.matshow(simulated_dataset.images[i, 0, :, :])
plt.title('Observed image')

# # plot locations 
locs_i = simulated_dataset.locs[i]
n_stars_i = simulated_dataset.n_stars[i]
fluxes_i = simulated_dataset.fluxes[i]
locs_x = (locs_i[0:int(n_stars_i), 0]) * (simulated_dataset.slen - 1) 
locs_y = (locs_i[0:int(n_stars_i), 1]) * (simulated_dataset.slen - 1)

plt.scatter(x = locs_y, y = locs_x, c = 'b')


In [None]:
# Compare with Hubble data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
for i in range(10): 
    
    indx = np.random.choice(len(sdss_hubble_data), 1)
    
    
    fig, axarr = plt.subplots(1, 2, figsize=(16, 8))

    data = sdss_hubble_data[indx]

    vmin = np.min(data['image'].squeeze())
    vmax = np.max(data['image'].squeeze())
    
    # plot observed image 
    foo = axarr[0].matshow(data['image'].squeeze(), vmin=vmin, vmax=vmax)
    axarr[0].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')

    axarr[0].set_title('observed image; indx = {}'.format(indx))
    
    # use parameters, simulate image 
    image = simulated_dataset.draw_image_from_params(locs = torch.Tensor(data['locs']).unsqueeze(0), 
                                                     fluxes = torch.Tensor(data['fluxes']).unsqueeze(0), 
                                                     n_stars = torch.Tensor([data['n_stars']]), 
                                                     add_noise = False)
    
    
    axarr[1].matshow(image.squeeze().numpy() - simulated_dataset.sky_intensity + \
                     data['background'].squeeze(), vmin=vmin, vmax=vmax)
    axarr[1].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
             data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
    
    fig.colorbar(foo, ax=axarr.ravel().tolist())

In [None]:
plt.hist(sdss_hubble_data.fluxes[sdss_hubble_data.fluxes < 1e6]);

In [None]:
min(sdss_hubble_data.fluxes)

In [None]:
foo = simulated_datasets_lib._draw_pareto_maxed(1300, 1e6, alpha = 2, shape = (10000, ))

In [None]:
plt.hist(foo);

In [None]:
max(foo)