In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import fitsio

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib
import image_utils

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Test tile images. 

Draw a full image, with sparse number of stars

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 101
data_params['min_stars'] = 6
data_params['max_stars'] = 6

data_params['alpha'] = 0.5

print(data_params)


In [None]:
slen = data_params['slen']

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_g = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = torch.Tensor(np.array([psf_r, psf_g]))

n_bands = psf_og.shape[0]

In [None]:
background = torch.ones(n_bands, slen, slen)
background[0] = 686.
background[1] = 1123.

In [None]:
simulated_dataset = simulated_datasets_lib.load_dataset_from_params(psf_og, data_params, 
                                                n_images = 3,
                                                background = background, 
                                                transpose_psf = False, 
                                                # for testing, turn off noise
                                                add_noise = False)

In [None]:
plt.matshow(simulated_dataset.images[0, 0].squeeze())

In [None]:
plt.matshow(simulated_dataset.images[1, 0].squeeze())

In [None]:
# save image parameters
images = simulated_dataset.images
locs = simulated_dataset.locs
fluxes = simulated_dataset.fluxes

### Parameters

In [None]:
patch_slen = 9
step = 4
edge_padding = 2

### Get batch images

In [None]:
image_patches = \
    image_utils.tile_images(images, patch_slen, step)

In [None]:
tile_coords = image_utils.get_tile_coords(slen, slen, patch_slen, step)
n_patches = tile_coords.shape[0]

### Test tile coordinates 

In [None]:
for i in range(image_patches.shape[0]):
    
    b = i // n_patches
    
    x0 = tile_coords[i % n_patches, 0]
    x1 = tile_coords[i % n_patches, 1]
    
    foo = images[b, :, x0:(x0 + patch_slen), x1:(x1 + patch_slen)]

    assert np.all(image_patches[i] == foo)

### Test extraction of parameters

In [None]:
# get reconstruction with the patch parameters 
patch_locs, patch_fluxes, patch_n_stars, patch_is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      locs,
                                      fluxes,
                                      slen,
                                      patch_slen,
                                      edge_padding)

In [None]:
patch_background = torch.zeros(n_bands, patch_slen - 2 * edge_padding, patch_slen - 2 * edge_padding)
patch_background[0] = background[0, 0, 0]
patch_background[1] = background[1, 0, 0]
patch_simulator = simulated_datasets_lib.StarSimulator(psf_og, 
                                                       patch_slen - 2 * edge_padding, 
                                                       transpose_psf = False, 
                                                       background = patch_background)

_n_stars = (torch.ones(patch_fluxes.shape[0]) * patch_fluxes.shape[1]).type(torch.LongTensor)

recon_means = patch_simulator.draw_image_from_params(patch_locs, 
                                                patch_fluxes, 
                                                _n_stars, 
                                                add_noise = False)

In [None]:
band_indx = 1

In [None]:
for indx in range(image_patches.shape[0]): 
    if patch_n_stars[indx] == 0: 
        continue 
        
    f, axarr = plt.subplots(1, 3, figsize=(16, 6))
    
    x0 = tile_coords[indx % n_patches, 0]
    x1 = tile_coords[indx % n_patches, 1]
    
    which_nonzero = patch_is_on_array[indx]
    
    
    image_patch_indx = images[indx // n_patches, band_indx,
                                       (x0 + edge_padding):(x0 - edge_padding + patch_slen), 
                                       (x1 + edge_padding):(x1 - edge_padding + patch_slen)]
    im1 = axarr[0].matshow(image_patch_indx)
    
    axarr[0].scatter(patch_locs[indx, which_nonzero == 1, 1] * (patch_slen - 1 - 2 * edge_padding), 
                    patch_locs[indx, which_nonzero == 1, 0] * (patch_slen - 1 - 2 * edge_padding))
    f.colorbar(im1, ax = axarr[0])
    
    axarr[0].set_title('n_stars: {}\n'.format(patch_n_stars[indx]))
    
    im2 = axarr[1].matshow(recon_means[indx, band_indx].squeeze())
    f.colorbar(im2, ax = axarr[1])
    
    residual = (image_patch_indx - recon_means[indx, band_indx].squeeze()) / image_patch_indx
    im3 = axarr[2].matshow(residual)
    f.colorbar(im3, ax = axarr[2])
    i += 1


# Test the getting full image from patch parameters

We draw a crowded starfield for this

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

print(data_params)


In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_g = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = torch.Tensor(np.array([psf_r, psf_g]))

In [None]:
simulated_dataset = simulated_datasets_lib.load_dataset_from_params(psf_og, data_params, 
                                                n_images = 3,
                                                background = background, 
                                                transpose_psf = False, 
                                                # for testing, turn off noise
                                                add_noise = False)

In [None]:
# save image parameters
images = simulated_dataset.images
locs = simulated_dataset.locs
fluxes = simulated_dataset.fluxes

# patch parameters 
slen = images.shape[-1]
patch_slen = 9
step = 4
edge_padding = 2

In [None]:
simulated_dataset.n_stars

### Get batch images

In [None]:
image_patches = \
    image_utils.tile_images(images, patch_slen, step)

In [None]:
# again test my batches
tile_coords = image_utils.get_tile_coords(slen, slen, patch_slen, step)
n_patches = tile_coords.shape[0]

for i in range(image_patches.shape[0]):
    
    b = i // n_patches
    
    x0 = tile_coords[i % n_patches, 0]
    x1 = tile_coords[i % n_patches, 1]
    
    foo = images[b, :, x0:(x0 + patch_slen), x1:(x1 + patch_slen)]

    assert np.all(image_patches[i] == foo)

In [None]:
# get patch parameters
patch_locs, patch_fluxes, patch_n_stars, patch_is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      locs,
                                      fluxes,
                                      slen,
                                      patch_slen,
                                      edge_padding)


### Now revert to full image parameters

In [None]:
locs2, fluxes2, n_stars2 = \
    image_utils.get_full_params_from_patch_params(patch_locs, patch_fluxes, 
                                            tile_coords, slen, patch_slen, edge_padding)

In [None]:
recon_means = simulated_dataset.simulator.draw_image_from_params(locs = locs2, 
                                                  fluxes = fluxes2, 
                                                  n_stars = n_stars2, 
                                                  add_noise = False)

In [None]:
# first band
for i in range(recon_means.shape[0]): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    
    im1 = axarr[0].matshow(simulated_dataset.images[i, 0].squeeze())
    fig.colorbar(im1, ax = axarr[0])
    
    im2 = axarr[1].matshow(recon_means[i, 0].squeeze())
    fig.colorbar(im2, ax = axarr[1])
    
    residual = recon_means[i, 0].squeeze() - simulated_dataset.images[i, 0].squeeze()
    im3 = axarr[2].matshow(residual)
    fig.colorbar(im3, ax = axarr[2])

In [None]:
# second band
for i in range(recon_means.shape[0]): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    
    im1 = axarr[0].matshow(simulated_dataset.images[i, 1].squeeze())
    fig.colorbar(im1, ax = axarr[0])
    
    im2 = axarr[1].matshow(recon_means[i, 1].squeeze())
    fig.colorbar(im2, ax = axarr[1])
    
    residual = recon_means[i, 1].squeeze() - simulated_dataset.images[i, 1].squeeze()
    im3 = axarr[2].matshow(residual)
    fig.colorbar(im3, ax = axarr[2])