# Setup

In [None]:
%load_ext autoreload
%autoreload 2
%aimport

In [None]:
%matplotlib inline 

In [None]:
import matplotlib.pyplot as plt
import torch

from bliss.inference import SDSSFrame
from bliss.datasets import sdss
from bliss.inference import reconstruct_scene_at_coordinates
from case_studies.sdss_galaxies.plots.main import load_models


from astropy.table import Table

import plotly.express as px
import plotly.graph_objects as go

In [None]:
# check GPU is configured correctly
device = torch.device('cuda:0')
!echo $CUDA_VISIBLE_DEVICES

In [None]:
from case_studies.sdss_galaxies.plots.autoencoder import scatter_bin_plot, set_rc_params

In [None]:
set_rc_params()

# Get Config

In [None]:
# load models
from hydra import compose, initialize
from hydra.utils import instantiate
from bliss.encoder import Encoder

with initialize(config_path="../../config"):
    cfg = compose("config", overrides=[])
    

# Run

In [None]:
from bliss.datasets.galsim_galaxies import SingleGalsimGalaxies

In [None]:
prior = instantiate(cfg.models.prior)

In [None]:
tile_cat = prior.sample_prior(1, 1, 1, 1)

In [None]:
tile_cat

In [None]:
dl = ds.val_dataloader()
for x in dl:
    print(x['images'].shape)

In [None]:
for x in dl:
    print(x['images'].shape)

In [None]:
# Load data

# load sdss data
sdss_dir = '/home/imendoza/bliss/data/sdss/'
pixel_scale = 0.393
coadd_file = "/home/imendoza/bliss/data/coadd_catalog_94_1_12.fits"
frame = SDSSFrame(sdss_dir, pixel_scale, coadd_file)

# Encoder blend images

In [None]:
# load models
from hydra import compose, initialize
from hydra.utils import instantiate
from bliss.encoder import Encoder

with initialize(config_path="../../config"):
    cfg = compose("config", overrides=[])
    
    
enc, dec = load_models(cfg, device)
bp = enc.border_padding
torch.cuda.empty_cache()

In [None]:
detection_encoder = instantiate(cfg.models.detection_encoder)

In [None]:
PATH = '../../../../epoch=989-val_loss=-0.015.ckpt'
model_checkpoint = torch.load(PATH, map_location='cpu')
model_state_dict = model_checkpoint["state_dict"]
det = detection_encoder.load_state_dict(model_state_dict)

In [None]:
detection_encoder = detection_encoder.load_from_checkpoint

In [None]:
enc.n_rows_per_batch = 10
enc.n_images_per_batch = 15

In [None]:
blend_images = torch.load('../../models/simulated_blended_galaxies.pt')
blend_images['images'].shape

In [None]:
n_images, c, slen, _ = blend_images['images'].shape
background = blend_images['background']
background = background.unsqueeze(0)
background = background.expand(n_images, 1, slen, slen)

In [None]:
tile_map = enc.variational_mode(blend_images['images'], background)

In [None]:
full = tile_map.cpu().to_full_params()

In [None]:
print(full.plocs.shape)
print(full.n_sources.max().item())

In [None]:
ii = 50
plt.figure(figsize=(8, 8))
plt.imshow(blend_images['images'][ii].cpu().numpy().reshape(slen, slen))
plocs = full.plocs[ii, :full.n_sources[ii].item()].cpu().numpy()
plt.scatter(plocs[:, 1] - 0.5 + enc.border_padding, plocs[:, 0] - 0.5 + enc.border_padding, marker='x', 
            color='r', s=75)
true_plocs = blend_images['plocs'][ii].cpu().numpy()
plt.scatter(true_plocs[:, 1] - 0.5, true_plocs[:, 0] - 0.5, marker='+', 
            color='b', s=100)
print(blend_images['n_sources'][ii].item(), full.n_sources[ii].item())

In [None]:
n_batches = len(full.n_sources)
true_plocs = blend_images['plocs']
est_plocs = full.plocs

In [None]:
est_plocs.shape

In [None]:
true_plocs.shape

In [None]:
from bliss.reporting import match_by_locs
for ii in [50]:

    tindx, eindx, dkeep, _ = match_by_locs(true_plocs[ii], est_plocs[ii] + enc.border_padding)
    print(tindx, eindx, dkeep)
    print(true_plocs[ii], est_plocs[ii]  + enc.border_padding)
    print(len(tindx[dkeep]))
    print(true_plocs[ii][tindx][dkeep])
    print(est_plocs[ii][eindx][dkeep] + enc.border_padding)

In [None]:
d = torch.load('../../output/sdss_figures_cache/blendsim_cache.pt')

In [None]:
d['est_ellips'].shape

In [None]:
plt.hist(d['true_ellips'][:, 0].numpy(), range)

In [None]:
plt.hist(d['snr'].numpy())

In [None]:
plt.hist(d['blendedness'].numpy())
plt.xlabel(r"$a$")

In [None]:
sum(np.log10(d['snr'].numpy())<0.4)

# Dataset of blends

In [None]:
!pwd

In [None]:
from bliss.catalog import TileCatalog
blend_data = torch.load('/home/imendoza/bliss.git/feature/case_studies/sdss_galaxies/models/simulated_blended_galaxies.pt')
images = blend_data.pop("images")
background = blend_data.pop("background")
n_batches, _, slen, _ = images.shape
assert background.shape == (1, slen, slen)

# prepare background
background = background.unsqueeze(0)
background = background.expand(n_batches, 1, slen, slen)

# first create FullCatalog from simulated data
tile_cat = TileCatalog(4, blend_data).cpu()
full_truth = tile_cat.to_full_params()

In [None]:
snr = full_truth['snr'].numpy().reshape(-1)
snr = snr[snr!=0]
plt.hist(np.log10(snr), bins=30)


In [None]:
ble = full_truth['blendedness'].numpy().reshape(-1)
ble = ble[ble!=0]
plt.hist(ble, bins=30)


In [None]:

_plocs = full_truth.plocs.reshape(-1, 2)
_plocs = _plocs[_plocs[:, 0] >0]
plt.scatter(_plocs[:, 1].numpy(), _plocs[:, 0].numpy())