In [None]:
import torch
import pytorch_lightning as pl

%matplotlib inline
import matplotlib.pyplot as plt

print(torch.__version__)

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

from hydra.experimental import initialize, compose

from bliss import sleep
from bliss.datasets import simulated, sdss

import numpy as np
import time 

torch.manual_seed(184)
np.random.seed(143)

In [None]:
import plotting

In [None]:
%env BLISS_HOME=/home/runingli/astronomy/celeste/

# Get config file

In [None]:
overrides = dict(
    model="m2",
    dataset="m2",
    training="m2",
    optimizer="m2"
)

overrides = [f"{key}={value}" for key, value in overrides.items()]

In [None]:
with initialize(config_path="../../config"):
    cfg = compose("config", overrides=overrides)

# Load dataset and encoder

In [None]:
datamodule = simulated.SimulatedModule(cfg)
sleep_net = sleep.SleepPhase(cfg)
trainer = pl.Trainer(**cfg.training.trainer)

### Check priors

In [None]:
# prior on n-stars

%matplotlib inline

plt.hist(datamodule.dataset.get_batch()['n_sources'].flatten().cpu(), 
         bins=np.arange(7, step = 0.5), 
         density = True);

plt.xlabel('n-sources per tile')

In [None]:
# prior on fluxes

%matplotlib inline

for b in range(2): 
    
    fluxes = datamodule.dataset.get_batch()['fluxes'][:, :, :, b]
    fluxes = fluxes.flatten()
    fluxes = fluxes[fluxes > 0]


    plt.hist(np.log10(fluxes.cpu()), 
                     density = True, 
                     alpha = 0.5, 
                     label = 'band-' + str(b), 
                     bins = 50)
    
plt.legend()

plt.xlabel('log10(fluxes)')

In [None]:
# prior on colors

%matplotlib inline

plt.figure()

fluxes = datamodule.dataset.get_batch()['fluxes']
colors = (torch.log10(fluxes[:, :, :, 1]) - 
           torch.log10(fluxes[:, :, :, 0])).flatten()
colors = colors[fluxes[:, :, :, 0].flatten() > 0]

bins = plt.hist(colors.cpu(), 
                 density = True, bins = 50)

plt.xlabel('color')

# Plot a simulated image

In [None]:
batch = datamodule.dataset.get_batch()

In [None]:
%matplotlib inline

image_indx = 0

image = batch['images'][image_indx, 0].cpu()

f, axarr = plt.subplots(1, 1, figsize=(5, 5))
plotting.plot_image(axarr, image)

### some subimages

In [None]:
slen = datamodule.dataset.image_decoder.slen
border_padding = datamodule.dataset.image_decoder.border_padding

In [None]:
from bliss.models.encoder import get_full_params

In [None]:
# get locations on the full image
full_params = get_full_params(slen, 
                              dict({'n_sources': batch['n_sources'][0:1], 
                                    'locs': batch['locs'][0:1]}))

locs_full = full_params['locs']

In [None]:
f, axarr = plt.subplots(2, 5, figsize=(16, 8))

subimage_slen = 10
indx_vec = np.arange(0, slen, subimage_slen)

for i in range(10): 
    x0 = np.random.choice(indx_vec)
    x1 = np.random.choice(indx_vec)
    
    ax = axarr[i // 5,  i % 5]
    
    plotting.plot_locations(locs_full.squeeze(), slen, border_padding, 
                            ax, marker = 'o', color = 'b')
    
    plotting.plot_image(ax, image, 
               x0, x1, subimage_slen, subimage_slen)


# Train!

In [None]:
t0 = time.time()
trainer.fit(sleep_net, datamodule = datamodule)
torch.save(sleep_net.image_encoder.state_dict(), './starnet')

print('TOTAL TIME ELAPSED: {:.3f}secs'.format(time.time() - t0))

# sleep_net.image_encoder.load_state_dict(torch.load('./starnet')); 

In [None]:
sleep_net.image_encoder.to(device);
sleep_net.image_encoder.eval(); 

# Results on SDSS data

### Load SDSS data

In [None]:
import os

In [None]:
hubble_data = np.load(os.path.join('../../data/true_hubble_m2.npz'))

In [None]:
sdss_image = torch.from_numpy(hubble_data["sdss_image"]).to(device)
print(sdss_image.shape)

# the true parameters
hubble_locs = torch.from_numpy(hubble_data["true_locs"]).to(device)
hubble_fluxes = torch.from_numpy(hubble_data["true_fluxes"]).to(device)
nelec_per_nmgy = torch.from_numpy(hubble_data["nelec_per_nmgy"]).to(device)

### get map estimates

In [None]:
sdss_image.shape

In [None]:
map_estimate = sleep_net.image_encoder.map_estimate(sdss_image.unsqueeze(0))

map_nstars = map_estimate['n_sources'].detach()
map_locs = map_estimate['locs'].detach()
map_log_fluxes = map_estimate['log_fluxes'].detach()

# all galaxies should be off
assert torch.all(map_estimate['galaxy_bool'] == 0.)

In [None]:
print(map_nstars)

### Example subimages

In [None]:
f, axarr = plt.subplots(2, 5, figsize=(24, 8))

subimage_slen = 10
indx_vec = np.arange(0, slen, subimage_slen)

for i in range(10): 
    x0 = np.random.choice(indx_vec)
    x1 = np.random.choice(indx_vec)
    
    ax = axarr[i // 5,  i % 5]
    
    plotting.plot_locations(hubble_locs, slen, border_padding, 
                            ax, marker = 'o', color = 'b')
    plotting.plot_locations(map_locs.squeeze(), slen, border_padding, 
                            ax, marker = 'x', color = 'red')
    
    im = plotting.plot_image(ax, sdss_image[0], 
               x0, x1, subimage_slen, subimage_slen)
    
    f.colorbar(im, ax = ax)

In [None]:
sleep_net.image_encoder.s

### Summary statistics

In [None]:
tpr, ppv = sdss.get_summary_stats(map_locs.squeeze(), 
                                                hubble_locs, 
                                                slen, 
                                                map_log_fluxes.exp()[0, :, 0], 
                                                hubble_fluxes[:, 0],
                                                nelec_per_nmgy = nelec_per_nmgy)[0:2]

print('True positive rate: {:.3f}'.format(tpr))
print('Positive predictive value: {:.3f}'.format(ppv))
print('F1: {:.3f}'.format(2 * tpr * ppv / (tpr + ppv)))

# Get summary statistics as a function of magnitude

In [None]:
def get_tpr_vec(
    est_locs, true_locs, slen, est_fluxes, true_fluxes, nelec_per_nmgy, mag_vec
):

    # convert to magnitude
    true_mags = sdss.convert_nmgy_to_mag(true_fluxes / nelec_per_nmgy)

    tpr_vec = np.zeros(len(mag_vec) - 1)

    counts_vec = np.zeros(len(mag_vec) - 1)

    for i in range(len(mag_vec) - 1):
        which_true = (true_mags > mag_vec[i]) & (true_mags < mag_vec[i + 1])
        counts_vec[i] = torch.sum(which_true)

        tpr_vec[i] = sdss.get_summary_stats(
            est_locs,
            true_locs[which_true],
            slen,
            est_fluxes,
            true_fluxes[which_true],
            nelec_per_nmgy,
        )[0]

    return tpr_vec, mag_vec, counts_vec


def get_ppv_vec(
    est_locs, true_locs, slen, est_fluxes, true_fluxes, nelec_per_nmgy, mag_vec
):

    est_mags = sdss.convert_nmgy_to_mag(est_fluxes / nelec_per_nmgy)

    ppv_vec = np.zeros(len(mag_vec) - 1)
    counts_vec = np.zeros(len(mag_vec) - 1)

    for i in range(len(mag_vec) - 1):
        which_est = (est_mags > mag_vec[i]) & (est_mags < mag_vec[i + 1])

        counts_vec[i] = torch.sum(which_est)

        if torch.sum(which_est) == 0:
            continue

        ppv_vec[i] = sdss.get_summary_stats(
            est_locs[which_est],
            true_locs,
            slen,
            est_fluxes[which_est],
            true_fluxes,
            nelec_per_nmgy,
        )[1]

    return ppv_vec, mag_vec, counts_vec


In [None]:
# percentiles of the hubble magnitudes.
hubble_mags = sdss.convert_nmgy_to_mag(hubble_fluxes[:, 0] / nelec_per_nmgy)
percentiles = np.linspace(0, 1, 11) * 100
mag_vec = np.percentile(hubble_mags.cpu(), percentiles)
mag_vec = torch.Tensor(mag_vec).to(device)

        
# get tpr as function of magnitude
tpr_vec  = \
    get_tpr_vec(map_locs.squeeze(), 
                                    hubble_locs, 
                                    slen, 
                                    map_log_fluxes.exp()[0, :, 0], 
                                    hubble_fluxes[:, 0],
                                    nelec_per_nmgy = nelec_per_nmgy, 
                                    mag_vec = mag_vec)[0]

# # get ppv as function of magnitude
ppv_vec  = \
    get_ppv_vec(map_locs.squeeze(), 
                                    hubble_locs, 
                                    slen, 
                                    map_log_fluxes.exp()[0, :, 0], 
                                    hubble_fluxes[:, 0],
                                    nelec_per_nmgy = nelec_per_nmgy, 
                                    mag_vec = mag_vec)[0]

In [None]:
fig, axarr = plt.subplots(1, 2, figsize=(15, 5))


# plot summary statistics as function of hubble percentile
percentiles = np.linspace(0, 1, 10 + 1) * 100
axarr[0].plot(percentiles[:-1], tpr_vec, '-o')
axarr[1].plot(percentiles[:-1], ppv_vec, '-o')

# axis labels
axarr[0].set_xlabel('Hubble magnitude')
axarr[1].set_xlabel('Estimated magnitude')
axarr[0].set_ylabel('True positive rate')
axarr[1].set_ylabel('Positive predictive value')

# replace percentiles with actual magnitudes
for i in range(2): 
    axarr[i].set_xticks(percentiles[:-1])
    axarr[i].set_xticklabels(np.round(mag_vec.cpu().numpy(), 1)[:-1])


In [None]:
hubble_mags = sdss.convert_nmgy_to_mag(hubble_fluxes[:, 0] / nelec_per_nmgy)
est_mags = sdss.convert_nmgy_to_mag(map_log_fluxes.exp()[0, :, 0] / nelec_per_nmgy)

In [None]:
# flux distributions
bins = plt.hist(hubble_mags.cpu(), color = 'grey', bins = 50);
plt.hist(est_mags.cpu(), color = 'red', alpha = 0.5, bins = bins[1]);

In [None]:
# np.savez('../../data/true_hubble_m2', 
#          sdss_image = sdss_image.cpu(), 
#          true_locs = hubble_locs.cpu(), 
#          true_fluxes = hubble_fluxes.cpu(), 
#          nelec_per_nmgy = nelec_per_nmgy)