In [None]:
import torch
import tqdm
from torch import nn
import torch.nn.functional as F
from torch import optim

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from bliss.encoder.metrics import CatalogMatcher
from bliss.encoder.convnet_layers import C3, ConvBlock
from bliss.global_env import GlobalEnv

from case_studies.dc2_new_diffusion.utils.catalog_parser import CatalogParser
from case_studies.dc2_new_diffusion.utils.metrics import DetectionPerformance

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
with initialize(config_path=".", version_base=None):
    new_diffusion_notebook_cfg = compose("new_diffusion_notebook_config")

In [3]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = 1024
dc2.setup(stage="fit")
GlobalEnv.current_encoder_epoch = 1
GlobalEnv.seed_in_this_program = 7272
dc2_train_dataloader = dc2.train_dataloader()

catalog_parser: CatalogParser = instantiate(new_diffusion_notebook_cfg.encoder.catalog_parser)

In [4]:
target_ch = catalog_parser.n_params_per_source
postprocess_net_ch = 16
postprocess_net = nn.Sequential(
            ConvBlock(target_ch, postprocess_net_ch, kernel_size=5),
            ConvBlock(postprocess_net_ch, postprocess_net_ch * 2, kernel_size=3, stride=2),
            C3(postprocess_net_ch * 2, postprocess_net_ch * 2, n=3),
            ConvBlock(postprocess_net_ch * 2, postprocess_net_ch * 4, kernel_size=3, stride=2),
            C3(postprocess_net_ch * 4, postprocess_net_ch * 4, n=3),
            ConvBlock(postprocess_net_ch * 4, postprocess_net_ch * 4, kernel_size=1),
            nn.Conv2d(postprocess_net_ch * 4, target_ch, kernel_size=1)
        )
postprocess_net = postprocess_net.to(device=device)
optimizer = optim.Adam(postprocess_net.parameters(), lr=1e-3)

In [5]:
postprocess_net.train()
total_batch = len(dc2_train_dataloader)
i = 0
epoch = 5
for _ in range(epoch):
    dc2_train_dataloader = dc2.train_dataloader()
    for batch in tqdm.tqdm(dc2_train_dataloader):
        batch_on_device = move_data_to_device(batch, device=device)
        target_cat = TileCatalog(batch_on_device["tile_catalog"])
        target_cat1 = target_cat.get_brightest_sources_per_tile(
            band=2, exclude_num=0
        )
        encoded_catalog_tensor = catalog_parser.encode(target_cat1).permute([0, 3, 1, 2])  # (b, k, h, w)
        upsampled_catalog_tensor = F.interpolate(encoded_catalog_tensor, 
                                                    scale_factor=4, 
                                                    mode="bilinear")  # (b, k, H, W)
        optimizer.zero_grad()
        pred = postprocess_net(upsampled_catalog_tensor)
        loss = ((pred - encoded_catalog_tensor) ** 2).mean()
        loss.backward()
        optimizer.step()

        if (i + 1) % 50 == 0:
            print(f"step [{i + 1}/{total_batch}], loss: {loss.item():.6f}")
        i += 1
        GlobalEnv.current_encoder_epoch += 1

 26%|██▌       | 50/191 [00:41<02:11,  1.07it/s]

step [50/191], loss: 0.028013


 53%|█████▎    | 101/191 [01:14<01:19,  1.14it/s]

step [100/191], loss: 0.017862


 79%|███████▊  | 150/191 [01:45<00:31,  1.31it/s]

step [150/191], loss: 0.010165


100%|██████████| 191/191 [02:06<00:00,  1.51it/s]
  5%|▌         | 10/191 [00:13<01:07,  2.69it/s]

step [200/191], loss: 0.005332


 31%|███▏      | 60/191 [00:47<00:38,  3.44it/s]

step [250/191], loss: 0.004079


 58%|█████▊    | 110/191 [01:16<00:20,  4.03it/s]

step [300/191], loss: 0.003695


 84%|████████▍ | 160/191 [01:46<00:07,  3.94it/s]

step [350/191], loss: 0.003214


100%|██████████| 191/191 [02:04<00:00,  1.53it/s]
  9%|▉         | 18/191 [00:21<04:24,  1.53s/it]

step [400/191], loss: 0.003123


 36%|███▌      | 69/191 [00:57<01:22,  1.48it/s]

step [450/191], loss: 0.003004


 62%|██████▏   | 119/191 [01:31<00:37,  1.90it/s]

step [500/191], loss: 0.002324


 88%|████████▊ | 169/191 [02:16<00:08,  2.72it/s]

step [550/191], loss: 0.001919


100%|██████████| 191/191 [02:26<00:00,  1.30it/s]
 14%|█▍        | 27/191 [00:25<01:08,  2.40it/s]

step [600/191], loss: 0.001682


 41%|████      | 78/191 [01:07<00:42,  2.64it/s]

step [650/191], loss: 0.001770


 66%|██████▋   | 127/191 [01:40<00:41,  1.56it/s]

step [700/191], loss: 0.001268


 93%|█████████▎| 177/191 [02:20<00:08,  1.74it/s]

step [750/191], loss: 0.000823


100%|██████████| 191/191 [02:27<00:00,  1.30it/s]
 19%|█▉        | 36/191 [00:31<01:34,  1.65it/s]

step [800/191], loss: 0.000419


 45%|████▌     | 86/191 [01:10<01:31,  1.14it/s]

step [850/191], loss: 0.000325


 71%|███████   | 136/191 [01:46<01:48,  1.98s/it]

step [900/191], loss: 0.000242


 98%|█████████▊| 187/191 [02:13<00:02,  2.00it/s]

step [950/191], loss: 0.000209


100%|██████████| 191/191 [02:14<00:00,  1.42it/s]


In [6]:
matcher = CatalogMatcher(dist_slack=1.0)
f1_metric = DetectionPerformance().to(device=device)
dc2_val_dataloader = dc2.val_dataloader()

In [7]:
postprocess_net.eval()
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    target_tile_cat = TileCatalog(batch_on_device["tile_catalog"])
    target_tile_cat = target_tile_cat.get_brightest_sources_per_tile(band=2,  exclude_num=0)
    target_full_cat = target_tile_cat.to_full_catalog(tile_slen)

    encoded_catalog_tensor = catalog_parser.encode(target_tile_cat).permute([0, 3, 1, 2])  # (b, k, h, w)
    upsampled_catalog_tensor = F.interpolate(encoded_catalog_tensor, 
                                                 scale_factor=4, 
                                                 mode="bilinear")  # (b, k, H, W)
    with torch.no_grad():
        pred = postprocess_net(upsampled_catalog_tensor)
    pred = catalog_parser.clip_tensor(pred.permute([0, 2, 3, 1]))
    pred_tile_cat = catalog_parser.decode(pred)
    pred_full_cat = pred_tile_cat.to_full_catalog(tile_slen)

    matching = matcher.match_catalogs(target_full_cat, pred_full_cat)
    f1_metric.update(target_full_cat, pred_full_cat, matching)

100%|██████████| 25/25 [00:31<00:00,  1.24s/it]


In [8]:
for k, v in f1_metric.compute().items():
    print(f"{k}: {v}")

detection_precision: 0.9716938138008118
detection_recall: 0.9717060923576355
detection_f1: 0.9716999530792236
n_true_sources: 157914.0
n_est_sources: 157916.0
