In [None]:
import torch 
import numpy as np

import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '../')

import residuals_vae_lib
import simulated_datasets_lib
import sdss_dataset_lib

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


 # get hubble data

In [None]:
hubble_cat_file='../hubble_data/NCG7078/hlsp_acsggct_hst_acs-wfc_ngc7078_r.rdviq.cal.adj.zpt.txt'
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(hubble_cat_file=hubble_cat_file, 
                                                   slen = 11, 
                                                   run = 2566, 
                                                   camcol = 6, 
                                                   field = 65, 
                                                max_detections = 20)

In [None]:
len(sdss_hubble_data)

In [None]:
# get simulator 
sky_intensity = sdss_hubble_data.sdss_background_full.mean()

simulator = simulated_datasets_lib.StarSimulator(psf_fit_file=sdss_hubble_data.psf_file, 
                                    slen = sdss_hubble_data.slen, 
                                    sky_intensity = sky_intensity)

# Check out "truth" residuals

In [None]:
for i in range(0, 10): 
    
    indx = np.random.choice(len(sdss_hubble_data), 1)
        
    data = sdss_hubble_data[indx]
        
    fig, axarr = plt.subplots(1, 3, figsize=(15, 4))
    
    true_image = data['image'].squeeze()
    vmin = np.min(true_image)
    vmax = np.max(true_image)
        
    # plot observed image 
    foo = axarr[0].matshow(true_image, vmin=vmin, vmax=vmax)
    axarr[0].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')

    axarr[0].set_title('observed; n_stars = {}; indx = {}\n'.format(data['n_stars'], indx))
    
    # use parameters, simulate image 
    simulated_image = \
        simulator.draw_image_from_params(locs = torch.Tensor(data['locs']).unsqueeze(0), 
                                         fluxes = torch.Tensor(data['fluxes']).unsqueeze(0), 
                                         n_stars = torch.Tensor([data['n_stars']]), 
                                         add_noise = False)
        
    simulated_image = simulated_image.squeeze().numpy() 
    axarr[1].matshow(simulated_image, vmin=vmin, vmax=vmax)
    axarr[1].plot(data['locs'][0:data['n_stars'], 1] * (sdss_hubble_data.slen - 1), 
                 data['locs'][0:data['n_stars'], 0] * (sdss_hubble_data.slen - 1), 'x', color = 'r')
    axarr[1].set_title('simulated; flux rat = {:06f}\n'.format(vmax / np.max(simulated_image)))
    
    fig.colorbar(foo, ax=[axarr[0], axarr[1]])
    
    foo2 = axarr[2].matshow(true_image - simulated_image)
    fig.colorbar(foo2, ax=[axarr[2]])

# Define VAE

In [None]:
residual_vae = residuals_vae_lib.ResidualVAE(slen = 11, n_bands = 1, f_min = 1300.)

# create dataset

In [None]:
batchsize = 64
loader = torch.utils.data.DataLoader(
                 dataset=sdss_hubble_data,
                 batch_size=batchsize,
                 shuffle=False)

In [None]:
for _, data in enumerate(loader):
    true_fluxes = data['fluxes'].to(device).type(torch.float)
    true_locs = data['locs'].to(device).type(torch.float)
    true_n_stars = data['n_stars'].to(device)
    images = data['image'].to(device)
    backgrounds = data['background'].to(device)
    
    
    simulated_images = \
        simulator.draw_image_from_params(locs = true_locs, 
                                         fluxes = true_fluxes,
                                         n_stars = true_n_stars,
                                         add_noise = False)
    
    residual_image = images - simulated_images
    
    loss = residuals_vae_lib.get_residual_vae_loss(residual_image, residual_vae)
    
    break

In [None]:
recon_mean, recon_logvar, eta_mean, eta_logvar = residual_vae(residual_image, sample = True)

In [None]:
plt.hist(recon_logvar.flatten().detach().numpy())

In [None]:
residuals_vae_lib.eval_residual_vae(residual_vae, loader, simulator)