In [None]:
import torch 
import numpy as np

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')

import residuals_vae_lib
import simulated_datasets_lib
import sdss_dataset_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


 # get hubble data

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   slen = 11, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65, 
                                                max_detections = 20)

In [None]:
len(sdss_hubble_data)

In [None]:
# get simulator 
sky_intensity = sdss_hubble_data.sdss_background_full.mean()

simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=sdss_hubble_data.psf_file, 
                                    slen = sdss_hubble_data.slen, 
                                    sky_intensity = sky_intensity)

# Get loader


In [None]:
batchsize = len(sdss_hubble_data)
loader = torch.utils.data.DataLoader(
                 dataset=sdss_hubble_data,
                 batch_size=batchsize,
                 shuffle=True)


# Define VAE

In [None]:
residual_vae = residuals_vae_lib.ResidualVAE(slen = sdss_hubble_data.slen,
                                            n_bands = 1,
                                            f_min = 2000)

In [None]:
residual_vae.load_state_dict(torch.load('../fits/residual_vae', 
                               map_location=lambda storage, loc: storage))
residual_vae.eval(); 

In [None]:
loss = residuals_vae_lib.eval_residual_vae(residual_vae, loader, simulator, train = False)

In [None]:
print('{:.6E}'.format(loss))

# Check out residuals

In [None]:
return_unnormalized = True

In [None]:
for _, data in enumerate(loader):
    # true parameters
    true_fluxes = data['fluxes'].to(device).type(torch.float)
    true_locs = data['locs'].to(device).type(torch.float)
    true_n_stars = data['n_stars'].to(device)
    images = data['image'].to(device)
    backgrounds = data['background'].to(device)

    # reconstruction
    simulated_images = \
        simulator.draw_image_from_params(locs = true_locs,
                                         fluxes = true_fluxes,
                                         n_stars = true_n_stars,
                                         add_noise = False)
    
    # reconstructed residuals 
    recon_residual, recon_logvar, eta_mean, eta_logvar, normalized_residual = \
        residual_vae(images, simulated_images, return_unnormalized = return_unnormalized)
    
    break

In [None]:
plt.hist(normalized_residual.flatten(), bins = 50);

In [None]:
for indx in range(16, 17): 
    
    fig, axarr = plt.subplots(1, 4, figsize=(15, 4))
    
    image_i = images[indx].squeeze()
    vmin = torch.min(image_i)
    vmax = torch.max(image_i)
    
    locs_i = true_locs[indx, 0:int(true_n_stars[indx]), :]
    
    # plot observed image 
    foo = axarr[0].matshow(image_i, vmin=vmin, vmax=vmax)
    axarr[0].plot(locs_i[:, 1] * (sdss_hubble_data.slen - 1), 
                 locs_i[:, 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')

    axarr[0].set_title('observed \n')
    
    simulated_image_i = simulated_images[indx].squeeze()
    
    axarr[1].matshow(simulated_image_i, vmin=vmin, vmax=vmax)
    axarr[1].plot(locs_i[:, 1] * (sdss_hubble_data.slen - 1), 
                 locs_i[:, 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')

    axarr[1].set_title('simulated \n')
    
    fig.colorbar(foo, ax=[axarr[0], axarr[1]])
    
    residual_image_i = normalized_residual[indx].squeeze()
    
    vmax = torch.max(torch.max(residual_image_i))
    foo2 = axarr[2].matshow(residual_image_i, vmin=-vmax, vmax=vmax, cmap=plt.get_cmap('bwr'))
    axarr[2].set_title('residual')
    
    recon_residual_i = recon_residual[indx].squeeze().detach().numpy()
    axarr[3].matshow(recon_residual_i, vmin=-vmax, vmax=vmax, cmap=plt.get_cmap('bwr'))
    axarr[3].set_title('reconstructed residual')
    
    fig.colorbar(foo2, ax=[axarr[2], axarr[3]])


# the generative model

In [None]:
eta = torch.randn(10, residual_vae.latent_dim)

In [None]:
recon_mean, recon_logvar = residual_vae.decode(eta)

In [None]:
for i in range(10): 
    plt.matshow(recon_mean[i, 0, :, :].detach())
    plt.colorbar()

In [None]:
for i in range(10): 
    plt.matshow(recon_mean[i, 0, :, :].detach() + \
                torch.exp(recon_logvar[i, 0, :, :].detach() * 0.5) * torch.randn(11, 11))