In [None]:
import numpy as np

import torch
import torch.optim as optim
import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')

import sdss_dataset_lib
import simulated_datasets_lib
import starnet_vae_lib
import kl_objective_lib

from psf_transform_lib import PsfLocalTransform, get_psf_transform_loss

import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device: ', device)

print('torch version: ', torch.__version__)


# Load data

In [None]:
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(
                                    sdssdir='../../celeste_net/sdss_stage_dir/',
                                    hubble_cat_file = '../hubble_data/NCG7089/' + \
                                        'hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt.txt', 
                                    )


In [None]:
# image
full_image = sdss_hubble_data.sdss_image.unsqueeze(0).to(device)
full_background = sdss_hubble_data.sdss_background.unsqueeze(0).to(device)

# true parameters
true_full_locs = sdss_hubble_data.locs.unsqueeze(0).to(device)
true_full_fluxes = sdss_hubble_data.fluxes.unsqueeze(0).to(device)


In [None]:
# original simulator
simulator1 = simulated_datasets_lib.StarSimulator(
                    psf_fit_file=str(sdss_hubble_data.psf_file),
                    slen = full_image.shape[-1],
                    sky_intensity = 0.)

# simulator that we will edit with the trained psf
simulator2 = simulated_datasets_lib.StarSimulator(
                    psf_fit_file=str(sdss_hubble_data.psf_file),
                    slen = full_image.shape[-1],
                    sky_intensity = 0.)


In [None]:
# define transform
psf_transform = PsfLocalTransform(torch.Tensor(simulator1.psf_og),
                                    simulator1.slen,
                                    kernel_size = 3)

In [None]:
# load trained transform
psf_transform.load_state_dict(torch.load('../fits/wake_sleep-portm2-101420129-psf_transform-iter2', 
                                         map_location=lambda storage, loc: storage))

In [None]:
# load vae
star_encoder = starnet_vae_lib.StarEncoder(full_slen = full_image.shape[-1],
                                           stamp_slen = 9,
                                           step = 2,
                                           edge_padding = 3,
                                           n_bands = 1,
                                           max_detections = 4)


In [None]:
# get variational parameters 

_, subimage_locs, subimage_fluxes, _, _ = \
    star_encoder.get_image_stamps(full_image, true_full_locs, true_full_fluxes,
                                        trim_images = False)
    


In [None]:
# get original recon mean
recon_means1, loss1 = get_psf_transform_loss(full_image, full_background,
                                subimage_locs,
                                subimage_fluxes,
                                star_encoder.tile_coords,
                                star_encoder.stamp_slen,
                                star_encoder.edge_padding,
                                simulator1,
                                psf_transform = None)
print(loss1)

In [None]:
# get original recon mean
recon_means2, loss2 = get_psf_transform_loss(full_image, full_background,
                                subimage_locs,
                                subimage_fluxes,
                                star_encoder.tile_coords,
                                star_encoder.stamp_slen,
                                star_encoder.edge_padding,
                                simulator2,
                                psf_transform = psf_transform)

print(loss2)

In [None]:
simulator1.psf.sum()

In [None]:
simulator2.psf.sum()

In [None]:
plt.matshow((simulator1.psf - simulator2.psf.detach())[40:60, 40:60])
plt.colorbar()

In [None]:
residual1 = torch.log10(recon_means1 / full_image)[0, 0, 5:95, 5:95]
residual2 = torch.log10(recon_means2 / full_image).detach()[0, 0, 5:95, 5:95]

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

vmax = residual1.abs().max()
im0 = axarr[0].matshow(residual1.squeeze(), 
                      vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im0, ax=axarr[0])

vmax = residual2.abs().max()
im1 = axarr[1].matshow(residual2.squeeze(), 
                      vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im1, ax=axarr[1])

axarr[2].matshow(full_image.squeeze()[5:95, 5:95])

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 4))

im0 = axarr[0].matshow(full_image.squeeze())
fig.colorbar(im0, ax=axarr[0])
axarr[0].set_title('observed sdss image\n')

im1 = axarr[1].matshow(recon_means1.squeeze())
fig.colorbar(im1, ax=axarr[1])
axarr[1].set_title('simulated image\n')


im2 = axarr[2].matshow(residual1.squeeze(), 
                vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im2, ax=axarr[2])
axarr[2].set_title('residual: (observed - simulated)/observed \n')




In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(10, 4))


vmax = residual1.abs().max()
im0 = axarr[0].matshow(residual1.squeeze(), 
                      vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im0, ax=axarr[0])

vmax = residual2.abs().max()
im1 = axarr[1].matshow(residual2.squeeze(), 
                      vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))
fig.colorbar(im1, ax=axarr[1])

# axarr[2].matshow(full_image.squeeze()[5:95, 5:95])

In [None]:
plt.matshow(residual1.squeeze()[20:30, 49:59], vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))

In [None]:
plt.matshow(residual2.squeeze()[20:30, 49:59], vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr'))