In [None]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import sys
sys.path.insert(0, '../')
import sdss_dataset_lib
import sdss_psf

import psf_transform_lib
import wake_lib

from astropy.io import fits
from astropy.wcs import WCS

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import os

In [None]:
# load data
bands = [2, 3]
x0 = 630
x1 = 310
sdss_hubble_data = sdss_dataset_lib.SDSSHubbleData(x0 = x0,
                                                    x1 = x1, 
                                                    bands = bands)


In [None]:
# the full image
plt.matshow(sdss_hubble_data.sdss_image_full[0])

In [None]:
# check the hubble coordinates overlap with the globular cluster
plt.matshow(sdss_hubble_data.sdss_image_full[0])
plt.plot(sdss_hubble_data.locs_full_x1, 
         sdss_hubble_data.locs_full_x0, alpha = 0.2)

In [None]:
# check patch 

for i in range(len(bands)):
    plt.matshow(sdss_hubble_data.sdss_image[i])
    plt.colorbar()


In [None]:
# check alignment between bands 

if len(bands) > 1: 
    band_diff = (sdss_hubble_data.sdss_image[1]) - \
                (sdss_hubble_data.sdss_image[0])

    plt.matshow(band_diff, vmax = band_diff.abs().max(), vmin = -band_diff.abs().max(), 
                cmap = plt.get_cmap('bwr'))

    plt.colorbar()

# Distribution of colors

In [None]:
if len(bands) > 1: 
    foo = (sdss_hubble_data.sdss_image[1]) / \
                (sdss_hubble_data.sdss_image[0])

    foo = torch.log10(foo).flatten() * (2.5)
    plt.hist(foo, bins = 100);

    print(foo.mean())
    print(foo.var().sqrt())

In [None]:
if len(bands) > 1: 

    foo = (sdss_hubble_data.fluxes[:, 1]) / \
                (sdss_hubble_data.fluxes[:, 0])

    foo = torch.log10(foo).flatten() * (2.5)
    plt.hist(foo, bins = 100);

    print(foo.mean())
    print(foo.var().sqrt())

# plot a few subimages

In [None]:
fmin = 500.

In [None]:
import plotting_utils

In [None]:
x0_vec = np.arange(0, 100, 10)
x1_vec = x0_vec

In [None]:
for i in range(6): 
    x0 = int(np.random.choice(x0_vec, 1))
    x1 = int(np.random.choice(x1_vec, 1))
    
    which_bright = (sdss_hubble_data.fluxes > fmin)[:, 0]
    
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))
    plotting_utils.plot_subimage(axarr[0], 
                                sdss_hubble_data.sdss_image[0], 
                                None, 
                                sdss_hubble_data.locs[which_bright], 
                                x0, x1, 
                                subimage_slen = 10, 
                                add_colorbar = True, 
                                global_fig = f)
    axarr[0].set_title('band ' + str(bands[0]))
    
    if len(bands) > 1: 
        plotting_utils.plot_subimage(axarr[1], 
                                    sdss_hubble_data.sdss_image[1], 
                                    None, 
                                    sdss_hubble_data.locs[which_bright], 
                                    x0, x1, 
                                    subimage_slen = 10, 
                                    add_colorbar = True, 
                                    global_fig = f)
        axarr[1].set_title('band ' + str(bands[1]))

        plotting_utils.plot_subimage(axarr[2], 
                                    (sdss_hubble_data.sdss_image[1]) - \
                                     (sdss_hubble_data.sdss_image[0]), 
                                    None, 
                                    sdss_hubble_data.locs[which_bright], 
                                    x0, x1, 
                                    subimage_slen = 10, 
                                    add_colorbar = True, 
                                    global_fig = f, 
                                    diverging_cmap = True)

# Get true parameters

In [None]:
filter_by_bright = True

In [None]:
if filter_by_bright: 
    
    which_bright = sdss_hubble_data.fluxes[:, 0] > fmin
    
    true_locs = sdss_hubble_data.locs[which_bright].unsqueeze(0)
    true_fluxes = sdss_hubble_data.fluxes[which_bright].unsqueeze(0)
    true_n_stars = torch.Tensor([len(true_locs[0])]).type(torch.LongTensor)
else: 
    true_fluxes = sdss_hubble_data.fluxes.unsqueeze(0)
    true_locs = sdss_hubble_data.locs.unsqueeze(0)
    true_n_stars = torch.Tensor([len(sdss_hubble_data.locs)]).type(torch.LongTensor)

In [None]:
plt.hist(torch.log10(true_fluxes[0, :, 0]))

In [None]:
if len(bands) > 0: 
    foo = (true_fluxes[0, :, 1]) / \
            (true_fluxes[0, :, 0])

    foo = torch.log10(foo).flatten() * (2.5)
    plt.hist(foo, bins = 100);

    print(foo.mean())
    print(foo.var().sqrt())

# Load initial PSF and background

In [None]:
bands = [2, 3]
psfield_file = '../../celeste_net/sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
init_psf_params = psf_transform_lib.get_psf_params(
                                    psfield_file,
                                    bands = bands)

In [None]:
init_background_params = torch.zeros(len(bands), 3).to(device)
init_background_params[:, 0] = torch.Tensor([686., 1123.])


In [None]:
model_params = wake_lib.ModelParams(sdss_hubble_data.sdss_image.unsqueeze(0), 
                                    init_psf_params,
                                    init_background_params)

# Check out reconstructions

In [None]:
init_recon = model_params.get_loss(locs = true_locs, 
                                   fluxes = true_fluxes, 
                                   n_stars = true_n_stars)[0].detach()

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(init_recon[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = 2.5 * (torch.log10(init_recon[0,i]) - torch.log10(observed[i]))
    im2 = axarr[2].matshow(residual, vmax = residual.abs().max(), 
                           vmin = - residual.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

# Optimize background

In [None]:
model_params.get_loss(use_cached_stars=True)[1]

In [None]:
model_params._get_init_background()

In [None]:
model_params.get_loss(use_cached_stars=True)[1]

In [None]:
from torch import optim

In [None]:
background_optimizer = optim.LBFGS(model_params.planar_background.parameters(),
                    max_iter = 20,
                    line_search_fn = 'strong_wolfe')

def back_closure(): 
    background_optimizer.zero_grad()
    loss = model_params.get_loss(use_cached_stars=True)[1]
    loss.backward()

    return loss

In [None]:
for i in range(5): 
    loss = background_optimizer.step(back_closure)

In [None]:
list(model_params.planar_background.parameters())

In [None]:
recon1, loss = model_params.get_loss(use_cached_stars=True)
print(loss)
recon1 = recon1.detach()

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(recon1[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = 2.5 * (torch.log10(recon1[0,i]) - torch.log10(observed[i]))[5:95, 9:95]
    im2 = axarr[2].matshow(residual, vmax = residual.abs().max(), 
                           vmin = - residual.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

# Optimize psf

In [None]:
psf_optimizer = optim.LBFGS(model_params.power_law_psf.parameters(),
                    max_iter = 20,
                    line_search_fn = 'strong_wolfe')

def psf_closure(): 
    psf_optimizer.zero_grad()
    loss = model_params.get_loss(locs = true_locs, 
                                   fluxes = true_fluxes, 
                                   n_stars = true_n_stars)[1]
    loss.backward()

    return loss

In [None]:
for i in range(1): 
    loss = psf_optimizer.step(psf_closure)

In [None]:
_, loss = model_params.get_loss(locs = true_locs, 
                                   fluxes = true_fluxes, 
                                   n_stars = true_n_stars)
print(loss)

In [None]:
list(model_params.power_law_psf.parameters())

In [None]:
init_psf_params

In [None]:
from simulated_datasets_lib import _trim_psf

In [None]:
b = 0
f, axarr = plt.subplots(1, 3, figsize=(16, 4))
# original psf
im0 = axarr[0].matshow(_trim_psf(model_params.init_psf, 15)[b])
f.colorbar(im0, ax = axarr[0])

# estimated psf
im1 = axarr[1].matshow(_trim_psf(model_params.get_psf().detach(), 15)[b])
f.colorbar(im1, ax = axarr[1])

# difference
im2 = axarr[2].matshow(_trim_psf(model_params.get_psf().detach() - model_params.init_psf, 15)[b])
f.colorbar(im2, ax = axarr[2])

# Run coordinate ascent 

In [None]:
optimizer = optim.LBFGS(model_params.parameters(),
                    max_iter = 20,
                    line_search_fn = 'strong_wolfe')

def closure(): 
    optimizer.zero_grad()
    loss = model_params.get_loss(locs = true_locs, 
                                   fluxes = true_fluxes, 
                                   n_stars = true_n_stars)[1]
    loss.backward()

    return loss

In [None]:
for i in range(1): 
    _ = optimizer.step(closure)

In [None]:
recon2, loss = model_params.get_loss(locs = true_locs, 
                                   fluxes = true_fluxes, 
                                   n_stars = true_n_stars)
print(loss)
recon2 = recon2.detach()

In [None]:
for i in range(len(bands)): 
    f, axarr = plt.subplots(1, 3, figsize=(16, 4))

    observed = sdss_hubble_data.sdss_image
    im0 = axarr[0].matshow(observed[i])
    f.colorbar(im0, ax=axarr[0])
    axarr[0].set_title('observed, band = ' + str(bands[i]))


    im1 = axarr[1].matshow(recon2[0, i])
    f.colorbar(im1, ax=axarr[1])
    axarr[1].set_title('recon, band = ' + str(bands[i]))

    residual = 2.5 * (torch.log10(recon2[0,i]) - torch.log10(observed[i]))[5:95, 9:95]
    im2 = axarr[2].matshow(residual, vmax = residual.abs().max(), 
                           vmin = - residual.abs().max(), cmap = plt.get_cmap('bwr'))
    f.colorbar(im2, ax=axarr[2])
    axarr[2].set_title('recon - obse, band = ' + str(bands[i]))

In [None]:
np.save('../data/fitted_powerlaw_psf_params', 
        list(model_params.power_law_psf.parameters())[0].detach().numpy())

In [None]:
np.save('../data/fitted_planar_backgrounds', 
        list(model_params.planar_background.parameters())[0].detach().numpy())

In [None]:
torch.Tensor(np.load('../data/fitted_powerlaw_psf_params.npy')).to(device)

In [None]:
torch.Tensor(np.load('../data/fitted_planar_backgrounds.npy')).to(device)