In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from os import environ
from pathlib import Path
from typing import List

import torch
import pandas as pd
import numpy as np

from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from einops import rearrange

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from bliss.surveys.des import DarkEnergySurvey, DESDownloader

from pathlib import Path
from hydra import initialize, compose
from bliss.main import predict
import case_studies.galaxy_clustering.utils.diagnostics as diagnostics
from astropy.io import fits
from astropy.visualization import make_lupton_rgb

In [None]:
environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
with initialize(config_path="../conf/", version_base=None):
    cfg = compose("config")

In [None]:
cfg.predict

In [None]:
encoder = instantiate(cfg.predict.encoder)
enc_state_dict = torch.load(cfg.predict.weight_save_path)
enc_state_dict = enc_state_dict["state_dict"]
encoder.load_state_dict(enc_state_dict)

In [None]:
cfg.predict.encoder

In [None]:
DES_SVA_TILES = pd.read_pickle("/data/scratch/des/sva_map.pickle")
CACHED_DATA_PATH = (
    "/nfs/turbo/lsa-regier/scratch/gapatron/desdr-server.ncsa.illinois.edu/despublic/dr2_tiles"
)
DES_BANDS = ("g", "r", "i", "z")
GROUNDTRUTH_PATH = "/data/scratch/des/redmapper_groundtruth"

In [None]:
DES_TILE = np.random.choice(DES_SVA_TILES)
DES_TILE

In [None]:
directory_path = f"{CACHED_DATA_PATH}/{DES_TILE}"
dir_files = {
            band: [
                f for f in os.listdir(f"{directory_path}") if f.endswith(f"{band}_nobkg.fits.fz")
            ][0]
            for band in DES_BANDS
        }
image_bands = []
for band in DES_BANDS:
    band_filepath = f"{directory_path}/{dir_files[band]}"
    with fits.open(band_filepath) as f:
        # Data seems to be on HDU 1, not 0.
        hud = torch.from_numpy(f[1].data)
    image_bands.append(hud.data.unsqueeze(0))

full_image = torch.cat(image_bands, axis=0)
full_image.shape

In [None]:
item = full_image.unfold(dimension=1, size=2560, step=2480).unfold(dimension=2, size=2560, step=2480)
item.shape

In [None]:
outputs = torch.zeros((4,4,10,10))

for i in range(4):
    for j in range(4):
        image = item[:,i,j,:,:].unsqueeze(0)
        batch = {"images": image, "background": torch.zeros_like(image)}
        batch_size, _n_bands, h, w = batch["images"].shape[0:4]
        ht, wt = h // encoder.tile_slen, w // encoder.tile_slen

        input_lst = [inorm.get_input_tensor(batch) for inorm in encoder.image_normalizers]
        x = torch.cat(input_lst, dim=2)
        x_features = encoder.features_net(x)
        mask = torch.zeros([batch_size, ht, wt])
        context = encoder.make_context(None, mask)
        x_cat_marginal = encoder.catalog_net(x_features, context)
        outputs[i,j,:,:] = encoder.var_dist.sample(x_cat_marginal, use_mode=True)["membership"].squeeze()
outputs = outputs.bool()


In [None]:
gt_filename = f"/{GROUNDTRUTH_PATH}/{DES_TILE}_redmapper_groundtruth.npy"
gt_memberships = torch.from_numpy(np.load(gt_filename))
unfolded_gt = gt_memberships.unfold(dimension=0, size=2560, step=2480).unfold(
    dimension=1, size=2560, step=2480
)
pred_memberships = torch.repeat_interleave(outputs, repeats=encoder.tile_slen, dim=2)
pred_memberships = torch.repeat_interleave(pred_memberships, repeats=encoder.tile_slen, dim=3)
tp = (pred_memberships * unfolded_gt).sum()
tn = (~pred_memberships * ~unfolded_gt).sum()
fp = (pred_memberships * ~unfolded_gt).sum()
fn = (~pred_memberships * unfolded_gt).sum()
acc = (tp + tn) / (tp + tn + fp + fn)
prec = tp / (tp + fp + 1e-6)
rec = tp / (tp + fn + 1e-6)
f1 = 2 * prec * rec / (prec + rec + 1e-6)
print(f"Accuracy: {acc}")
print(f"Precision: {prec}")
print(f"Recall: {rec}")
print(f"F1: {f1}")

In [None]:
def blend_images(original, overlay, alpha=0.5, img_crop=0):
    # Ensure the original image is in float
    if original.max() > 1.0:
        original = original / 255.0
    # Blend the images
    blended = original * (1 - alpha) + overlay * alpha
    blended = blended[img_crop: blended.shape[0] - img_crop, img_crop: blended.shape[1] - img_crop]
    print(blended.shape)
    return blended

In [None]:
rgb_default = make_lupton_rgb(full_image[2,:,:], full_image[1,:,:], full_image[0,:,:])
overlay = gt_memberships.unsqueeze(2).repeat(1,1,3).numpy()
blended = blend_images(rgb_default, overlay)
plt.figure(figsize=(10,10))
plt.imshow(blended)

In [None]:
rearranged_memberships = rearrange(pred_memberships, 'd0 d1 d2 d3 -> (d0 d2) (d1 d3)')
row_include = torch.from_numpy(np.concatenate([np.arange(0,2560), np.arange(2640,5120), np.arange(5200,7680), np.arange(7760,10240)]))
folded_outputs = torch.index_select(rearranged_memberships, 0, row_include)
folded_outputs = torch.index_select(folded_outputs, 1, row_include)

In [None]:
rgb_default = make_lupton_rgb(full_image[2,:,:], full_image[1,:,:], full_image[0,:,:])
overlay = folded_outputs.unsqueeze(2).repeat(1,1,3).numpy()
blended = blend_images(rgb_default, overlay)
plt.figure(figsize=(10,10))
plt.imshow(blended)