## DC2 Classification Accuracy

In [4]:
import torch
from os import environ
from pathlib import Path
from einops import rearrange
import pickle

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from hydra import initialize, compose
from hydra.utils import instantiate

from bliss.surveys.dc2 import DC2, unsqueeze_tile_dict
from pathlib import Path

from bliss.catalog import FullCatalog

from torchmetrics import MetricCollection
import tqdm
from pytorch_lightning.utilities import move_data_to_device
from bliss.catalog import SourceType

import GCRCatalogs

environ["BLISS_HOME"] = str(Path().resolve().parents[1])

output_dir = Path("./DC2_classification_accuracy/")
output_dir.mkdir(parents=True, exist_ok=True)

In [5]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [6]:
dc2: DC2 = instantiate(notebook_cfg.surveys.dc2)
test_sample = dc2.get_plotting_sample(0)
cur_image_wcs = test_sample["wcs"]
cur_image_true_full_catalog = test_sample["full_catalog"]
cur_image_match_id = test_sample["match_id"]

In [8]:
GCRCatalogs.set_root_dir("/data/scratch/dc2_nfs/")
lsst_catalog_gcr = GCRCatalogs.load_catalog("desc_dc2_run2.2i_dr6_object_with_truth_match")
lsst_catalog_sub = lsst_catalog_gcr.get_quantities(
    [
        "id_truth",
        "objectId",
        "ra",
        "dec",
        "truth_type",
        "cModelFlux_u",
        "cModelFluxErr_u",
        "cModelFlux_g",
        "cModelFluxErr_g",
        "cModelFlux_r",
        "cModelFluxErr_r",
        "cModelFlux_i",
        "cModelFluxErr_i",
        "cModelFlux_z",
        "cModelFluxErr_z",
        "cModelFlux_y",
        "cModelFluxErr_y",
    ]
)
lsst_catalog_df = pd.DataFrame(lsst_catalog_sub)
lsst_catalog_tensors_dict = {
    "truth_type": torch.tensor(lsst_catalog_df["truth_type"].values).view(-1, 1),
    "flux": torch.cat(
        [
            torch.tensor(flux.values).view(-1, 1)
            for flux in [
                lsst_catalog_df["cModelFlux_g"],
                lsst_catalog_df["cModelFlux_i"],
                lsst_catalog_df["cModelFlux_r"],
                lsst_catalog_df["cModelFlux_u"],
                lsst_catalog_df["cModelFlux_y"],
                lsst_catalog_df["cModelFlux_z"],
            ]
        ],
        dim=1,
    ),
    "ra": torch.tensor(lsst_catalog_df["ra"].values),
    "dec": torch.tensor(lsst_catalog_df["dec"].values),
}

In [9]:
def get_lsst_params(
    lsst_catalog_tensors_dict, cur_image_wcs, image_lim,
):
    lsst_ra = lsst_catalog_tensors_dict["ra"]
    lsst_dec = lsst_catalog_tensors_dict["dec"]
    lsst_pt, lsst_pr = cur_image_wcs.all_world2pix(lsst_ra, lsst_dec, 0)
    lsst_pt = torch.from_numpy(lsst_pt)
    lsst_pr = torch.from_numpy(lsst_pr)

    lsst_plocs = torch.stack((lsst_pr, lsst_pt), dim=-1)
    lsst_source_type = lsst_catalog_tensors_dict["truth_type"]
    lsst_flux = lsst_catalog_tensors_dict["flux"]

    x0_mask = (lsst_plocs[:, 0] > 0) & (lsst_plocs[:, 0] < image_lim)
    x1_mask = (lsst_plocs[:, 1] > 0) & (lsst_plocs[:, 1] < image_lim)
    lsst_x_mask = x0_mask * x1_mask
    # filter r band
    lsst_flux_mask = lsst_flux[:, 2] > 0
    # filter supernova
    lsst_source_mask = (lsst_source_type != 3).squeeze(1)
    lsst_mask = lsst_x_mask * lsst_flux_mask * lsst_source_mask

    lsst_plocs = lsst_plocs[lsst_mask, :]
    lsst_source_type = torch.where(
        lsst_source_type[lsst_mask] == 2, SourceType.STAR, SourceType.GALAXY
    )
    lsst_flux = lsst_flux[lsst_mask, :]

    return lsst_plocs, lsst_source_type, lsst_flux

In [10]:
image_lim = test_sample["image"].shape[1]
r_band_min_flux = notebook_cfg.encoder.min_flux_for_metrics
lsst_plocs, lsst_source_type, lsst_flux = get_lsst_params(
    lsst_catalog_tensors_dict, cur_image_wcs, image_lim)
flux_mask = lsst_flux[:, 2] > r_band_min_flux
lsst_plocs = lsst_plocs[flux_mask, :]
lsst_source_type = lsst_source_type[flux_mask]
lsst_flux = lsst_flux[flux_mask, :]
lsst_n_sources = torch.tensor([lsst_plocs.shape[0]])

In [11]:
lsst_full_cat = FullCatalog(height=image_lim, width=image_lim, d={
        "plocs": lsst_plocs.unsqueeze(0).to(device=device),
        "n_sources": lsst_n_sources.to(device=device),
        "source_type": lsst_source_type.unsqueeze(0).to(device=device),
        "galaxy_fluxes": lsst_flux.unsqueeze(0).to(device=device),
        "star_fluxes": lsst_flux.unsqueeze(0).clone().to(device=device),
    })

In [12]:
# change this model path according to your training setting
MODEL_PATH = "../../output/DC2_experiments/DC2_psf_aug_asinh_06-06-1/checkpoints/best_encoder.ckpt"
bliss_encoder = instantiate(notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(MODEL_PATH, device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [14]:
batch = {
    "tile_catalog": unsqueeze_tile_dict(test_sample["tile_catalog"]),
    "images": rearrange(test_sample["image"], "h w nw -> 1 h w nw"),
    "background": rearrange(test_sample["background"], "h w nw -> 1 h w nw"),
    "psf_params": rearrange(test_sample["psf_params"], "h w -> 1 h w")
}

batch = move_data_to_device(batch, device=device)

bliss_output_path = output_dir / "bliss_output.pkl"

if not bliss_output_path.exists():
    bliss_out_dict = bliss_encoder.predict_step(batch, None)

    with open(bliss_output_path, "wb") as outp:  # Overwrites any existing file.
        pickle.dump(bliss_out_dict, outp, pickle.HIGHEST_PROTOCOL)
else:
    with open(bliss_output_path, "rb") as inputp:
        bliss_out_dict = pickle.load(inputp)

Traceback (most recent call last):
  File "_pydevd_bundle/pydevd_cython.pyx", line 577, in _pydevd_bundle.pydevd_cython.PyDBFrame._handle_exception
  File "_pydevd_bundle/pydevd_cython.pyx", line 312, in _pydevd_bundle.pydevd_cython.PyDBFrame.do_wait_suspend
  File "/home/pduan/bliss/.venv/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd.py", line 2070, in do_wait_suspend
    keep_suspended = self._do_wait_suspend(thread, frame, event, arg, suspend_type, from_this_thread, frames_tracker)
  File "/home/pduan/bliss/.venv/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd.py", line 2106, in _do_wait_suspend
    time.sleep(0.01)
KeyboardInterrupt


In [None]:
bliss_full_cat: FullCatalog = bliss_out_dict["mode_cat"].to_full_catalog()

In [None]:
matcher = instantiate(notebook_cfg.encoder.matcher)
bliss_metrics = instantiate(notebook_cfg.encoder.metrics)
lsst_metrics = bliss_metrics.clone()
bliss_metrics = MetricCollection({
    "source_type_accuracy": bliss_metrics["source_type_accuracy"],
    "source_type_accuracy_star": bliss_metrics["source_type_accuracy_star"],
    "source_type_accuracy_galaxy": bliss_metrics["source_type_accuracy_galaxy"],
}).to(device=device)
lsst_metrics = MetricCollection({
    "source_type_accuracy": lsst_metrics["source_type_accuracy"],
    "source_type_accuracy_star": lsst_metrics["source_type_accuracy_star"],
    "sourec_type_accuracy_galaxy": lsst_metrics["source_type_accuracy_galaxy"],
}).to(device=device)

bliss_results = {
    "classification_acc": None,
    "classification_acc_star": None,
    "classification_acc_galaxy": None,
}

lsst_results = {
    "classification_acc": None,
    "classification_acc_star": None,
    "classification_acc_galaxy": None,
}

classification_result_path = output_dir / "classification_result.pkl"
if not classification_result_path.exists():
    bliss_matching = matcher.match_catalogs(cur_image_true_full_catalog, bliss_full_cat)
    bliss_metrics.update(cur_image_true_full_catalog, bliss_full_cat, bliss_matching)

    lsst_matching = matcher.match_catalogs(cur_image_true_full_catalog, lsst_full_cat)
    lsst_metrics.update(cur_image_true_full_catalog, lsst_full_cat, lsst_matching)

    for k, v in bliss_metrics.items():
        resutls = v.get_results_on_per_flux_bin()
        for k_results, v_results in resutls.items():
            bliss_results[k_results] = v_results.cpu()

    for k, v in lsst_metrics.items():
        resutls = v.get_results_on_per_flux_bin()
        for k_results, v_results in resutls.items():
            lsst_results[k_results] = v_results.cpu()

        with open(classification_result_path, "wb") as classification_result_file:
                pickle.dump({
                    "bliss_results": bliss_results,
                    "lsst_results": lsst_results,
                }, classification_result_file, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open(classification_result_path, "rb") as classification_result_file:
          classification_result = pickle.load(classification_result_file)
    bliss_results = classification_result["bliss_results"]
    lsst_results = classification_result["lsst_results"]

In [None]:
def plot(classification_acc_1,
        classification_acc_2,
        flux_bin_cutoffs,
        source_type_name, 
        model_name_1, 
        model_name_2):
    xlabels = (
        ["[100, " + str(flux_bin_cutoffs[0]) + "]"]
        + [f"[{flux_bin_cutoffs[i]}, {flux_bin_cutoffs[i + 1]}]" for i in range(len(flux_bin_cutoffs) - 1)]
        + ["> " + str(flux_bin_cutoffs[-1])]
    )

    sns.set_theme(style="whitegrid")
    fig, ax = plt.subplots(
        1, 1, figsize=(10, 10), sharey=True
    )

    c1, c2 = plt.rcParams["axes.prop_cycle"].by_key()["color"][0:2]
    ax.plot(
        range(len(xlabels)),
        classification_acc_1.tolist(),
        fmt="-o",
        color=c1,
        label=f"{model_name_1} Classification Acc ({source_type_name})",
    )
    ax.plot(
        range(len(xlabels)),
        classification_acc_2.tolist(),
        fmt="-o",
        color=c2,
        label=f"{model_name_2} Classification Acc ({source_type_name})",
    )
    ax.set_xlabel("Flux")
    ax.set_xticks(range(len(xlabels)))
    ax.set_xticklabels(xlabels, rotation=45)
    ax.legend()
  
    plt.tight_layout()
    plt.show()

    return fig, ax

In [None]:
for k, v in bliss_metrics.items():
    if k == "source_type_accuracy":
        fig, ax = plot(bliss_results["classification_acc"],
                       lsst_results["classification_acc"],
                       flux_bin_cutoffs=v.flux_bin_cutoffs,
                       source_type_name=v.source_type_name,
                       model_name_1="BLISS",
                       model_name_2="LSST")
    elif k == "source_type_accuracy_star":
        fig, ax = plot(bliss_results["classification_acc_star"],
                       lsst_results["classification_acc_star"],
                       flux_bin_cutoffs=v.flux_bin_cutoffs,
                       source_type_name=v.source_type_name,
                       model_name_1="BLISS",
                       model_name_2="LSST")
    elif k == "source_type_accuracy_galaxy":
        fig, ax = plot(bliss_results["classification_acc_galaxy"],
                       lsst_results["classification_acc_galaxy"],
                       flux_bin_cutoffs=v.flux_bin_cutoffs,
                       source_type_name=v.source_type_name,
                       model_name_1="BLISS",
                       model_name_2="LSST")
    else:
        raise NotImplementedError()
    
    fig.show()