In [1]:
import torch
import logging
import tqdm
import time

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device
from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from bliss.encoder.metrics import CatalogMatcher

# set device
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

# load config
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("half_pixel_notebook_config")

In [2]:
def test_tile_cat_equal(left_tile_cat, right_tile_cat):
    logger = logging.getLogger(__name__)
    is_equal = True
    for k, v in left_tile_cat.items():
        if k == "n_sources":
            right_v = right_tile_cat[k]
            left_v = v
        else:
            right_v = torch.where(right_tile_cat.is_on_mask.unsqueeze(-1),
                                right_tile_cat[k],
                                torch.zeros_like(right_tile_cat[k]))
            left_v = torch.where(left_tile_cat.is_on_mask.unsqueeze(-1),
                                v, torch.zeros_like(v))
        cur_test_equal = torch.allclose(left_v, right_v, equal_nan=True)
        if not cur_test_equal:
            logger.warning("%s are different", k)
        is_equal &= cur_test_equal
    return is_equal

In [3]:
# setup bliss encoder
tile_slen = notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = 2

dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

In [4]:
total_batch_num = len(dc2_val_dataloader)

In [None]:
load_data_start_time = time.time()
for batch in tqdm.tqdm(dc2_val_dataloader):
    pass
load_data_end_time = time.time()
print(f"load data time: {load_data_end_time - load_data_start_time:0.2f}")
print(f"load data per batch time: {(load_data_end_time - load_data_start_time) / total_batch_num:0.2f}")

In [None]:
dc2_val_dataloader = dc2.val_dataloader()
to_full_time = 0
to_tile_time = 0
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    ori_tile_cat = TileCatalog(batch_on_device["tile_catalog"])

    to_full_start_time = time.time()
    full_cat = ori_tile_cat.to_full_catalog(tile_slen)
    to_full_end_time = time.time()
    to_full_time += (to_full_end_time - to_full_start_time)

    to_tile_start_time = time.time()
    back_tile_cat = full_cat.to_tile_catalog(tile_slen, max_sources_per_tile=max_sources_per_tile)
    to_tile_end_time = time.time()
    to_tile_time += (to_tile_end_time - to_tile_start_time)
    
    assert test_tile_cat_equal(back_tile_cat, ori_tile_cat)
print(f"to full time: {to_full_time:0.2f}")
print(f"to tile time: {to_tile_time:0.2f}")
print(f"to full time per batch: {to_full_time / total_batch_num:0.5f}")
print(f"to tile time per batch: {to_tile_time / total_batch_num:0.5f}")

In [None]:
matcher = CatalogMatcher(dist_slack=0.1)
def test_full_cat_equal(left_full_cat, right_full_cat):
    for k, v in left_full_cat.items():
        if k == "n_sources":
            assert torch.allclose(left_full_cat[k], right_full_cat[k])
        elif k == "plocs":
            matching = matcher.match_catalogs(true_cat=left_full_cat, 
                                              est_cat=right_full_cat)
            for i, (true_match, est_match) in enumerate(matching):
                assert len(true_match) == len(est_match)
                assert len(true_match) == left_full_cat["n_sources"][i]

In [None]:
dc2_val_dataloader = dc2.val_dataloader()
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    # if batch_on_device["tile_catalog"]["n_sources"].max() > 1:
    #     continue
    
    ori_tile_cat = TileCatalog(batch_on_device["tile_catalog"])
    ori_full_cat = ori_tile_cat.to_full_catalog(tile_slen)
    larger_tile_cat = ori_full_cat.to_tile_catalog(4, max_sources_per_tile * 2)
    new_full_cat = larger_tile_cat.to_full_catalog(4)
    
    test_full_cat_equal(ori_full_cat, new_full_cat)