In [None]:
import numpy as np
import timeit

import matplotlib.pyplot as plt

import torch
import torch.optim as optim

import fitsio 

import sys
sys.path.insert(0, './../')
import sdss_psf
import simulated_datasets_lib
import starnet_vae_lib
import sdss_dataset_lib
import plotting_utils
import image_statistics_lib
import utils

import inv_kl_objective_lib as inv_kl_lib

import image_utils

import time

import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('torch version: ', torch.__version__)

from copy import deepcopy

import flux_utils

In [None]:
np.random.seed(453453)
_ = torch.manual_seed(456456)

# Load data

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

In [None]:
data_params['max_stars'] = 1300

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

psf_og = torch.Tensor(np.array([psf_r, psf_i]))

n_bands = psf_og.shape[0]

background = torch.zeros(n_bands, data_params['slen'], data_params['slen']) * \
                    torch.Tensor([854., 1345.])[:, None, None]


In [None]:
# Draw from the same distribution I used int the sleep phase
n_images = 4

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                    data_params,
                    background = background,
                    n_images = n_images,
                    transpose_psf = False, 
                    add_noise = True)

images_full = simulated_dataset.images.detach()
backgrounds_full = simulated_dataset.background.detach()

true_n_stars = simulated_dataset.n_stars
true_full_locs = simulated_dataset.locs
true_full_fluxes = simulated_dataset.fluxes
        
simulator = simulated_dataset.simulator

# define optimizer

In [None]:
psf_og.shape

In [None]:
flux_estimator = flux_utils.EstimateFluxes(images_full, true_full_locs, true_n_stars,
                               simulator.psf, background.unsqueeze(0), pad = 5,
                                fmin = data_params['f_min'],
                                init_fluxes = None)

## Check out initialization

In [None]:
init_recon = flux_estimator.forward().detach()

In [None]:
band = 0

In [None]:
for i in range(images_full.shape[0]): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 6))

    im0 = axarr[0].matshow(images_full[i, band]); 
    f.colorbar(im0, ax = axarr[0])
    axarr[0].set_title('true sdss image')

    im1 = axarr[1].matshow(init_recon[i, band]); 
    f.colorbar(im1, ax = axarr[1])
    axarr[1].set_title('simulated sdss image')


    residual = torch.log10(images_full[i, band]) - torch.log10(init_recon[i, band])
    vmax = residual[10:90, 10:90].abs().max()
    im2 = axarr[2].matshow(residual[10:90, 10:90], vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr')); 
    f.colorbar(im2, ax = axarr[2])
    axarr[2].set_title('residual')

In [None]:
plt.hist(torch.log10(true_full_fluxes[:, :, 0].flatten()[true_full_fluxes[:, :, 0].flatten() > 0]), 
        bins = 100);
plt.hist(torch.log10(flux_estimator.return_fluxes()[:, :, 0].flatten()[\
                                    true_full_fluxes[:, :, 0].flatten() > 0]), alpha = 0.5, bins = 100);

# Optimize

In [None]:
flux_estimator.optimize(print_every = True)

In [None]:
print(flux_estimator.get_loss())

In [None]:
recon = flux_estimator.forward().detach()

for i in range(images_full.shape[0]): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 6))

    im0 = axarr[0].matshow(images_full[i, band]); 
    f.colorbar(im0, ax = axarr[0])
    axarr[0].set_title('true sdss image')

    im1 = axarr[1].matshow(recon[i, band]); 
    f.colorbar(im1, ax = axarr[1])
    axarr[1].set_title('simulated sdss image')


    residual = torch.log10(images_full[i, band]) - torch.log10(recon[i, band])
    vmax = residual[10:90, 10:90].abs().max()
    im2 = axarr[2].matshow(residual[10:90, 10:90], vmax = vmax, vmin = -vmax, cmap=plt.get_cmap('bwr')); 
    f.colorbar(im2, ax = axarr[2])
    axarr[2].set_title('residual')

In [None]:
plt.hist(torch.log10(true_full_fluxes[:, :, 0].flatten()[true_full_fluxes[:, :, 0].flatten() > 0]), 
        bins = 100);

plt.hist(torch.log10(flux_estimator.return_fluxes()[:, :, 0].flatten()[\
                                    true_full_fluxes[:, :, 0].flatten() > 0]), 
               bins = 100, alpha = 0.5);

