In [1]:
import torch
import logging
import math
import time

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device
from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog, FullCatalog
from typing import Dict, Tuple
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

# set device
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

# load config
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")

In [2]:
def to_tile_catalog(
        ori_full_cat,
        tile_slen: int,
        max_sources_per_tile: int,
        ignore_extra_sources=False,
        filter_oob=False,
    ) -> TileCatalog:
        """Returns the TileCatalog corresponding to this FullCatalog.

        Args:
            tile_slen: The side length of the tiles.
            max_sources_per_tile: The maximum number of sources in one tile.
            ignore_extra_sources: If False (default), raises an error if the number of sources
                in one tile exceeds the `max_sources_per_tile`. If True, only adds the tile
                parameters of the first `max_sources_per_tile` sources to the new TileCatalog.
            filter_oob: If filter_oob is true, filter out the sources outside the image. (e.g. In
                case of data augmentation, there is a chance of some sources located outside the
                image)

        Returns:
            TileCatalog correspond to the each source in the FullCatalog.

        Raises:
            ValueError: If the number of sources in one tile exceeds `max_sources_per_tile` and
                `ignore_extra_sources` is False.
            KeyError: If the tile_params contain `plocs` or `n_sources`.
        """
        # TODO: a FullCatalog only needs to "know" its height and width to convert itself to a
        # TileCatalog. So those parameters should be passed on conversion, not initialization.
        tile_coords = torch.div(ori_full_cat["plocs"], tile_slen, rounding_mode="trunc").to(torch.int)
        n_tiles_h = math.ceil(ori_full_cat.height / tile_slen)
        n_tiles_w = math.ceil(ori_full_cat.width / tile_slen)

        # prepare tiled tensors
        tile_cat_shape = (ori_full_cat.batch_size, n_tiles_h, n_tiles_w, max_sources_per_tile)
        tile_locs = torch.zeros((*tile_cat_shape, 2), device=ori_full_cat.device)
        tile_n_sources = torch.zeros(tile_cat_shape[:3], dtype=torch.int64, device=ori_full_cat.device)
        tile_params: Dict[str, Tensor] = {}
        for k, v in ori_full_cat.items():
            if k in {"plocs", "n_sources"}:
                continue
            size = (ori_full_cat.batch_size, n_tiles_h, n_tiles_w, max_sources_per_tile, v.shape[-1])
            tile_params[k] = torch.zeros(size, dtype=v.dtype, device=ori_full_cat.device)

        tile_params["locs"] = tile_locs

        for ii in range(ori_full_cat.batch_size):
            n_sources = int(ori_full_cat["n_sources"][ii].item())
            plocs_ii = ori_full_cat["plocs"][ii][:n_sources]
            filter_sources = n_sources
            source_tile_coords = tile_coords[ii][:n_sources]
            if filter_oob:
                x0_mask = (plocs_ii[:, 0] > 0) & (plocs_ii[:, 0] < ori_full_cat.height)
                x1_mask = (plocs_ii[:, 1] > 0) & (plocs_ii[:, 1] < ori_full_cat.width)
                x_mask = x0_mask * x1_mask
                filter_sources = x_mask.sum()
                source_tile_coords = source_tile_coords[x_mask]

            if filter_sources == 0:
                continue

            source_indices = source_tile_coords[:, 0] * n_tiles_w + source_tile_coords[
                :, 1
            ].unsqueeze(0)
            tile_indices = torch.arange(n_tiles_h * n_tiles_w, device=ori_full_cat.device).unsqueeze(1)

            tile_to_source_mapping = (source_indices == tile_indices).nonzero()
            tile_source_count: Tuple[Tensor, Tensor] = tile_to_source_mapping[:, 0].unique(
                sorted=True, return_counts=True
            )  # first element is tile index; second element is source count
            if tile_source_count[1].max() > max_sources_per_tile:
                if not ignore_extra_sources:
                    raise ValueError(  # noqa: WPS220
                        "# of sources per tile exceeds `max_sources_per_tile`."
                    )

            # get n_sources for each tile
            tile_n_sources[ii].view(-1)[tile_source_count[0].flatten().tolist()] = torch.where(
                tile_source_count[1] <= max_sources_per_tile,
                tile_source_count[1],
                max_sources_per_tile,
            )

            for k, v in tile_params.items():
                if k == "plocs":
                    raise KeyError("plocs should not be in tile_params")
                if k == "locs":
                    k = "plocs"
                if k == "n_sources":
                    raise KeyError("n_sources should not be in tile_params")
                param_matrix = ori_full_cat[k][ii][:n_sources]
                if filter_oob:
                    param_matrix = param_matrix[x_mask]
                params_on_tile = list(
                    param_matrix[tile_to_source_mapping[:, 1]].split(
                        tile_source_count[1].flatten().tolist()
                    )
                )
                # pad first tensor to desired length
                # the second argument of pad function is
                # padding_left, padding_right, padding_top, padding_bottom
                params_on_tile[0] = torch.nn.functional.pad(
                    params_on_tile[0],
                    (0, 0, 0, (max_sources_per_tile - params_on_tile[0].shape[0])),
                )
                # pad all tensors
                params_on_tile = pad_sequence(params_on_tile, batch_first=True)
                max_fill = min(filter_sources, max_sources_per_tile)
                v[ii].view(-1, *v[ii].shape[2:])[
                    tile_to_source_mapping[:, 0].unique(sorted=True).tolist(), :max_fill
                ] = params_on_tile[:, :max_fill].to(dtype=v.dtype)

            # modify tile location
            tile_params["locs"][ii] = (tile_params["locs"][ii] % tile_slen) / tile_slen
        tile_params.update({"n_sources": tile_n_sources})
        return TileCatalog(tile_params)

In [3]:
def test_tile_cat_equal(left_tile_cat, right_tile_cat):
    logger = logging.getLogger(__name__)
    is_equal = True
    for k, v in left_tile_cat.items():
        cur_test_equal = torch.allclose(right_tile_cat[k], v, equal_nan=True)
        if not cur_test_equal:
            logger.warning("%s are different", k)
        is_equal &= cur_test_equal
    return is_equal

In [4]:
# setup bliss encoder
tile_slen = notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = notebook_cfg.surveys.dc2.max_sources_per_tile

dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

In [5]:
old_gpu_time = 0
new_gpu_time = 0
for i, batch in enumerate(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    full_cat = TileCatalog(batch_on_device["tile_catalog"]).to_full_catalog(tile_slen)

    old_start_time = time.time()
    old_tile_cat = to_tile_catalog(full_cat, tile_slen=tile_slen, max_sources_per_tile=max_sources_per_tile)
    old_end_time = time.time()
    print(f"[{i}] old: {old_end_time - old_start_time: .3f}")
    old_gpu_time += (old_end_time - old_start_time)

    new_start_time = time.time()
    new_tile_cat = full_cat.to_tile_catalog(tile_slen=tile_slen, max_sources_per_tile=max_sources_per_tile)
    new_end_time = time.time()
    print(f"[{i}] new: {new_end_time - new_start_time: .3f}")
    new_gpu_time += (new_end_time - new_start_time)

    assert test_tile_cat_equal(old_tile_cat, new_tile_cat)

[0] old:  0.461
[0] new:  0.063
[1] old:  0.320
[1] new:  0.006
[2] old:  0.339
[2] new:  0.005
[3] old:  0.328
[3] new:  0.005
[4] old:  0.318
[4] new:  0.005
[5] old:  0.328
[5] new:  0.005
[6] old:  0.317
[6] new:  0.004
[7] old:  0.327
[7] new:  0.006
[8] old:  0.329
[8] new:  0.006
[9] old:  0.305
[9] new:  0.005
[10] old:  0.317
[10] new:  0.005
[11] old:  0.331
[11] new:  0.005
[12] old:  0.329
[12] new:  0.005
[13] old:  0.318
[13] new:  0.005
[14] old:  0.330
[14] new:  0.005
[15] old:  0.319
[15] new:  0.005
[16] old:  0.322
[16] new:  0.005
[17] old:  0.328
[17] new:  0.005
[18] old:  0.328
[18] new:  0.005
[19] old:  0.323
[19] new:  0.005
[20] old:  0.317
[20] new:  0.005
[21] old:  0.322
[21] new:  0.005
[22] old:  0.330
[22] new:  0.005
[23] old:  0.327
[23] new:  0.004
[24] old:  0.313
[24] new:  0.004
[25] old:  0.314
[25] new:  0.004
[26] old:  0.336
[26] new:  0.005
[27] old:  0.318
[27] new:  0.004
[28] old:  0.329
[28] new:  0.004
[29] old:  0.324
[29] new:  0.005


In [6]:
print(f"old_gpu_time: {old_gpu_time: .3f}")
print(f"new_gpu_time: {new_gpu_time: .3f}")
print(f"old/new: {old_gpu_time / new_gpu_time: .3f}")

old_gpu_time:  126.758
new_gpu_time:  1.898
old/new:  66.783


In [None]:
old_cpu_time = 0
new_cpu_time = 0
for i, batch in enumerate(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device="cpu")
    full_cat = TileCatalog(batch_on_device["tile_catalog"]).to_full_catalog(tile_slen)

    old_start_time = time.time()
    old_tile_cat = to_tile_catalog(full_cat, tile_slen=tile_slen, max_sources_per_tile=max_sources_per_tile)
    old_end_time = time.time()
    print(f"[{i}] old: {old_end_time - old_start_time: .3f}")
    old_cpu_time += (old_end_time - old_start_time)

    new_start_time = time.time()
    new_tile_cat = full_cat.to_tile_catalog(tile_slen=tile_slen, max_sources_per_tile=max_sources_per_tile)
    new_end_time = time.time()
    print(f"[{i}] new: {new_end_time - new_start_time: .3f}")
    new_cpu_time += (new_end_time - new_start_time)

    assert test_tile_cat_equal(old_tile_cat, new_tile_cat)

In [None]:
print(f"old_cpu_time: {old_cpu_time: .3f}")
print(f"new_cpu_time: {new_cpu_time: .3f}")
print(f"old/new: {old_cpu_time / new_cpu_time: .3f}")

In [7]:
full_image_old_time = 0
full_image_new_time = 0
for i in range(10):
    full_cat = dc2.get_plotting_sample(i)["full_catalog"]

    old_start_time = time.time()
    old_tile_cat = to_tile_catalog(full_cat, tile_slen, max_sources_per_tile)
    old_end_time = time.time()
    print(f"[{i}] old: {old_end_time - old_start_time: .3f}")
    full_image_old_time += (old_end_time - old_start_time)

    new_start_time = time.time()
    new_tile_cat = full_cat.to_tile_catalog(tile_slen, max_sources_per_tile)
    new_end_time = time.time()
    print(f"[{i}] new: {new_end_time - new_start_time: .3f}")
    full_image_new_time += (new_end_time -new_start_time)

    assert test_tile_cat_equal(old_tile_cat, new_tile_cat)

[0] old:  13.089
[0] new:  0.446
[1] old:  11.963
[1] new:  0.481
[2] old:  14.759
[2] new:  0.481
[3] old:  13.642
[3] new:  0.466
[4] old:  14.058
[4] new:  0.489
[5] old:  14.735
[5] new:  0.421
[6] old:  13.105
[6] new:  0.501
[7] old:  13.838
[7] new:  0.477
[8] old:  12.556
[8] new:  0.467
[9] old:  12.543
[9] new:  0.535


In [8]:
print(f"full_image_old_time: {full_image_old_time: .3f}")
print(f"full_image_new_time: {full_image_new_time: .3f}")
print(f"old/new: {full_image_old_time / full_image_new_time: .3f}")

full_image_old_time:  134.288
full_image_new_time:  4.765
old/new:  28.182


In [None]:
test_plocs = torch.zeros(1, 10, 2)
test_plocs[0, 0] = torch.tensor([200.0, 200.0])
test_n_sources = torch.zeros(1, dtype=torch.int32)
test_n_sources[0] = 1
test_flux = torch.zeros(1, 10, 6)
test_full_cat = FullCatalog(200, 200, d={
    "plocs": test_plocs,
    "n_sources": test_n_sources,
    "fluxes": test_flux,
})

In [None]:
test_full_cat.to_tile_catalog(4, 3)