In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 

import starnet_lib
import inv_kl_objective_lib as inv_kl_lib
import plotting_utils
import wake_sleep_lib

import psf_transform_lib
import image_statistics_lib

np.random.seed(34534)

# Load the data

In [None]:
f_min = 1000.

In [None]:
bands = [2]

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = bands)

# image 
full_image = sdss_hubble_data.sdss_image.unsqueeze(0)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0)

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > f_min)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
plt.matshow(full_image[0, 0])
plt.colorbar()

# Load SDSS PSF

In [None]:
import fitsio

In [None]:
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

psf_og = np.array([psf_r])
# psf_og = np.loadtxt('../data/my_r_psf.txt')[None]


In [None]:
psf_init = torch.Tensor(simulated_datasets_lib._expand_psf(psf_og, full_image.shape[-1]))

# True PSF 

In [None]:
psf_transform = psf_transform_lib.PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1], 
                                    kernel_size = 3)

In [None]:
psf_transform.load_state_dict(torch.load('../fits/results_11202019/true_psf_transform_630x310_r',
                                             map_location=lambda storage, loc: storage))
    
psf_truth = psf_transform.forward().detach()

In [None]:
# Check losses to make sure it was loaded correctly 

recon_mean_init, init_loss = \
        psf_transform_lib.get_psf_loss(full_image, 
                                        full_background,
                                        sdss_hubble_data.locs.unsqueeze(0), 
                                        sdss_hubble_data.fluxes.unsqueeze(0), 
                                        n_stars = torch.Tensor([len(sdss_hubble_data.fluxes)]).type(torch.long), 
                                        psf = psf_init,
                                        pad = 5)
recon_mean_truth, truth_loss = \
        psf_transform_lib.get_psf_loss(full_image, 
                                        full_background,
                                        sdss_hubble_data.locs.unsqueeze(0), 
                                        sdss_hubble_data.fluxes.unsqueeze(0), 
                                        n_stars = torch.Tensor([len(sdss_hubble_data.fluxes)]).type(torch.long), 
                                        psf = psf_truth,
                                        pad = 5)

In [None]:
print('init loss: ', init_loss)
print('truth loss: ', truth_loss)

In [None]:
band = 0

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(15, 4))

residual_init = ((recon_mean_init[0, band] - full_image[0, band]) / full_image[0, band])
vmax = residual_init.abs().max()
im0 = axarr[0].matshow(residual_init, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im0, ax = axarr[0])

residual_truth = ((recon_mean_truth[0, band] - full_image[0, band]) / full_image[0, band])
vmax = residual_truth.abs().max()
im0 = axarr[1].matshow(residual_truth, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im0, ax = axarr[1])

# Our simulator

In [None]:
sky_intensity = full_background.reshape(full_background.shape[1], -1).mean(1)

simulator = simulated_datasets_lib.StarSimulator(psf_truth, 
                                                 slen=full_image.shape[-1], 
                                                 transpose_psf = False, 
                                                 sky_intensity = sky_intensity)

In [None]:
# check again 
_recon_mean_truth = simulator.draw_image_from_params(sdss_hubble_data.locs.unsqueeze(0), 
                                sdss_hubble_data.fluxes.unsqueeze(0), 
                                n_stars = torch.Tensor([len(sdss_hubble_data.fluxes)]).type(torch.long), 
                                add_noise = False)

In [None]:
_residual_truth = ((_recon_mean_truth[0, band] - full_image[0, band]) / full_image[0, band])
vmax = residual_truth.abs().max()
plt.matshow(_residual_truth, vmin = -vmax, vmax = vmax, cmap=plt.get_cmap('bwr'))
plt.colorbar()

# define VAEs

In [None]:
star_encoder1 = starnet_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2)

In [None]:
star_encoder1.load_state_dict(torch.load('../fits/results_11202019/starnet_r', 
                                       map_location=lambda storage, loc: storage))

In [None]:
star_encoder1.eval();

In [None]:
star_encoder2 = starnet_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 7,
                                           step = 2,
                                           edge_padding = 2,
                                           n_bands = len(bands),
                                           max_detections = 2)

In [None]:
star_encoder2.load_state_dict(torch.load('../fits/results_11202019/wake-sleep_630x310_r-encoder-iter6', 
                                       map_location=lambda storage, loc: storage))

In [None]:
star_encoder2.eval();

# Draw image patches

In [None]:
n_tests = 100
locs = torch.ones((n_tests, 2, 2))
# location of first star
locs[:, 0, 0] = 49.2 / 100
locs[:, 0, 1] = 49.2 / 100

locs[:, 1, 0] = 50.8 / 100
locs[:, 1, 1] = 50.8 / 100

fluxes = torch.ones((n_tests, 2, 1)) * 1e4

n_stars = (torch.ones(n_tests) * 2).type(torch.long)

In [None]:
image_patches = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = True)[:, :, 47:54, 47:54]

In [None]:
plt.matshow(image_patches[0, 0])

In [None]:
background_stamps = torch.ones(image_patches.shape) * full_background.mean()
probs = star_encoder1(image_patches, background_stamps)[4]

In [None]:
probs.argmax(1)

In [None]:
probs = star_encoder2(image_patches, background_stamps)[4]

In [None]:
n_tests = 100
locs = torch.ones((n_tests, 2, 2))

delta = np.arange(0.0, 1.7, 0.1)

probs1_vec = np.zeros(len(delta))
probs2_vec = np.zeros(len(delta))

for i in range(len(delta)): 
    # location of first star
    locs[:, 0, 0] = 49.2 / 100
    locs[:, 0, 1] = 49.2 / 100
    
    # location of second star
    locs[:, 1, 0] = (49.2 + delta[i]) / 100
    locs[:, 1, 1] = (49.2 + delta[i]) / 100
    
    # set fluxes
    fluxes = torch.ones((n_tests, 2, 1))
    fluxes[:, 0, :] = 1e4
    fluxes[:, 1, :] = 1e4
    
    # number of stars
    n_stars = (torch.ones(n_tests) * 2).type(torch.long)
    
    # draw image 
    image_patches = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = True)[:, :, 47:54, 47:54]
    background_stamps = torch.ones(image_patches.shape) * full_background.mean()

    # infer
    log_probs1 = star_encoder1(image_patches, background_stamps)[4]
    log_probs2 = star_encoder2(image_patches, background_stamps)[4]
    
#     probs1_vec[i] = torch.exp(log_probs1[:, 2]).mean()
#     probs2_vec[i] = torch.exp(log_probs2[:, 2]).mean()
    probs1_vec[i] = (log_probs1.argmax(1) == 2).float().mean()
    probs2_vec[i] = (log_probs2.argmax(1) == 2).float().mean()
    

In [None]:
save_fig = False

In [None]:
plt.plot(delta * np.sqrt(2), probs1_vec, '-x', color = 'orange', label = 'sleep-only')
plt.plot(delta * np.sqrt(2), probs2_vec, '-x', color = 'red', label = 'wake-sleep')

plt.xlabel('Separation, $\delta$', size = 16)
plt.ylabel('True positive rate', size = 16)
plt.legend()

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/deblending_test.png')

In [None]:
# Example for paper
n_tests = 100
locs = torch.ones((n_tests, 2, 2))
# location of first star
locs[:, 0, 0] = 49.2 / 100
locs[:, 0, 1] = 49.2 / 100

locs[:, 1, 0] = 50.0 / 100
locs[:, 1, 1] = 50.0 / 100

fluxes = torch.ones((n_tests, 2, 1)) 
fluxes[:, 0, :] = 2.3e4
fluxes[:, 1, :] = 1e4


n_stars = (torch.ones(n_tests) * 2).type(torch.long)

In [None]:
image_patches = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = True)[:, :, 47:54, 47:54]

In [None]:
plt.matshow(image_patches[0, 0, 0:6, 0:6])
plt.scatter(locs[0, :, 1] * 100 - 47, locs[0, :, 0] * 100 - 47, color = 'blue')
plt.arrow(locs[0, 0, 1] * 100 - 47 + 0.05, 
      locs[0, 0, 1] * 100 - 47 + 0.05, 
      0.7, 0.7, length_includes_head = True, color = 'red')

plt.text(2.8, 2.7, '$\delta$', color = 'red', size = 16)

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/deblending_test_ex.png')

# Test of one star

In [None]:
n_tests = 100
locs = torch.ones((n_tests, 2, 2))
# location of first star
locs[:, 0, 0] = 50. / 100
locs[:, 0, 1] = 50.9 / 100

locs[:, 1, 0] = 0
locs[:, 1, 1] = 0

fluxes = torch.ones((n_tests, 2, 1)) * 1e4

n_stars = (torch.ones(n_tests)).type(torch.long)

In [None]:
full_image = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = True)
image_patches1 = full_image[:, :, 47:54, 47:54]
image_patches2 = full_image[:, :, 47:54, 49:56]



In [None]:
plt.matshow(image_patches1[0, 0].numpy())
plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

In [None]:
plt.matshow(image_patches2[0, 0].numpy())
plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

In [None]:
background_stamps = torch.ones(image_patches.shape) * full_background.mean()

In [None]:
probs_patch_truth = star_encoder2(image_patches1, background_stamps)[4]
probs_patch_neighbor = star_encoder2(image_patches2, background_stamps)[4]

In [None]:
1 - torch.exp(probs_patch_truth)[:, 0].mean()

In [None]:
1 - torch.exp(probs_patch_neighbor)[:, 0].mean()

In [None]:
n_tests = 100
locs = torch.ones((n_tests, 2, 2))
# location of first star
locs[:, 0, 0] = 50. / 100
locs[:, 0, 1] = 50.9 / 100

locs[:, 1, 0] = 0
locs[:, 1, 1] = 0

fluxes = torch.ones((n_tests, 2, 1)) * 1e4

n_stars = (torch.ones(n_tests)).type(torch.long)

In [None]:
full_image = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = True)
image_patches1 = full_image[:, :, 47:54, 47:54]
image_patches2 = full_image[:, :, 47:54, 49:56]

In [None]:
plt.matshow(image_patches1[0, 0].numpy())
plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

In [None]:
plt.matshow(image_patches2[0, 0].numpy())
plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

plt.axvline(x=2, color = 'r')
plt.axvline(x=4, color = 'r')
plt.axhline(y=2, color = 'r')
plt.axhline(y=4, color = 'r')

In [None]:
background_stamps = torch.ones(image_patches.shape) * full_background.mean()

In [None]:
probs_patch_truth = star_encoder2(image_patches1, background_stamps)[4]
probs_patch_neighbor = star_encoder2(image_patches2, background_stamps)[4]

In [None]:
1 - torch.exp(probs_patch_truth)[:, 0].mean()

In [None]:
1 - torch.exp(probs_patch_neighbor)[:, 0].mean()

In [None]:
delta_vec = np.linspace(-0.4, 0.4, 20)

probs_on_truth_vec = torch.ones(len(delta_vec))
probs_on_neighbor_vec = torch.ones(len(delta_vec))

for i in range(len(delta_vec)): 
    n_tests = 100
    locs = torch.ones((n_tests, 2, 2))
    # location of first star
    locs[:, 0, 0] = 50. / 100
    locs[:, 0, 1] = (51. + delta_vec[i]) / 100

    locs[:, 1, 0] = 0
    locs[:, 1, 1] = 0

    fluxes = torch.ones((n_tests, 2, 1)) * 1e4

    n_stars = (torch.ones(n_tests)).type(torch.long)

    full_image = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = True)
    image_patches1 = full_image[:, :, 47:54, 47:54]
    image_patches2 = full_image[:, :, 47:54, 49:56]

    probs_patch_truth = star_encoder2(image_patches1, background_stamps)[4]
    probs_patch_neighbor = star_encoder2(image_patches2, background_stamps)[4]

    probs_on_truth_vec[i] = 1 - torch.exp(probs_patch_truth)[:, 0].mean()
    probs_on_neighbor_vec[i] = 1 - torch.exp(probs_patch_neighbor)[:, 0].mean()

In [None]:
probs_on_truth_vec

In [None]:
plt.plot(delta_vec, probs_on_truth_vec.detach().numpy(), '-x')
plt.plot(delta_vec, probs_on_neighbor_vec.detach().numpy(), '-x')


In [None]:
plt.matshow(image_patches1[0, 0].numpy())

In [None]:
plt.matshow(image_patches2[0, 0].numpy())

In [None]:
plt.matshow(image_patches1[0, 0].numpy() - image_patches2[0, 0].numpy())
plt.colorbar()

In [None]:
plt.matshow(psf_og[0])

In [None]:
plt.matshow(psf_og[0] - np.flip(psf_og[0], 0))
plt.colorbar()

In [None]:
plt.matshow((psf_og[0] - np.flip(psf_og[0], 1)))
plt.colorbar()

# Check what happens on whole image

In [None]:
n_tests = 1
locs = torch.rand((n_tests, 5, 2))
# location of first star
# locs[:, 0, 0] = 51. / 100 # torch.rand(n_tests)
# locs[:, 0, 1] = 49.8 / 100 # torch.rand(n_tests)

# locs[:, 1, 0] = 0.
# locs[:, 1, 1] = 0.

fluxes = torch.ones((n_tests, 5, 1)) * 1e4

n_stars = (torch.ones(n_tests) * 5).type(torch.long)

In [None]:
images = simulator.draw_image_from_params(locs, fluxes, n_stars, add_noise = True)
backgrounds = torch.ones(images.shape) * simulator.sky_intensity

In [None]:
plt.matshow(images[0, 0])

In [None]:
map_locs, map_fluxes, map_nstars = \
    star_encoder2.sample_star_encoder(images, 
                                         backgrounds, 
                                         return_map = True)[0:3]

In [None]:
star_encoder2.tile_coords[1224]

In [None]:
plt.matshow(images[0, 0])
plt.scatter(map_locs[0, :, 1] * 100, 
           map_locs[0, :, 0] * 100, 
            color = 'r', marker = 'x', alpha = 0.8)

plt.savefig('../../qualifying_exam_slides/figures/sparse_field_test_simulated.png')