In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch import optim

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf

from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

from psf_transform_lib import get_psf_loss
import psf_transform_lib2
import fitsio

In [None]:
# load data
bands = [2, 3]
x0 = 630
x1 = 310
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = x0,
                                                    x1 = x1, 
                                                    bands = bands)


In [None]:
sdss_hubble_data.nelec_per_nmgy.mean()

In [None]:
# the full image
plt.matshow(sdss_hubble_data.sdss_image_full[0])

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full[0])
plt.plot(sdss_hubble_data.locs_full_x1, 
         sdss_hubble_data.locs_full_x0, alpha = 0.2)

In [None]:
# check patch 

for i in range(len(bands)):
    plt.matshow(sdss_hubble_data.sdss_image[i])
    plt.colorbar()


In [None]:
# check alignment between bands 

if len(bands) > 1: 
    band_diff = (sdss_hubble_data.sdss_image[1]) - \
                (sdss_hubble_data.sdss_image[0])

    plt.matshow(band_diff, vmax = band_diff.abs().max(), vmin = -band_diff.abs().max(), 
                cmap = plt.get_cmap('bwr'))

    plt.colorbar()

# SDSS PSF

In [None]:
psf_dir = '../data/'
psf_r = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-r.fits')[0].read()
psf_i = fitsio.FITS(psf_dir + 'sdss-002583-2-0136-psf-i.fits')[0].read()

# Power law PSF

In [None]:
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'

In [None]:
init_psf_params = torch.zeros(len(bands), 6)
for i in range(len(bands)): 
    init_psf_params[i] = psf_transform_lib2.get_psf_params(
                                    psfield_file, 
                                    band = bands[i])

In [None]:
power_law_psf = psf_transform_lib2.PowerLawPSF(init_psf_params)

In [None]:
init_psf = power_law_psf.forward().detach()

In [None]:
plt.matshow(init_psf[0, 46:55, 46:55])
plt.colorbar()

# Get my simulator

In [None]:
import simulated_datasets_lib

In [None]:
sky_intensity = sdss_hubble_data.sdss_background.reshape(len(bands), -1).mean(1)

simulator = simulated_datasets_lib.StarSimulator(psf = init_psf,
                                    slen = sdss_hubble_data.slen, 
                                    transpose_psf = False, 
                                    background = sdss_hubble_data.sdss_background)

# Check reconstruction with all of the stars

In [None]:
_fluxes = sdss_hubble_data.fluxes.unsqueeze(0)
_locs = sdss_hubble_data.locs.unsqueeze(0)
_n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor)
                                                               
recon_mean_dense = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes, 
                                                n_stars = _n_stars, 
                                                add_noise = False) 

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(recon_mean_dense[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = recon_mean_dense[0,i] - observed[i]
    foo = (residual / observed[i])
    im2 = axarr[2].matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

# Reconstruction with only bright stars

In [None]:
fmin = 1000.

In [None]:
which_bright = sdss_hubble_data.fluxes[:, 0] > fmin

_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)
_n_stars = torch.Tensor([len(_locs[0])]).type(torch.LongTensor)

recon_mean = simulator.draw_image_from_params(locs = _locs, 
                                                fluxes = _fluxes, 
                                                n_stars = _n_stars, 
                                                add_noise = False) 

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(recon_mean[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = recon_mean[0,i] - observed[i]
    foo = (residual / observed[i])
    im2 = axarr[2].matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

# Estimate the background

In [None]:
import wake_lib

In [None]:
estimator = wake_lib.EstimateModelParams(sdss_hubble_data.sdss_image.unsqueeze(0), 
                                            locs = _locs, 
                                            n_stars = _n_stars, 
                                            init_psf_params = init_psf_params, 
                                            init_background_params = None, 
                                            init_fluxes = _fluxes)

In [None]:
background_optimizer = optim.LBFGS(list(estimator.planar_background.parameters()), 
                            max_iter = 10, 
                            line_search_fn = 'strong_wolfe')

In [None]:
estimator._run_optimizer(background_optimizer, tol = 1e-3, use_cached_star_basis=True)

In [None]:
recon_mean1 = estimator.get_loss(use_cached_star_basis=True)[0].detach()

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(recon_mean1[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = recon_mean1[0,i] - observed[i]
    foo = (residual / observed[i])
    im2 = axarr[2].matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

In [None]:
plt.matshow(estimator.planar_background.forward().detach()[1])
plt.colorbar()

# Now estimate the PSF 

In [None]:
psf_optimizer = optim.LBFGS(list(estimator.power_law_psf.parameters()), 
                            max_iter = 10, 
                            line_search_fn = 'strong_wolfe')

In [None]:
estimator._run_optimizer(psf_optimizer, tol = 1e-6, max_iter = 50, 
                         use_cached_star_basis = False, 
                         print_every = True)

In [None]:
recon_mean2 = estimator.get_loss(use_cached_star_basis=False)[0].detach()

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(recon_mean2[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = recon_mean2[0,i] - observed[i]
    foo = (residual / observed[i])
    im2 = axarr[2].matshow(foo, vmax = foo.abs().max(), vmin = - foo.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

In [None]:
band = 1

In [None]:
plt.matshow(simulated_datasets_lib._trim_psf(init_psf, 15)[band]); 
plt.colorbar()

In [None]:
trained_psf = estimator.get_psf().detach()

In [None]:
plt.matshow(simulated_datasets_lib._trim_psf(trained_psf, 15)[band]); 
plt.colorbar()

In [None]:
estimator.state_dict()

In [None]:
np.save('../fits/results_2020-02-06/true_powerlaw_psf_params', 
        list(estimator.power_law_psf.parameters())[0].data.numpy())

In [None]:
np.save('../fits/results_2020-02-06/true_planarback_params', 
       list(estimator.planar_background.parameters())[0].data.numpy())

In [None]:
wake_lib.PlanarBackground()

In [None]:
sdss_hubble_data.sdss_background[0].mean()