In [None]:
import numpy as np

import torch
import torch.optim as optim
import matplotlib.pyplot as plt

import fitsio

import sys
sys.path.insert(0, '../')

import sdss_psf
import sdss_dataset_lib
import simulated_datasets_lib
import starnet_vae_lib

from psf_transform_lib import PsfLocalTransform, get_psf_loss

import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ', device)

print('torch version: ', torch.__version__)


# Load data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(bands = [2])

In [None]:
# image
full_image = sdss_hubble_data.sdss_image.unsqueeze(0).to(device)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0).to(device) 

# true parameters
true_full_locs = sdss_hubble_data.locs.unsqueeze(0).to(device)
true_full_fluxes = sdss_hubble_data.fluxes.unsqueeze(0).to(device)


In [None]:
true_full_fluxes.shape

In [None]:
# load psf
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

psf_og = np.array([psf_r])


In [None]:
# define transform
psf_transform = PsfLocalTransform(torch.Tensor(psf_og),
                                    full_image.shape[-1],
                                    kernel_size = 3)

In [None]:
w = torch.zeros(psf_transform.psf_slen ** 2, psf_transform.n_bands, psf_transform.kernel_size ** 2)
w[:, :, 4] = 100.

psf_transform.weights = w.to(device)


In [None]:
psf_init = psf_transform.forward().detach()

In [None]:
psf_transform.weight.shape

In [None]:
# load trained transform
psf_transform.load_state_dict(torch.load('../fits/results_11202019/identity_psf', 
                                         map_location=lambda storage, loc: storage))


In [None]:
foo = psf_transform.forward()

In [None]:
np.abs(simulated_datasets_lib._trim_psf(foo, 25).detach().numpy() - psf_og).max()

In [None]:
import numpy as np

In [None]:
psf_trained = psf_transform.forward().detach()
psf_trained.shape

In [None]:
# get variational parameters 
locs = true_full_locs
fluxes = true_full_fluxes
n_stars = torch.sum(true_full_fluxes[:, :, 0] > 0, dim = 1).type(torch.LongTensor)

In [None]:
n_stars

In [None]:
# get original recon mean
recon_means1, loss1 = get_psf_loss(full_image, full_background,
                    locs, fluxes, n_stars,
                    torch.Tensor(psf_init),
                    pad = 5)
print(loss1)

In [None]:
# get original recon mean
recon_means2, loss2 = get_psf_loss(full_image, full_background,
                    locs, fluxes, n_stars,
                    psf_trained,
                    pad = 5)

print(loss2)

In [None]:
band = 1

In [None]:
foo = (psf_trained - psf_init)[band, 40:60, 40:60].detach()
plt.matshow(foo, vmax = foo.abs().max(), vmin = -foo.abs().max(), cmap = plt.get_cmap('bwr'))
plt.colorbar()

In [None]:
residual1 = torch.log10(recon_means1 / full_image)[0, band, 5:95, 5:95]
residual2 = torch.log10(recon_means2 / full_image).detach()[0, band, 5:95, 5:95]

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

vmax = residual1.abs().max()
im0 = axarr[0].matshow(residual1.squeeze(), 
                      vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im0, ax=axarr[0])

# vmax = residual2.abs().max()
im1 = axarr[1].matshow(residual2.squeeze(), 
                      vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im1, ax=axarr[1])

axarr[2].matshow(full_image[0, band, 5:95, 5:95])

In [None]:
plt.hist(residual2.flatten(), bins = 100);

In [None]:
import plotting_utils

In [None]:
fig, axarr = plt.subplots(2, 3, figsize = (20, 12))

x0 = 10
x1 = 50
slen = 30

#################
# ORIGINAL PSF
#################
plotting_utils.plot_subimage(axarr[0, 0], torch.log10(full_image.squeeze()), 
                                None, 
                                true_full_locs.squeeze(), 
                                x0, x1, subimage_slen = slen, 
                                add_colorbar = True, global_fig = fig)
axarr[0, 0].set_title('observed')

plotting_utils.plot_subimage(axarr[0, 1], recon_means1.squeeze(), 
                                None, 
                                true_full_locs.squeeze(), 
                                x0, x1, subimage_slen = slen, 
                                add_colorbar = True, global_fig = fig)
axarr[0, 1].set_title('simulated, original PSF')

_resid = torch.log(recon_means1.squeeze() / full_image.squeeze())
vmax = _resid[x0:(x0 + slen), x1:(x1 + slen)].abs().max()
plotting_utils.plot_subimage(axarr[0, 2], _resid, 
                                None, 
                                true_full_locs.squeeze(), 
                                x0, x1, subimage_slen = slen, 
                                vmin = -vmax, vmax = vmax, 
                                diverging_cmap = True, 
                                add_colorbar = True, global_fig = fig)
axarr[0, 2].set_title('residual, original PSF')

#################
# TRANSFORMED PSF
#################
plotting_utils.plot_subimage(axarr[1, 0], torch.log10(full_image.squeeze()), 
                                None,
                                true_full_locs.squeeze(), 
                                x0, x1, subimage_slen = slen, 
                                add_colorbar = True, global_fig = fig)
axarr[1, 0].set_title('observed')

plotting_utils.plot_subimage(axarr[1, 1], recon_means2.squeeze().detach(), 
                                None, 
                                true_full_locs.squeeze(), 
                                x0, x1, subimage_slen = slen, 
                                add_colorbar = True, global_fig = fig)
axarr[1, 1].set_title('simulated, original PSF')

_resid2 = torch.log(recon_means2.squeeze().detach() / full_image.squeeze())
vmax = _resid[x0:(x0 + slen), x1:(x1 + slen)].abs().max()
plotting_utils.plot_subimage(axarr[1, 2], _resid2, 
                                None, 
                                true_full_locs.squeeze(), 
                                x0, x1, subimage_slen = slen, 
                                vmin = -vmax, vmax = vmax, 
                                diverging_cmap = True, 
                                add_colorbar = True, global_fig = fig)
axarr[1, 2].set_title('residual, original PSF')


In [None]:
# fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

# im0 = axarr[0].matshow(full_image.squeeze())
# fig.colorbar(im0, ax=axarr[0])
# axarr[0].set_title('observed sdss image\n')

# im1 = axarr[1].matshow(recon_means1.squeeze())
# fig.colorbar(im1, ax=axarr[1])
# axarr[1].set_title('simulated image\n')


# im2 = axarr[2].matshow(residual1.squeeze(), 
#                 vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
# fig.colorbar(im2, ax=axarr[2])
# axarr[2].set_title('residual: (observed - simulated)/observed \n')




In [None]:
# fig, axarr = plt.subplots(1, 2, figsize=(10, 4))


# vmax = residual1.abs().max()
# im0 = axarr[0].matshow(residual1.squeeze(), 
#                       vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
# fig.colorbar(im0, ax=axarr[0])

# vmax = residual2.abs().max()
# im1 = axarr[1].matshow(residual2.squeeze(), 
#                       vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
# fig.colorbar(im1, ax=axarr[1])

# # axarr[2].matshow(full_image.squeeze()[5:95, 5:95])