In [1]:
import torch
import logging
import tqdm
import time

from einops import repeat

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device
from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from bliss.encoder.metrics import CatalogMatcher
from case_studies.dc2_diffusion.utils.metrics import DetectionPerformance

# set device
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

# load config
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("half_pixel_notebook_config")

In [2]:
# setup bliss encoder
tile_slen = notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = 2

dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

In [3]:
matcher = CatalogMatcher()
f1_metric = DetectionPerformance().to(device=device)
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    target_tile_cat = TileCatalog(batch_on_device["tile_catalog"])
    target_full_cat = target_tile_cat.to_full_catalog(tile_slen)
    best_tile_dict = batch_on_device["tile_catalog"]
    best_tile_dict["locs"] = repeat(best_tile_dict["n_sources"] * 0.5, 
                                    "b h w -> b h w m k", 
                                    m=max_sources_per_tile, k=2)
    best_tile_cat = TileCatalog(best_tile_dict)
    best_full_cat = best_tile_cat.to_full_catalog(tile_slen)
    matching = matcher.match_catalogs(target_full_cat, best_full_cat)
    f1_metric.update(target_full_cat, best_full_cat, matching)

100%|██████████| 782/782 [03:15<00:00,  4.01it/s]


In [4]:
for k, v in f1_metric.compute().items():
    print(f"{k}: {v}")

detection_precision: 1.0
detection_recall: 1.0
detection_f1: 1.0
