In [None]:
import numpy as np

import matplotlib.pyplot as plt

import torch

import fitsio

import sys
sys.path.insert(0, './../')

import sdss_psf
import simulated_datasets_lib
import sdss_dataset_lib
import image_utils

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

In [None]:
np.random.seed(43534)
_ = torch.manual_seed(24534)

# Test the getting full image from patch parameters

We draw a crowded starfield for this

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

data_params['slen'] = 100
print(data_params)


In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_g = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()
psf_og = torch.Tensor(np.array([psf_r, psf_g]))

n_bands = psf_og.shape[0]

background = torch.zeros(psf_og.shape[0], data_params['slen'], data_params['slen'])
background[0] = 686.
background[1] = 1123.

In [None]:
simulated_dataset = simulated_datasets_lib.load_dataset_from_params(psf_og, data_params, 
                                                n_images = 3,
                                                background = background, 
                                                transpose_psf = False, 
                                                # for testing, turn off noise
                                                add_noise = False)

In [None]:
# save image parameters
images = simulated_dataset.images
locs = simulated_dataset.locs
fluxes = simulated_dataset.fluxes

# patch parameters 
slen = images.shape[-1]
patch_slen = 2
step = 2
edge_padding = 0

In [None]:
simulated_dataset.n_stars

# Get image patches

In [None]:
image_patches = \
    image_utils.tile_images(images, patch_slen, step)

# Test tile coordinates

In [None]:
# again test my batches
tile_coords = image_utils.get_tile_coords(slen, slen, patch_slen, step)
n_patches = tile_coords.shape[0]

for i in range(image_patches.shape[0]):
    
    b = i // n_patches
    
    x0 = tile_coords[i % n_patches, 0]
    x1 = tile_coords[i % n_patches, 1]
    
    foo = images[b, :, x0:(x0 + patch_slen), x1:(x1 + patch_slen)]

    assert np.all(image_patches[i] == foo)

# get patch parameters

In [None]:
patch_locs, patch_fluxes, patch_n_stars, patch_is_on_array = \
    image_utils.get_params_in_patches(tile_coords,
                                      locs,
                                      fluxes,
                                      slen,
                                      patch_slen,
                                      edge_padding)


In [None]:
plt.hist(patch_n_stars)

# Now revert to full image parameters

In [None]:
locs2, fluxes2, n_stars2 = \
    image_utils.get_full_params_from_patch_params(patch_locs, patch_fluxes, 
                                            tile_coords, slen, patch_slen, edge_padding)

# They should match

In [None]:
recon_means = simulated_dataset.simulator.draw_image_from_params(locs = locs2, 
                                                  fluxes = fluxes2, 
                                                  n_stars = n_stars2, 
                                                  add_noise = False)

In [None]:
# first band
for i in range(recon_means.shape[0]): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    
    im1 = axarr[0].matshow(simulated_dataset.images[i, 0].squeeze())
    fig.colorbar(im1, ax = axarr[0])
    
    im2 = axarr[1].matshow(recon_means[i, 0].squeeze())
    fig.colorbar(im2, ax = axarr[1])
    
    residual = recon_means[i, 0].squeeze() - simulated_dataset.images[i, 0].squeeze()
    im3 = axarr[2].matshow(residual)
    fig.colorbar(im3, ax = axarr[2])

In [None]:
# second band
for i in range(recon_means.shape[0]): 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    
    im1 = axarr[0].matshow(simulated_dataset.images[i, 1].squeeze())
    fig.colorbar(im1, ax = axarr[0])
    
    im2 = axarr[1].matshow(recon_means[i, 1].squeeze())
    fig.colorbar(im2, ax = axarr[1])
    
    residual = recon_means[i, 1].squeeze() - simulated_dataset.images[i, 1].squeeze()
    im3 = axarr[2].matshow(residual)
    fig.colorbar(im3, ax = axarr[2])