In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset
from torch import optim

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf
import simulated_datasets_lib

import wake_lib
from wake_lib import EstimateModelParams, PlanarBackground


from astropy.io import fits
from astropy.wcs import WCS


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import json 
import os

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

# Psf

In [None]:
import psf_transform_lib2

In [None]:
bands = [2, 3]
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
true_psf_params = torch.zeros(len(bands), 6)
for i in range(len(bands)):
    true_psf_params[i] = psf_transform_lib2.get_psf_params(
                                    psfield_file,
                                    band = bands[i])
power_law_psf = psf_transform_lib2.PowerLawPSF(true_psf_params)
psf_og = power_law_psf.forward().detach()

In [None]:
plt.matshow(simulated_datasets_lib._trim_psf(psf_og, 15)[0])
plt.colorbar()

# Draw data

In [None]:
true_background_params = torch.Tensor([[1157, -64, 118],
                                       [1706, -94, 147]])
true_background = PlanarBackground(init_background_params = true_background_params)

In [None]:
n_images = 1

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                    data_params,
                    background = true_background.forward().detach(), 
                    n_images = n_images,
                    transpose_psf = False, 
                    add_noise = True)

full_image = simulated_dataset.images.detach()
full_background = simulated_dataset.background.detach()

true_n_stars = simulated_dataset.n_stars
true_locs = simulated_dataset.locs
true_fluxes = simulated_dataset.fluxes
        
simulator = simulated_dataset.simulator

In [None]:
for i in range(len(bands)): 
    plt.matshow(full_image[0, i])
    plt.colorbar()

In [None]:
recon_mean_truth = simulator.draw_image_from_params(locs = true_locs, 
                            fluxes = true_fluxes, 
                            n_stars = true_n_stars, 
                            add_noise = False)

In [None]:
for i in range(len(bands)): 
    foo = ((recon_mean_truth[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

# Check estimation of fluxes

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                              init_psf_params = true_psf_params, 
                            init_background_params = true_background_params, 
                            init_fluxes = None)

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
optimizer = optim.LBFGS(list(estimator.flux_params_class.parameters()), 
                            max_iter = 10, 
                            line_search_fn = 'strong_wolfe')

In [None]:
estimator._run_optimizer(optimizer, tol = 1e-3, max_iter = 20, 
                         print_every = True, 
                        use_cached_star_basis = True)

In [None]:
est = torch.log10(estimator.get_fluxes()[0, 0:true_n_stars].detach())
truth = torch.log10(true_fluxes[0, 0:true_n_stars])

In [None]:
plt.plot(est.flatten(), truth.flatten(), '+')
plt.plot(est.flatten(), est.flatten())

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

# Check estimation of background 

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                          init_psf_params = true_psf_params, 
                        init_background_params = None, 
                        init_fluxes = true_fluxes)

In [None]:
list(estimator.planar_background.parameters())

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
optimizer = optim.LBFGS(list(estimator.planar_background.parameters()), 
                            max_iter = 10, 
                            line_search_fn = 'strong_wolfe')

In [None]:
estimator._run_optimizer(optimizer, tol = 1e-3, max_iter = 50, 
                         use_cached_star_basis = True, print_every = True)

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
list(estimator.planar_background.parameters())

In [None]:
true_background_params

# Check estimation of PSF

In [None]:
false_psf_params = true_psf_params * torch.exp(torch.randn(true_psf_params.shape))

In [None]:
false_psf_params

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                          init_psf_params = false_psf_params, 
                        init_background_params = true_background_params, 
                        init_fluxes = true_fluxes)

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
optimizer = optim.LBFGS(list(estimator.power_law_psf.parameters()), 
                            max_iter = 1, 
                            line_search_fn = 'strong_wolfe')

In [None]:
estimator._run_optimizer(optimizer, tol = 1e-6, max_iter = 50, 
                         use_cached_star_basis = False, print_every = True)

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

# OK lets jointly estimated the background and flux

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                          init_psf_params = true_psf_params, 
                        init_background_params = None, 
                        init_fluxes = None)

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
optimizer = optim.LBFGS(list(estimator.flux_params_class.parameters()) + 
                             list(estimator.planar_background.parameters()),
                    max_iter = 20, 
                    line_search_fn = 'strong_wolfe')

estimator._run_optimizer(optimizer, tol = 1e-3, max_iter = 10, 
                             print_every = True, 
                            use_cached_star_basis = True)

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
list(estimator.planar_background.parameters())

In [None]:
true_background_params

In [None]:
nelec_per_nmgy = 856.
is_on_bool = true_fluxes[0, :, 0] > 0
est = sdss_dataset_lib.convert_nmgy_to_mag(estimator.get_fluxes().detach()[0, is_on_bool, :] / nelec_per_nmgy)
truth = sdss_dataset_lib.convert_nmgy_to_mag(true_fluxes[0, is_on_bool, :] / nelec_per_nmgy)


In [None]:
plt.plot(est[:, 0], truth[:, 0], '+')
plt.plot(truth[:, 0], truth[:, 0], '-')

In [None]:
plt.plot(est[:, 1], truth[:, 1], '+')
plt.plot(truth[:, 1], truth[:, 1], '-')

In [None]:
import image_statistics_lib

In [None]:
image_statistics_lib.get_summary_stats(true_locs[0, is_on_bool], 
                                      true_locs[0, is_on_bool], 
                                      slen = 101, 
                                      est_fluxes = estimator.get_fluxes().detach()[0, is_on_bool, 0], 
                                      true_fluxes = true_fluxes[0, is_on_bool, 0],
                                      nelec_per_nmgy = nelec_per_nmgy)[0:2]

# Now lets see if coordinate ascent works

In [None]:
false_psf_params = true_psf_params * torch.exp(torch.randn(true_psf_params.shape) * 0.1)

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                          init_psf_params = false_psf_params, 
                            init_background_params = None, 
                            init_fluxes = None)

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
estimator.run_coordinate_ascent()

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
list(estimator.planar_background.parameters())

In [None]:
true_background_params

In [None]:
nelec_per_nmgy = 856.
is_on_bool = true_fluxes[0, :, 0] > 0
est = sdss_dataset_lib.convert_nmgy_to_mag(estimator.get_fluxes().detach()[0, is_on_bool, :] / nelec_per_nmgy)
truth = sdss_dataset_lib.convert_nmgy_to_mag(true_fluxes[0, is_on_bool, :] / nelec_per_nmgy)


In [None]:
plt.plot(est[:, 0], truth[:, 0], '+')
plt.plot(truth[:, 0], truth[:, 0], '-')

In [None]:
plt.plot(est[:, 1], truth[:, 1], '+')
plt.plot(truth[:, 1], truth[:, 1], '-')

In [None]:
import image_statistics_lib

In [None]:
image_statistics_lib.get_summary_stats(true_locs[0, is_on_bool], 
                                      true_locs[0, is_on_bool], 
                                      slen = 101, 
                                      est_fluxes = estimator.get_fluxes().detach()[0, is_on_bool, 0], 
                                      true_fluxes = true_fluxes[0, is_on_bool, 0],
                                      nelec_per_nmgy = nelec_per_nmgy)[0:2]

In [None]:
estimator.get_background().shape

In [None]:
estimator.get_psf().shape