In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/3900/6/269/psField-003900-6-0269.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Draw data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['min_stars'] = 0
data_params['max_stars'] = 20

print(data_params)


In [None]:
max_stars = data_params['max_stars']

In [None]:
n_images = 1024

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_fit_file,
                            data_params,
                            n_stars = n_images,
                            use_fresh_data = False, 
                            add_noise = True)

In [None]:
simulated_dataset.images.shape

In [None]:
# observed image 
i = np.random.choice(n_images, 1)[0]

plt.matshow(simulated_dataset.images[i, 0, :, :])
plt.title('Observed image')

# # plot locations 
locs_i = simulated_dataset.locs[i]
n_stars_i = simulated_dataset.n_stars[i]
fluxes_i = simulated_dataset.fluxes[i]
locs_x = (locs_i[0:int(n_stars_i), 0]) * (simulated_dataset.slen - 1) 
locs_y = (locs_i[0:int(n_stars_i), 1]) * (simulated_dataset.slen - 1)

plt.scatter(x = locs_y, y = locs_x, c = 'b')


In [None]:
# Compare with Hubble data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

In [None]:
plt.hist(sdss_hubble_data.sdss_background_full.flatten()); 
print(sdss_hubble_data.sdss_background_full.flatten().mean())

In [None]:
indx = 98
data = sdss_hubble_data[indx]

plt.matshow(data['image'].squeeze())
plt.plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
         data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')

plt.title('image {}'.format(indx))

In [None]:
torch.Tensor(data['locs']).unsqueeze(0)

In [None]:
image = simulated_dataset.draw_image_from_params(locs = torch.Tensor(data['locs']).unsqueeze(0), 
                                         fluxes = 10**(22 - torch.Tensor(data['fluxes'])).unsqueeze(0), 
                                         n_stars = torch.Tensor([data['n_stars']]), 
                                         add_noise = False)

In [None]:
plt.matshow(image.squeeze() - simulated_dataset.sky_intensity)