In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset
from torch import optim

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf
import simulated_datasets_lib

from wake_lib import EstimateModelParams, BackgroundBias


from astropy.io import fits
from astropy.wcs import WCS


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import json 
import os

In [None]:
# data parameters
with open('../data/default_star_parameters.json', 'r') as fp:
    data_params = json.load(fp)

# Psf

In [None]:
import psf_transform_lib2

In [None]:
bands = [2, 3]
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
init_psf_params = torch.zeros(len(bands), 6)
for i in range(len(bands)):
    init_psf_params[i] = psf_transform_lib2.get_psf_params(
                                    psfield_file,
                                    band = bands[i])
power_law_psf = psf_transform_lib2.PowerLawPSF(init_psf_params)
psf_og = power_law_psf.forward().detach()

# Draw data

In [None]:
background_bias = BackgroundBias(init_background_params = torch.Tensor([[303.6371, -64.7714, 118.8480],
                                                                             [361.8250, -94.0563, 147.5636]]))


In [None]:
background = torch.ones(len(bands), data_params['slen'], data_params['slen']) * \
                    torch.Tensor([854., 1345.])[:, None, None] + \
                background_bias.forward().detach()

In [None]:
n_images = 1

simulated_dataset = \
    simulated_datasets_lib.load_dataset_from_params(psf_og,
                    data_params,
                    background = background, 
                    n_images = n_images,
                    transpose_psf = False, 
                    add_noise = True)

full_image = simulated_dataset.images.detach()
full_background = simulated_dataset.background.detach()

true_n_stars = simulated_dataset.n_stars
true_locs = simulated_dataset.locs
true_fluxes = simulated_dataset.fluxes
        
simulator = simulated_dataset.simulator

In [None]:
recon_mean_truth = simulator.draw_image_from_params(locs = true_locs, 
                            fluxes = true_fluxes, 
                            n_stars = true_n_stars, 
                            add_noise = False)

In [None]:
for i in range(len(bands)): 
    foo = ((recon_mean_truth[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

# Check estimation of fluxes

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                          init_psf_params = init_psf_params, 
                        init_background = full_background, 
                        init_fluxes = None)

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
optimizer = optim.LBFGS(list(estimator.flux_params_class.parameters()), 
                            max_iter = 20, 
                            line_search_fn = 'strong_wolfe')

In [None]:
estimator._run_optimizer(optimizer, tol = 1e-3, max_iter = 5, print_every = True)

In [None]:
est = torch.log10(estimator.get_fluxes()[0, 0:true_n_stars].detach())
truth = torch.log10(true_fluxes[0, 0:true_n_stars])

In [None]:
plt.plot(est.flatten(), truth.flatten(), '+')
plt.plot(est.flatten(), est.flatten())

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

# Check estimation of background 

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                          init_psf_params = init_psf_params, 
                        init_background = full_background - background_bias.forward().detach(), 
                        init_fluxes = true_fluxes)

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
optimizer = optim.LBFGS(list(estimator.background_bias.parameters()), 
                            max_iter = 20, 
                            line_search_fn = 'strong_wolfe')

In [None]:
def closure():
    optimizer.zero_grad()
    loss = estimator.get_loss()[1]
    loss.backward()
    print(loss)
    return loss

In [None]:
_ = optimizer.step(closure)

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
list(estimator.background_bias.parameters())

In [None]:
list(background_bias.parameters())

# Check estimation of PSF

In [None]:
false_psf_params = torch.Tensor(np.load('../fits/results_2020-02-04/true_psf_params.npy'))

In [None]:
estimator = EstimateModelParams(full_image, true_locs, true_n_stars, 
                          init_psf_params = false_psf_params, 
                        init_background = full_background, 
                        init_fluxes = true_fluxes)

In [None]:
recon0 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon0[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()

In [None]:
optimizer = optim.LBFGS(list(estimator.power_law_psf.parameters()), 
                            max_iter = 50, 
                            line_search_fn = 'strong_wolfe')

In [None]:
def closure():
    optimizer.zero_grad()
    loss = estimator.get_loss()[1]
    loss.backward()
    print(loss)
    return loss

In [None]:
_ = optimizer.step(closure)

In [None]:
recon1 = estimator.get_loss()[0].detach()
for i in range(len(bands)): 
    foo = ((recon1[0, i] - full_image[0, i]) / full_image[0, i])[5:95, 5:95]
    plt.matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr')) 
    plt.colorbar()