In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

In [None]:
bands = [2]

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = 600, x1 = 0, slen = 800, 
                                                   bands = bands, fudge_conversion=1.0)

# image 
full_image = sdss_hubble_data.sdss_image.unsqueeze(0)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0) 

# true parameters
true_locs = sdss_hubble_data.locs
true_fluxes = sdss_hubble_data.fluxes


In [None]:
slen0 = full_image.shape[2]
slen1 = full_image.shape[3]

In [None]:
plt.matshow(full_image.squeeze())

In [None]:
plt.hist(torch.log10(true_fluxes.squeeze()), bins = 100);

In [None]:
def get_star_patches(full_image, true_locs, which_stars, subimage_slen):
    
    assert len(full_image.shape) == 4
    assert (true_locs >= 0).all() & (true_locs <= 1).all()
    
    slen0 = full_image.shape[-2]
    slen1 = full_image.shape[-1]
    
    which_locs = true_locs[which_stars]
    
    star_patches = torch.zeros(which_locs.shape[0], subimage_slen, subimage_slen)
    patch_coords = torch.zeros(which_locs.shape[0], 2)
    
    for i in range(which_stars.shape[0]):
        loc_i = which_locs[i] * torch.Tensor([slen0 - 1., slen1 - 1.])

        which_pix = loc_i.round().type(torch.long)

        x0 = int(which_pix[0] - (subimage_slen - 1) / 2)
        x1 = int(which_pix[1] - (subimage_slen - 1) / 2)
        
        assert x0 > 0
        assert x1 > 0
        assert (x0 + subimage_slen) < slen0
        assert (x1 + subimage_slen) < slen1
        
        star_patches[i] = full_image[0, 0, x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]
        patch_coords[i] = torch.Tensor([x0, x1])
        
    return star_patches, patch_coords

In [None]:
which_stars = torch.nonzero((torch.log10(true_fluxes.squeeze()) > 5.0) & 
                           (true_locs[:, 0] < 0.95) & (true_locs[:, 1] < 0.95) & 
                           (true_locs[:, 0] > 0.05) & (true_locs[:, 1] > 0.05)).squeeze()


In [None]:
subimage_slen = 5

In [None]:
star_patches, patch_coords = \
    get_star_patches(full_image, true_locs, which_stars, subimage_slen)

In [None]:
for i in range(20, 30): 
    # plot 
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

    x0 = int(patch_coords[i, 0])
    x1 = int(patch_coords[i, 1])
    
    plotting_utils.plot_subimage(axarr[0], full_image[0, 0], 
                                None, 
                                true_locs, 
                                x0 = x0, 
                                x1 = x1, 
                                subimage_slen = subimage_slen, 
                                add_colorbar = True, global_fig = fig,
                                diverging_cmap = False)

    axarr[0].set_xticklabels(np.arange(x1 - 1, x1 + 7));
    axarr[0].set_yticklabels(np.arange(x0 - 1, x0 + 7)); 
    
    axarr[1].matshow(star_patches[i])

In [None]:
# normalize
star_patches_normalized = \
    star_patches / star_patches.view(star_patches.shape[0], -1).sum(1).unsqueeze(-1).unsqueeze(-1)

In [None]:
plt.matshow(star_patches_normalized.mean(0))
plt.colorbar()