In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib
import image_utils

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
# load PSF
psf_fit_file = '../../celeste_net/sdss_stage_dir/2566/6/65/psField-002566-6-0065.fit'
print('psf file: \n', psf_fit_file)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# check psf

In [None]:
psf = sdss_psf.psf_at_points(0, 0, psf_fit_file=psf_fit_file)

In [None]:
plt.matshow(psf)

In [None]:
psf_expanded = simulated_datasets_lib._expand_psf(psf, 101)

In [None]:
plt.matshow(psf_expanded)

In [None]:
psf_expanded[50, 50]

In [None]:
psf[25, 25]

In [None]:
np.max(psf)

# Check locations

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf_fit_file, slen = 11, sky_intensity=686)

In [None]:
locs = torch.Tensor([[[0/11, 5/11]]])
n_stars = torch.Tensor([1]).type(torch.LongTensor)
fluxes = torch.Tensor([[1000]])

In [None]:
mean = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = False)
plt.matshow(mean.squeeze())

# Test tile images. 

Draw a full image, with sparse number of stars

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 101
data_params['min_stars'] = 10
data_params['max_stars'] = 10
data_params['alpha'] = 0.5

print(data_params)


In [None]:
slen = data_params['slen']

In [None]:
simulated_dataset = simulated_datasets_lib.load_dataset_from_params(psf_fit_file, data_params, 
                                                n_images = 2,
                                                # for testing, turn off noise
                                                add_noise = False)

In [None]:
plt.matshow(simulated_dataset.images[0].squeeze())

In [None]:
plt.matshow(simulated_dataset.images[1].squeeze())

In [None]:
# save image parameters
full_images = simulated_dataset.images
full_locs = simulated_dataset.locs
full_fluxes = simulated_dataset.fluxes

### Parameters

In [None]:
full_slen = full_images.shape[-1]
subimage_slen = 9
step = 4
edge_padding = 2

### Get batch images

In [None]:
images_batched = image_utils.tile_images(full_images, subimage_slen, step)

# get tile coordinates
tile_coords = image_utils.get_tile_coords(full_slen, full_slen, subimage_slen, step)
n_patches = tile_coords.shape[0]

### Test tile coordinates 

In [None]:
for i in range(images_batched.shape[0]):
    
    b = i // n_patches
    
    x0 = tile_coords[i % n_patches, 0]
    x1 = tile_coords[i % n_patches, 1]
    
    foo = full_images[b].squeeze()[x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]

    assert np.all(images_batched[i].squeeze() == foo)

### Test extraction of parameters

In [None]:
# get reconstruction with the subimage parameters 
subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      full_locs,
                                      full_fluxes,
                                      full_slen,
                                      subimage_slen,
                                      edge_padding)

patch_simulator = simulated_datasets_lib.StarSimulator(psf_fit_file, 
                                                       subimage_slen - 2 * edge_padding, 
                                                       data_params['sky_intensity'])

_n_stars = (torch.ones(subimage_fluxes.shape[0]) * subimage_fluxes.shape[1]).type(torch.LongTensor)

recon_means = patch_simulator.draw_image_from_params(subimage_locs, 
                                                subimage_fluxes, 
                                                _n_stars, 
                                                add_noise = False)

In [None]:
for indx in range(images_batched.shape[0]): 
    if n_stars[indx] == 0: 
        continue 
        
    f, axarr = plt.subplots(1, 3, figsize=(16, 6))
    
    x0 = tile_coords[indx % n_patches, 0]
    x1 = tile_coords[indx % n_patches, 1]
    
    which_nonzero = is_on_array[indx]
    
    
    image_patch_indx = full_images[indx // n_patches, 0,
                                       (x0 + edge_padding):(x0 - edge_padding + subimage_slen), 
                                       (x1 + edge_padding):(x1 - edge_padding + subimage_slen)]
    im1 = axarr[0].matshow(image_patch_indx)
    
    axarr[0].scatter(subimage_locs[indx, which_nonzero == 1, 1] * (subimage_slen - 1 - 2 * edge_padding), 
                    subimage_locs[indx, which_nonzero == 1, 0] * (subimage_slen - 1 - 2 * edge_padding))
    f.colorbar(im1, ax = axarr[0])
    
    axarr[0].set_title('n_stars: {}\n'.format(n_stars[indx]))
    
    im2 = axarr[1].matshow(recon_means[indx].squeeze())
    f.colorbar(im2, ax = axarr[1])
    
    residual = (image_patch_indx - recon_means[indx].squeeze()) / image_patch_indx
    im3 = axarr[2].matshow(residual)
    f.colorbar(im3, ax = axarr[2])
    i += 1


# Test the getting full image from patch parameters

We draw a crowded starfield for this

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

print(data_params)


In [None]:
n_images = 2

simulated_dataset = simulated_datasets_lib.load_dataset_from_params(psf_fit_file, data_params, 
                                                n_images = n_images,
                                                # for testing, turn off noise
                                                add_noise = False)

In [None]:
plt.matshow(simulated_dataset.images[0].squeeze())

In [None]:
# save image parameters
full_images = simulated_dataset.images
full_locs = simulated_dataset.locs
full_fluxes = simulated_dataset.fluxes

# subimage parameters 
full_slen = full_images.shape[-1]
subimage_slen = 9
step = 4
edge_padding = 2

### Get batch images

In [None]:
images_batched = image_utils.tile_images(full_images, subimage_slen, step)

# get tile coordinates
tile_coords = image_utils.get_tile_coords(full_slen, full_slen, subimage_slen, step)
n_patches = tile_coords.shape[0]

In [None]:
# get subimage parameters
subimage_locs, subimage_fluxes, n_stars, is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      full_locs,
                                      full_fluxes,
                                      full_slen,
                                      subimage_slen,
                                      edge_padding)


### Now revert to full image parameters

In [None]:
locs_full_image, fluxes_full_image, _ = \
    image_utils.get_full_params_from_patch_params(subimage_locs, subimage_fluxes, 
                                            tile_coords, full_slen, subimage_slen, edge_padding, 
                                            batchsize = n_images)

In [None]:
recon_means = simulated_dataset.simulator.draw_image_from_params(locs = locs_full_image, 
                                                  fluxes = fluxes_full_image, 
                                                  n_stars = torch.sum(fluxes_full_image > 0, dim = 1), 
                                                  add_noise = False)

In [None]:
for i in range(n_images): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    
    im1 = axarr[0].matshow(simulated_dataset.images[i].squeeze())
    fig.colorbar(im1, ax = axarr[0])
    
    im2 = axarr[1].matshow(recon_means[i].squeeze())
    fig.colorbar(im2, ax = axarr[1])
    
    residual = (recon_means[i].squeeze() - simulated_dataset.images[i].squeeze())
    im3 = axarr[2].matshow(residual)
    fig.colorbar(im3, ax = axarr[2])