In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib
import image_utils

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2566/6/65/psField-002566-6-0065.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# check psf

In [None]:
psf = sdss_psf.psf_at_points(0, 0, psf_fit_file=psf_fit_file)

In [None]:
plt.matshow(psf)

In [None]:
psf_expanded = simulated_datasets_lib._expand_psf(psf, 101)

In [None]:
plt.matshow(psf_expanded)

In [None]:
psf_expanded[50, 50]

In [None]:
psf[25, 25]

In [None]:
np.max(psf)

# Check locations

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file, slen = 11, sky_intensity=686)

In [None]:
locs = torch.Tensor([[[2/11, 11/11]]])
n_stars = torch.Tensor([1]).type(torch.LongTensor)
fluxes = torch.Tensor([[1000]])

In [None]:
mean = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = False)
plt.matshow(mean.squeeze())

# Draw images 

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 101
data_params['min_stars'] = 50
data_params['max_stars'] = 50
data_params['alpha'] = 0.5

print(data_params)


In [None]:
slen = data_params['slen']

In [None]:
simulated_dataset = simulated_datasets_lib.load_dataset_from_params(psf_fit_file, data_params, 
                                                n_images = 2,
                                                # for testing, turn off noise
                                                add_noise = False)

In [None]:
plt.matshow(simulated_dataset.images[0].squeeze())

In [None]:
plt.matshow(simulated_dataset.images[1].squeeze())

In [None]:
images = simulated_dataset.images
subimage_slen = 11
step = 5

# Test batching off images

In [None]:
images_batched, tile_coords, nx_patches, ny_patches, n_patches = \
    image_utils.tile_images(images, subimage_slen, step, 
                            return_tile_coords=True)

## Test tile coordinates 

In [None]:
for i in range(images_batched.shape[0]):
    
    b = i // n_patches
    
    x0 = tile_coords[i, 0]
    x1 = tile_coords[i, 1]
    
    foo = images[b].squeeze()[x0:(x0 + 11), x1:(x1 + 11)]

    assert np.all(images_batched[i].squeeze() == foo)

# Test extraction of parameters

In [None]:
subimage_locs, subimage_fluxes, n_stars = \
    image_utils.get_params_in_patches(tile_coords, 
                      simulated_dataset.locs, 
                      simulated_dataset.fluxes, 
                      slen, subimage_slen, 
                      n_patches)

In [None]:
# get reconstruction with the subimage parameters 
patch_simulator = simulated_datasets_lib.StarSimulator(psf_fit_file, 
                                                       subimage_slen, 
                                                       data_params['sky_intensity'])

recon_means = patch_simulator.draw_image_from_params(subimage_locs, subimage_fluxes, 
                                                     n_stars, add_noise = False)

In [None]:
for i in range(10): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 6))
    indx = int(np.random.choice(images_batched.shape[0], 1))
    
    x0 = tile_coords[indx, 0].float()
    x1 = tile_coords[indx, 1].float()
    
    which_nonzero = subimage_locs[indx, :, 0] > 0
    
    im1 = axarr[0].matshow(images_batched[indx].squeeze())
    axarr[0].scatter(subimage_locs[indx, 0:n_stars[indx], 1] * (subimage_slen - 1), 
                subimage_locs[indx, 0:n_stars[indx], 0] * (subimage_slen - 1))
    f.colorbar(im1, ax = axarr[0])
    
    axarr[0].set_title('n_stars: {}\n'.format(n_stars[indx]))
    
    im2 = axarr[1].matshow(recon_means[indx].squeeze())
    f.colorbar(im2, ax = axarr[1])
    
    im3 = axarr[2].matshow(images_batched[indx].squeeze() - recon_means[indx].squeeze())
    f.colorbar(im3, ax = axarr[2])

In [None]:
# Test our actual simulator 

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 101
data_params['min_stars'] = 400
data_params['max_stars'] = 400
data_params['alpha'] = 0.5

print(data_params)


In [None]:
simulated_dataset = simulated_datasets_lib.load_dataset_from_params(psf_fit_file, data_params, 
                                                n_images = 10,
                                                add_noise = True)

In [None]:
plt.matshow(simulated_dataset.images[0].squeeze())

In [None]:
images_batched, tile_coords, nx_patches, ny_patches, n_patches = \
    image_utils.tile_images(simulated_dataset.images, subimage_slen, step, 
                                        return_tile_coords = True)

In [None]:
subimage_locs, subimage_fluxes, n_stars = \
    image_utils.get_params_in_patches(tile_coords, 
                                      simulated_dataset.locs, 
                                      simulated_dataset.fluxes, 
                                      slen, subimage_slen, n_patches)

In [None]:
for i in range(10): 
    indx = int(np.random.choice(images_batched.shape[0], 1))
    
    x0 = tile_coords[indx, 0].float()
    x1 = tile_coords[indx, 1].float()
    
    which_nonzero = subimage_locs[indx, :, 0] > 0
    
    plt.matshow(images_batched[indx].squeeze())
    plt.scatter(subimage_locs[indx, 0:n_stars[indx], 1] * (subimage_slen - 1), 
                subimage_locs[indx, 0:n_stars[indx], 0] * (subimage_slen - 1))
