In [None]:
import numpy as np
import torch

import json

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')
import simulated_datasets_lib
import sdss_dataset_lib
import sdss_psf
import image_utils 
import starnet_vae_lib
import inv_kl_objective_lib as inv_kl_lib


import plotting_utils

np.random.seed(34534)

# Load the data

In [None]:
fmin = 1000

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData()

# image 
full_image = sdss_hubble_data.sdss_image.squeeze()
full_background = sdss_hubble_data.sdss_background.squeeze() 

# true parameters
which_bright = (sdss_hubble_data.fluxes[:, 0] > fmin)
true_locs = sdss_hubble_data.locs[which_bright]
true_fluxes = sdss_hubble_data.fluxes[which_bright]


In [None]:
len(true_locs)

In [None]:
full_image = torch.Tensor(full_image)
print(full_image.shape)

In [None]:
plt.matshow(full_image.squeeze())
plt.colorbar()
# plt.xticks([]); 
# plt.yticks([]); 
plt.savefig('../../qualifying_exam_slides/figures/sdss_image_port.png')

In [None]:
save_fig = False

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[0])
if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/sdss_image_full.png')

In [None]:
sdss_hubble_data.sdss_image_full[0].shape


In [None]:
save_fig = True

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[0])    
    
plt.axvline(x=310, ymax=1 - 630 / 1489, ymin=1 - 730 / 1489, color = 'red', linewidth = 2)
plt.axvline(x=410, ymax=1 - 630 / 1489, ymin=1 - 730 / 1489, color = 'red', linewidth = 2)
plt.axhline(y=630, xmin = 310 / 2048, xmax = 410 / 2048, color = 'red', linewidth = 2)
plt.axhline(y=730, xmin = 310 / 2048, xmax = 410 / 2048, color = 'red', linewidth = 2)

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/sdss_image_full_port_marked.png')


In [None]:
# plt.matshow(sdss_hubble_data.sdss_image_full[0][720:1100, 20:410])
# if save_fig: 
#     plt.savefig('../../qualifying_exam_slides/figures/sdss_m2_image2.png')

In [None]:
plt.matshow(sdss_hubble_data.sdss_image_full[0][600:1200, 0:420])

plt.axvline(x=310, ymax=1 - 30 / 600, ymin=1 - 130 / 600, color = 'red', linewidth = 2)
plt.axvline(x=410, ymax=1 - 30 / 600, ymin=1 - 130 / 600, color = 'red', linewidth = 2)
plt.axhline(y=30, xmin = 310 / 420, xmax = 410 / 420, color = 'red', linewidth = 2)
plt.axhline(y=130, xmin = 310 / 420, xmax = 410 / 420, color = 'red', linewidth = 2)

# plt.axhline(x=410, ymax=1 - 10 / 600, ymin=1 - 110 / 600, color = 'red', linewidth = 2)


plt.axvline(x=120, ymax=1 - 50 / 600, ymin=1 - 150 / 600, color = 'blue', linewidth = 2)
plt.axvline(x=220, ymax=1 - 50 / 600, ymin=1 - 150 / 600, color = 'blue', linewidth = 2)
plt.axhline(y=50, xmin = 120 / 420, xmax = 220 / 420, color = 'blue', linewidth = 2)
plt.axhline(y=150, xmin = 120 / 420, xmax = 220 / 420, color = 'blue', linewidth = 2)

if save_fig:
    plt.savefig('../../qualifying_exam_slides/figures/sdss_m2_image2_marked.png')

# Example 10x10 images

In [None]:
sdss_hubble_data2 = sdss_dataset_lib.SDSSHubbleData(x0 = 720, x1 = 20, slen=400)

In [None]:
fig, axarr = plt.subplots(2, 2, figsize = (10, 8))

x0_vec = [26, 100, 150, 300]
x1_vec = [25, 300, 100, 350]

for i in range(4): 
    plotting_utils.plot_subimage(axarr[i // 2, i % 2], sdss_hubble_data2.sdss_image[0], 
                                 None,  
                                 sdss_hubble_data2.locs[sdss_hubble_data2.fluxes[:, 0] > 3000], 
                                 x0 = x0_vec[i], 
                                 x1 = x1_vec[i], 
                                 subimage_slen = 10, 
                                add_colorbar = True, 
                                global_fig = fig)

fig.tight_layout()

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/example_subpatches.png')

In [None]:
plt.matshow(sdss_hubble_data2.sdss_image[0])
k = sdss_hubble_data2.sdss_image.shape[-1] - 1
for i in range(4): 
    plt.axhline(x0_vec[i], xmin= (x1_vec[i]) / k, xmax= (x1_vec[i] + 10) / k, color = 'red')
    plt.axhline(x0_vec[i] + 11, xmin= (x1_vec[i]) / k, xmax= (x1_vec[i] + 10) / k, color = 'red')

    plt.axvline(x1_vec[i], ymin = 1 - (x0_vec[i] + 10) / k, ymax = 1 - x0_vec[i] / k, color = 'red')
    plt.axvline(x1_vec[i] + 11, ymin = 1 - (x0_vec[i] + 10) / k, ymax = 1 - x0_vec[i] / k, color = 'red')

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/example_subpatches_full.png')

In [None]:
x0_vec = np.array([53, 41, 31, 32]) + 30


In [None]:
which = torch.nonzero(((sdss_hubble_data.locs[:, 0] > 0.8) & (sdss_hubble_data.locs[:, 0] < 0.9) & \
                  (sdss_hubble_data.locs[:, 1] > 0.55) & (sdss_hubble_data.locs[:, 1] < 0.65) & \
                (sdss_hubble_data.fluxes[:, 0] > 1000)).float()).squeeze()

In [None]:
torch.cat((sdss_hubble_data.locs[which][4:-1], sdss_hubble_data.locs[which][4:-1]))

In [None]:
fig, axarr = plt.subplots(1, 1, figsize = (5, 4))

plotting_utils.plot_subimage(axarr, sdss_hubble_data.sdss_image[0], 
                                 torch.cat((sdss_hubble_data.locs[which][4:-1], 
                                            sdss_hubble_data.locs[which][2:3])),  
                                 None, 
                                 x0 = 80, 
                                 x1 = 55, 
                                 subimage_slen = 10, 
                                add_colorbar = True, 
                                global_fig = fig)

for j in range(5): 
    axarr.axhline(-0.5 + 2*j, color = 'white', linewidth = 2)
    axarr.axvline(-0.5 + 2*j, color = 'white', linewidth = 2)

if save_fig: 
    plt.savefig('../../qualifying_exam_slides/figures/example_tiled.png')

In [None]:
sdss_hubble_data.locs[which]

# Get simulator 

In [None]:
import fitsio
psf_dir = '../../multiband_pcat/Data/idR-002583-2-0136/psfs/'

In [None]:
psf = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()[None]

In [None]:
simulator = simulated_datasets_lib.StarSimulator(psf, 
                                                slen = full_image.shape[-1], 
                                                 transpose_psf = False, 
                                                sky_intensity = torch.Tensor([full_background.mean()]))



# Simulation with ground truth

In [None]:
# truth_recon = simulator.draw_image_from_params(locs = true_locs.unsqueeze(0), 
#                                     fluxes = true_fluxes.unsqueeze(0),
#                                      n_stars = torch.Tensor([len(true_locs)]).type(torch.LongTensor), 
#                                      add_noise = False).squeeze()

truth_recon = simulator.draw_image_from_params(locs = sdss_hubble_data.locs.unsqueeze(0), 
                            fluxes = sdss_hubble_data.fluxes.unsqueeze(0),
                            n_stars = torch.Tensor([len(sdss_hubble_data.fluxes[:, 0])]).type(torch.LongTensor), 
                            add_noise = False).squeeze()

In [None]:
foo = (truth_recon.squeeze() - full_image) / full_image
plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
plt.colorbar()

In [None]:
for i in range(5): 
    f, axarr = plt.subplots(1, 1, figsize=(3, 3))
    plotting_utils.plot_subimage(axarr, full_image.squeeze(),
                                 true_locs, 
                                 None, 
                                 x0 = int(np.random.choice(100, 1)), 
                                 x1 = int(np.random.choice(100, 1)), 
                                 subimage_slen = 5)
    axarr.set_xticks([]);
    axarr.set_yticks([]);
    
    axarr.axvline(x=-0.49, color = 'w', linewidth = 10)
    axarr.axvline(x= 4.49, color = 'w', linewidth = 10)
    axarr.axhline(y= -0.49, color = 'w', linewidth = 10)
    axarr.axhline(y= 4.49, color = 'w', linewidth = 10)

In [None]:
x0 = 48 # int(np.random.choice(100, 1))
x1 = 59 # int(np.random.choice(100, 1))

foo = true_locs * 100
which_locs = (foo[:, 0] > x0 + 2) & (foo[:, 0] < x0 + 4) & \
                (foo[:, 1] > x1 + 2) & (foo[:, 1] < x1 + 4) 

print(which_locs.sum())
print(true_locs[which_locs])

f, axarr = plt.subplots(1, 1, figsize=(3, 3))
plotting_utils.plot_subimage(axarr, full_image.squeeze(), 
                             true_locs[which_locs], 
                             None, 
                             x0 = x0, 
                             x1 = x1, 
                             subimage_slen = 6, color = 'r')

axarr.axvline(x=1.5, ymin = 2/3, ymax = 1/3, color = 'white', linewidth = 5)
axarr.axvline(x=3.5, ymin = 2/3, ymax = 1/3, color = 'white', linewidth = 5)
axarr.axhline(y=1.5, xmin = 2/3, xmax = 1/3, color = 'white', linewidth = 5)
axarr.axhline(y=3.5, xmin = 2/3, xmax = 1/3, color = 'white', linewidth = 5)


axarr.axvline(x=1.5, color = 'white', linewidth = 1)
axarr.axvline(x=3.5, color = 'white', linewidth = 1)
axarr.axhline(y=1.5, color = 'white', linewidth = 1)
axarr.axhline(y=3.5, color = 'white', linewidth = 1)


In [None]:
# example of simulated images 

In [None]:
simulator2 = simulated_datasets_lib.StarSimulator(psf, 
                                                slen = 11, 
                                                  transpose_psf = False, 
                                                sky_intensity = torch.Tensor([full_background.mean()]))



In [None]:
locs = torch.Tensor([[[0.2, 0.22], [0.32, 0.43], [0.71, 0.12], [0.83, 0.89]]])
foo = simulator2.draw_image_from_params(locs=locs, 
                                 fluxes = torch.Tensor([[[2000], [4500], [3000], [2500]]]), 
                                 n_stars = torch.Tensor([4]).type(torch.long))

In [None]:
np.arange(-5, 5, step=1)

In [None]:
fig, ax = plt.subplots()

im = ax.matshow(simulator2.psf.squeeze())
fig.colorbar(im)

plt.xticks(np.arange(1, 11, step = 2), np.arange(-4, 6, step=2));
plt.yticks(np.arange(1, 11, step = 2), np.arange(4, -6, step=-2)); 

ax.xaxis.set_ticks_position('bottom')

fig.tight_layout()

plt.savefig('../../qualifying_exam_slides/figures/example_psf.png')

In [None]:
fig, ax = plt.subplots(1, 2, figsize = (10, 4))

im0 = ax[1].matshow(foo.squeeze())
fig.colorbar(im0, ax = ax[1])
ax[1].scatter(locs[:, :, 1] * 10, 
            locs[:, :, 0] * 10, color = 'b')

im1 = ax[0].matshow(simulator2.psf.squeeze())
fig.colorbar(im, ax = ax[0])

ax[0].set_xticks(np.arange(1, 11, step = 2))
ax[0].set_xticklabels(np.arange(-4, 6, step=2));
ax[0].xaxis.set_ticks_position('bottom')

ax[0].set_yticks(np.arange(1, 11, step = 2),)
ax[0].set_yticklabels(np.arange(4, -6, step=-2)); 

fig.tight_layout()


plt.savefig('../../qualifying_exam_slides/figures/example_gen_model.png')

In [None]:
# example evaluating log lik.

In [None]:
simulator3 = simulated_datasets_lib.StarSimulator(psf, 
                                                slen = 21, 
                                                transpose_psf = False, 
                                                sky_intensity = torch.Tensor([full_background.mean()]))

In [None]:
locs = torch.Tensor([[[0.3, 0.2], [0.7, 0.1], [0.8, 0.8]]])
im = simulator3.draw_image_from_params(locs=locs, 
                                 fluxes = torch.Tensor([[[4500], [3000], [2500]]]), 
                                 n_stars = torch.Tensor([3]).type(torch.long))

In [None]:
# est_locs = torch.Tensor([[[0.1818, 0.1531],
#          [0.6115, 0.1333],
#          [0.7607, 0.8436]]])

est_locs = torch.Tensor([[[0.2, 0.8], [0.3, 0.7], [0.1, 0.8]]])

recon = simulator3.draw_image_from_params(locs=est_locs, 
                                 fluxes = torch.Tensor([[[4500], [3000], [2500]]]), 
                                 n_stars = torch.Tensor([3]).type(torch.long), 
                                add_noise = False)

In [None]:
plt.matshow(im.squeeze())
plt.colorbar()
plt.plot(locs[:, :, 1].squeeze().numpy() * 20, 
         locs[:, :, 0].squeeze().numpy() * 20, 
        'ob', label = 'truth')

plt.plot(est_locs[:, :, 1].squeeze().numpy() * 20, 
         est_locs[:, :, 0].squeeze().numpy() * 20, 
        'xr', label = 'est.')

plt.legend()
plt.title('observed \n')
plt.savefig('../../qualifying_exam_slides/figures/loss_example2_observed.png')

In [None]:
recon = simulator3.draw_image_from_params(locs=est_locs, 
                                 fluxes = torch.Tensor([[[4500], [3000], [2500]]]), 
                                 n_stars = torch.Tensor([3]).type(torch.long), 
                                add_noise = False)

plt.matshow(recon.squeeze())
plt.colorbar()

# plt.plot(locs[:, :, 1].squeeze() * 20, 
#          locs[:, :, 0].squeeze() * 20, 
#         'ob', label = 'truth')

plt.plot(est_locs[:, :, 1].squeeze().numpy() * 20, 
         est_locs[:, :, 0].squeeze().numpy() * 20, 
        'xr', label = 'est.')

plt.title('reconstructed \n')

plt.savefig('../../qualifying_exam_slides/figures/loss_example2_recon.png')


In [None]:
resid = (recon.squeeze() - im.squeeze())
plt.matshow(resid, vmax = resid.abs().max(), vmin = -resid.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

# plt.plot(locs[:, :, 1].squeeze() * 20, 
#          locs[:, :, 0].squeeze() * 20, 
#         'ob', label = 'truth')

# plt.plot(est_locs[:, :, 1].squeeze() * 20, 
#          est_locs[:, :, 0].squeeze() * 20, 
#         'xr', label = 'est.')

plt.title('recon. - observed \n')

plt.savefig('../../qualifying_exam_slides/figures/loss_example2_resid.png')
