In [None]:
%load_ext autoreload
%autoreload 2
%aimport

In [None]:
%matplotlib inline 

In [None]:
import matplotlib.pyplot as plt
import torch

from bliss.inference import SDSSFrame
from bliss.datasets import sdss
from bliss.inference import reconstruct_scene_at_coordinates
from case_studies.sdss_galaxies.plots import load_models


from astropy.table import Table

import plotly.express as px
import plotly.graph_objects as go

In [None]:
# check GPU is configured correctly
device = torch.device('cuda:0')
!echo $CUDA_VISIBLE_DEVICES

# Load data

In [None]:
# load sdss data
sdss_dir = '/home/imendoza/bliss/data/sdss/'
pixel_scale = 0.393
coadd_file = "/home/imendoza/bliss/data/coadd_catalog_94_1_12.fits"
frame = SDSSFrame(sdss_dir, pixel_scale, coadd_file)

In [None]:
# load models
from hydra import compose, initialize
from hydra.utils import instantiate
from bliss.encoder import Encoder

with initialize(config_path="../config"):
    cfg = compose("config", overrides=[])
    
    
enc, dec = load_models(cfg, device)
bp = enc.border_padding
torch.cuda.empty_cache()

In [None]:
# get catalog 
h, w = bp, bp
h_end = ((frame.image.shape[2] - 2 * bp) // 4) * 4 + bp #adjustments when using whole frame.
w_end = ((frame.image.shape[3] - 2 * bp) // 4) * 4 + bp
coadd_params = frame.get_catalog((h, h_end), (w, w_end))

In [None]:
# frame PHOTO catalog
frame_cat = Table.read('../../../data/sdss/94/1/12/photoObj-000094-1-0012.fits')

In [None]:
print("# objects detected by photo on frame: ", len(frame_cat))
print("# objects detected by photo on coadd: ", coadd_params.n_sources.sum().item())

# Get Locations

In [None]:
#inference
with torch.no_grad():
    _, tile_est = reconstruct_scene_at_coordinates(
        enc,
        dec,
        frame.image,
        frame.background,
        h_range=(h, h_end),
        w_range=(w, w_end),
        slen=300,
        device=device,
    )
map_recon = tile_est.to_full_params()
map_recon["fluxes"] = (
    map_recon["galaxy_bools"] * map_recon["galaxy_fluxes"]
    + map_recon["star_bools"] * map_recon["fluxes"]
)
map_recon["mags"] = sdss.convert_flux_to_mag(map_recon["fluxes"])
torch.cuda.empty_cache()

In [None]:
# prepare inference locs
plocs = map_recon.plocs.cpu().numpy().squeeze() + bp - 0.5 # plotting adjustment
coords = frame.wcs.all_pix2world(np.hstack([plocs[:, 1, None], plocs[:, 0, None]]), 0)
galaxy_bool = map_recon['galaxy_bools'].numpy().astype(bool).squeeze()
galaxy_prob = map_recon['galaxy_probs'].numpy().squeeze()
mags = map_recon['mags'].squeeze().numpy()


In [None]:
# coadd params
coplocs = coadd_params.plocs.squeeze().numpy() + bp - 0.5 # plotting adjustment
cogbool = coadd_params['galaxy_bools'].squeeze().numpy()

cora = coadd_params['ra'].squeeze().numpy()
codec = coadd_params['dec'].squeeze().numpy()
comags = coadd_params['mags'].squeeze().numpy()

# Plotly Frame

In [None]:
from bliss.reporting import match_by_locs

coindx, indx, dkeep, _ = match_by_locs(torch.from_numpy(coplocs), torch.from_numpy(plocs))
nindx = np.array(list(set(range(len(plocs))).difference(set(indx))))
assert len(indx) + len(nindx) == len(plocs)

gal_misclass = ~np.equal(galaxy_bool[indx][dkeep], cogbool[coindx][dkeep]) & cogbool[coindx][dkeep]
plocs_misclassified = plocs[indx][dkeep][gal_misclass]
plocs_unmatched = plocs[nindx]
plocs_unmatched_dist = plocs[indx][~dkeep]


In [None]:
image = frame.image.squeeze().numpy()

fig = px.imshow(image, width=800, height=550, zmin=800, zmax=1150, color_continuous_scale='gray')


# create scatter objects
scatter_coadd_galaxies = go.Scatter(name='galaxy coadd', x=coplocs[cogbool][:, 1], y=coplocs[cogbool][:, 0], 
                                  mode='markers', marker=dict(color='red', size=8, symbol='cross'), 
                                 hovertext=[f'mag:{x:.2f}; (ra, dec):({y:.4f}, {z:.4f})' 
                                            for (x, y,z) in zip(comags[cogbool], 
                                                             cora[cogbool],
                                                             codec[cogbool])]
                                 )
scatter_coadd_star = go.Scatter(name='star coadd', x=coplocs[~cogbool][:, 1], y=coplocs[~cogbool][:, 0], 
                                  mode='markers', marker=dict(color='blue', size=8, symbol='cross'),
                                 hovertext=[f'mag:{x:.2f}; (ra, dec):({y:.4f}, {z:.4f})' 
                                            for (x, y,z) in zip(comags[~cogbool], 
                                                             cora[~cogbool],
                                                             codec[~cogbool])]
                                 )
scatter_bliss_galaxies = go.Scatter(name='galaxy pred', x=plocs[galaxy_bool][:, 1], y=plocs[galaxy_bool][:, 0], 
                                  mode='markers', marker=dict(color='hotpink', size=11, symbol='x'),
                                 hovertext=[f'mag:{a:.2f}; prob_galaxy:{x:.2f}; (ra, dec):({y:.4f}, {z:.4f})' 
                                            for (a, x, y,z) in zip(mags[galaxy_bool], 
                                                               galaxy_prob[galaxy_bool], 
                                                               coords[galaxy_bool][:, 0], 
                                                               coords[galaxy_bool][:, 1])]
                                 )
scatter_bliss_star = go.Scatter(name='star pred', x=plocs[~galaxy_bool][:, 1], y=plocs[~galaxy_bool][:, 0], 
                                mode='markers', marker=dict(color='cyan', size=11, symbol='x'),\
                                 hovertext=[f'mag:{a:.2f}; prob_galaxy:{x:.2f}; (ra, dec):({y:.4f},{z:.4f})' 
                                            for (a, x,y,z) in zip(mags[~galaxy_bool], 
                                                               galaxy_prob[~galaxy_bool], 
                                                               coords[~galaxy_bool][:, 0], 
                                                               coords[~galaxy_bool][:, 1])]
                                 )

scatter_unmatched = go.Scatter(name='unmatched', x=plocs_unmatched[:, 1], y=plocs_unmatched[:, 0], 
                                mode='markers', marker=dict(color='magenta', size=10, symbol='cross'),
                                 )
scatter_unmatched_dist = go.Scatter(name='unmatched (distance)', x=plocs_unmatched_dist[:, 1], y=plocs_unmatched_dist[:, 0], 
                                mode='markers', marker=dict(color='green', size=10, symbol='cross'),
                                 )
scatter_misclassified = go.Scatter(name='misclassified', x=plocs_misclassified[:, 1], y=plocs_misclassified[:, 0], 
                                mode='markers', marker=dict(color='yellow', size=10, symbol='cross'),
                                 )

# add traces to figure
fig.add_trace(scatter_coadd_galaxies)
fig.add_trace(scatter_coadd_star)
fig.add_trace(scatter_bliss_galaxies)
fig.add_trace(scatter_bliss_star)
fig.add_trace(scatter_unmatched)
fig.add_trace(scatter_unmatched_dist)
fig.add_trace(scatter_misclassified)


fig.update_layout(legend=dict(orientation="h", y=1.05)) # adjust legend


fig.show()

