In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

In [None]:
bands = [2]

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = 600, x1 = 0, slen = 801, 
                                                   bands = bands, fudge_conversion=1.0)

# image 
full_image = sdss_hubble_data.sdss_image.unsqueeze(0)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0) 

# true parameters
true_locs = sdss_hubble_data.locs
true_fluxes = sdss_hubble_data.fluxes


In [None]:
slen0 = full_image.shape[2]
slen1 = full_image.shape[3]

In [None]:
plt.matshow(full_image.squeeze())

In [None]:
plt.hist(sdss_hubble_data.hubble_color.clamp(max = 5.0), bins = 100);

In [None]:
plt.scatter(torch.log10(sdss_hubble_data.fluxes.squeeze()),
           sdss_hubble_data.hubble_color.clamp(max = 10))

In [None]:
plt.scatter(torch.log10(sdss_hubble_data.fluxes.squeeze())[torch.log10(true_fluxes.squeeze()) > 5.0],
           sdss_hubble_data.hubble_color.clamp(max = 10)[torch.log10(true_fluxes.squeeze()) > 5.0])

In [None]:
plt.hist(sdss_hubble_data.hubble_color.clamp(max = 5.0)[(torch.log10(true_fluxes.squeeze()) > 5.0) & 
                                                       (torch.log10(true_fluxes.squeeze()) < 6.0)], bins = 100);

In [None]:
sdss_hubble_data.hubble_color.clamp(max = 5.0)[torch.log10(true_fluxes.squeeze()) > 5.0].median()

# Get reconstruction

In [None]:
import fitsio

psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()

psf_og = np.array([psf_r])


In [None]:
truth_recon = torch.Tensor(np.loadtxt('truth_recon')).unsqueeze(0).unsqueeze(0)

In [None]:
plt.matshow(truth_recon.squeeze())

In [None]:
residual = (truth_recon - full_image) / full_image

In [None]:
plt.matshow(residual.squeeze())

In [None]:
plt.hist(torch.log10(true_fluxes.squeeze()), bins = 100);

In [None]:
def get_star_patches(full_image, true_locs, which_stars, subimage_slen):
    
    assert len(full_image.shape) == 4
    assert (true_locs >= 0).all() & (true_locs <= 1).all()
    
    slen0 = full_image.shape[-2]
    slen1 = full_image.shape[-1]
    
    which_locs = true_locs[which_stars]
    
    star_patches = torch.zeros(which_locs.shape[0], subimage_slen, subimage_slen)
    patch_coords = torch.zeros(which_locs.shape[0], 2)
    
    is_blended = torch.zeros(which_locs.shape[0])
    
    for i in range(which_stars.shape[0]):
        loc_i = which_locs[i] * torch.Tensor([slen0 - 1., slen1 - 1.])

        which_pix = loc_i.round().type(torch.long)

        x0 = int(which_pix[0] - (subimage_slen - 1) / 2)
        x1 = int(which_pix[1] - (subimage_slen - 1) / 2)
        
        assert x0 > 0
        assert x1 > 0
        assert (x0 + subimage_slen) < slen0
        assert (x1 + subimage_slen) < slen1
        
        star_patches[i] = full_image[0, 0, x0:(x0 + subimage_slen), x1:(x1 + subimage_slen)]
        patch_coords[i] = torch.Tensor([x0, x1])
        
        if (star_patches[i, int((subimage_slen - 1) / 2), 
                            int((subimage_slen - 1) / 2)]) < star_patches[i].max(): 
            is_blended[i] = 1
        
        
    return star_patches, patch_coords, is_blended

# The effect of color

In [None]:
which_stars = torch.nonzero((sdss_hubble_data.hubble_color > 0.5) & 
                            (torch.log10(true_fluxes.squeeze()) > 5.0) & 
                            (torch.log10(true_fluxes.squeeze()) < 6.0) & 
                           (true_locs[:, 0] < 0.95) & (true_locs[:, 1] < 0.95) & 
                           (true_locs[:, 0] > 0.05) & (true_locs[:, 1] > 0.05)).squeeze()


print(len(which_stars))

subimage_slen = 7

star_patches, patch_coords, is_blended = \
    get_star_patches(full_image, true_locs, which_stars, subimage_slen)

star_patches = star_patches[is_blended == 0]
    
# normalize
star_patches_normalized = \
    star_patches / star_patches.view(star_patches.shape[0], -1).sum(1).unsqueeze(-1).unsqueeze(-1)

plt.matshow(star_patches_normalized.mean(0))
plt.colorbar()

In [None]:
star_patches_normalized.mean(0).sum()

In [None]:
for i in range(10): 
    plt.matshow(star_patches[i])

In [None]:
which_stars = torch.nonzero((sdss_hubble_data.hubble_color < 0.5) & 
                            (torch.log10(true_fluxes.squeeze()) > 5.0) & 
                            (torch.log10(true_fluxes.squeeze()) < 6.0) & 
                           (true_locs[:, 0] < 0.95) & (true_locs[:, 1] < 0.95) & 
                           (true_locs[:, 0] > 0.05) & (true_locs[:, 1] > 0.05)).squeeze()


print(len(which_stars))

subimage_slen = 7

star_patches2, patch_coords, is_blended = \
    get_star_patches(full_image, true_locs, which_stars, subimage_slen)

star_patches2 = star_patches2[is_blended == 0]

# normalize
star_patches_normalized2 = \
    star_patches2 / star_patches2.view(star_patches2.shape[0], -1).sum(1).unsqueeze(-1).unsqueeze(-1)

plt.matshow(star_patches_normalized2.mean(0))
plt.colorbar()

In [None]:
star_patches_normalized2.mean(0).sum()

In [None]:
for i in range(10): 
    plt.matshow(star_patches2[i])

In [None]:
foo = star_patches_normalized2.mean(0) - star_patches_normalized.mean(0)
plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
max(3, 2)

In [None]:
# figure for paper
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

vmax = max(star_patches_normalized.mean(0).max(), star_patches_normalized2.mean(0).max())
vmin = min(star_patches_normalized.mean(0).min(), star_patches_normalized2.mean(0).min())

im0 = axarr[0].matshow(star_patches_normalized.mean(0), vmax = vmax, vmin = vmin)
fig.colorbar(im0, ax = axarr[0])
axarr[0].set_title('\nOverlayed Stars with V-I > 0.5', fontsize = 16)

im1 = axarr[1].matshow(star_patches_normalized2.mean(0), vmax = vmax, vmin = vmin)
fig.colorbar(im1, ax = axarr[1])
axarr[1].set_title('\nOverlayed Stars with V-I < 0.5', fontsize = 16)

diff = star_patches_normalized2.mean(0) - star_patches_normalized.mean(0)
im2 = axarr[2].matshow(diff, vmax = diff.abs().max(), vmin = -diff.abs().max(), 
                       cmap = plt.get_cmap('bwr'))
fig.colorbar(im2, ax = axarr[2])
axarr[2].set_title('\nDifference', fontsize = 16)

for i in range(3): 
    axarr[i].set_xticks([])
    axarr[i].set_yticks([])

fig.tight_layout()
fig.savefig('../../qualifying_exam_slides/figures/psf_by_color.png')

In [None]:
profile0_1 = star_patches_normalized.mean(0).sum(0)
profile0_2 = star_patches_normalized2.mean(0).sum(0)

profile1_1 = star_patches_normalized.mean(0).sum(1)
profile1_2 = star_patches_normalized2.mean(0).sum(1)

In [None]:
from scipy.interpolate import interp1d

In [None]:
_x = np.arange(- (subimage_slen - 1) / 2, (subimage_slen + 1) / 2)
f0_1 = interp1d(_x, profile0_1.numpy(), kind = 'cubic')
f0_2 = interp1d(_x, profile0_2.numpy(), kind = 'cubic')

f1_1 = interp1d(_x, profile1_1.numpy(), kind = 'cubic')
f1_2 = interp1d(_x, profile1_2.numpy(), kind = 'cubic')

In [None]:
x = np.linspace(- (subimage_slen - 1) / 2, (subimage_slen - 1) / 2, 100)

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(12, 4))
axarr[0].scatter(_x, profile1_1.numpy(), color = 'b', marker = 'x', label = 'V-I > 0.5')
axarr[0].scatter(_x, profile1_2.numpy(), color = 'r', marker = 'x', label = 'V-I < 0.5')

axarr[0].plot(x, f1_1(x), color = 'b', alpha = 0.5)
axarr[0].plot(x, f1_2(x), color = 'r', alpha = 0.5)

axarr[0].legend()
axarr[0].set_title('x-coordinate star profile', fontsize = 16)

axarr[1].scatter(_x, profile0_1.numpy(), color = 'b', marker = 'x', label = 'V-I > 0.5')
axarr[1].scatter(_x, profile0_2.numpy(), color = 'r', marker = 'x', label = 'V-I < 0.5')

axarr[1].plot(x, f0_1(x), color = 'b', alpha = 0.5)
axarr[1].plot(x, f0_2(x), color = 'r', alpha = 0.5)

axarr[1].legend()
axarr[1].set_title('y-coordinate star profile', fontsize = 16)


for i in range(2): 
    axarr[i].set_xlabel('pixel coordinate', fontsize = 14)
    axarr[i].set_ylabel('normalized brightness', fontsize = 14)

fig.tight_layout()
fig.savefig('../../qualifying_exam_slides/figures/psf_profile_by_color.png')

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full.squeeze()[1100:1120, 410:430])
plt.savefig('../../qualifying_exam_slides/figures/saturation_ex.png')

In [None]:
plt.matshow(sdss_hubble_data.sdss_image.squeeze()[280:300, 200:220])
plt.savefig('../../qualifying_exam_slides/figures/bleed_train_ex.png')