In [1]:
import torch
import tqdm

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule

from case_studies.dc2_new_diffusion.utils.encoder import DiffusionEncoder

In [2]:
model_name = "exp_03-16-3"
model_check_point_name = "encoder_113.ckpt"
model_path = f"../../../bliss_output/DC2_ynet_full_diffusion_exp/{model_name}/checkpoints/{model_check_point_name}"
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./ynet_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("ynet_full_diffusion_notebook_config")

In [3]:
seed = 7272
pytorch_lightning.seed_everything(seed=seed)

Seed set to 7272


7272

In [4]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = 512
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: DiffusionEncoder = instantiate(new_diffusion_notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [5]:
print(bliss_encoder.ddim_steps)
print(bliss_encoder.ddim_objective)
print(bliss_encoder.ddim_beta_schedule)
print(bliss_encoder.detection_diffusion.ddim_sampling_eta)
print(bliss_encoder.ddim_self_cond)
print(bliss_encoder.catalog_parser.factors[0].threshold)

5
pred_noise
cosine
0.0
False
0.0


In [6]:
bliss_encoder.catalog_parser.factors[0].threshold = 0.0

In [None]:
for ddim_steps in [1, 2, 5, 10, 20, 50, 100, 200]:
    dc2_val_dataloader = dc2.val_dataloader()
    bliss_encoder.ddim_steps = ddim_steps
    bliss_encoder.detection_diffusion.sampling_timesteps = ddim_steps
    
    bliss_encoder.mode_metrics.reset()
    i = 0
    total_batch = len(dc2_val_dataloader)
    for batch in tqdm.tqdm(dc2_val_dataloader):
        batch_on_device = move_data_to_device(batch, device=device)
        with torch.no_grad():
            bliss_encoder.update_metrics(batch_on_device, i)
        i += 1
    for k, v in bliss_encoder.mode_metrics.compute().items():
        # if "detection" in k:
        #     print(f"(ddim_steps={ddim_steps}) {k}: {v:0.5f}")
        # if "n_est" in k:
        #     print(f"(ddim_steps={ddim_steps}) {k}: {int(v)}")
        if "bin" not in k:
            print(f"(ddim_steps={ddim_steps}) {k}: {v:.2e}")

100%|██████████| 49/49 [01:27<00:00,  1.78s/it]


(ddim_steps=1) detection_precision: 5.05e-03
(ddim_steps=1) detection_recall: 1.45e-01
(ddim_steps=1) detection_f1: 9.77e-03
(ddim_steps=1) n_true_sources: 1.56e+05
(ddim_steps=1) n_est_sources: 4.48e+06
(ddim_steps=1) flux_err_u_mae: 6.15e+09
(ddim_steps=1) flux_err_u_mpe: -9.54e+07
(ddim_steps=1) flux_err_u_mape: 9.54e+07
(ddim_steps=1) flux_err_g_mae: 5.28e+09
(ddim_steps=1) flux_err_g_mpe: -3.90e+07
(ddim_steps=1) flux_err_g_mape: 3.90e+07
(ddim_steps=1) flux_err_r_mae: 5.48e+09
(ddim_steps=1) flux_err_r_mpe: -2.51e+07
(ddim_steps=1) flux_err_r_mape: 2.51e+07
(ddim_steps=1) flux_err_i_mae: 6.13e+09
(ddim_steps=1) flux_err_i_mpe: -2.05e+07
(ddim_steps=1) flux_err_i_mape: 2.05e+07
(ddim_steps=1) flux_err_z_mae: 5.90e+09
(ddim_steps=1) flux_err_z_mpe: -1.67e+07
(ddim_steps=1) flux_err_z_mape: 1.67e+07
(ddim_steps=1) flux_err_y_mae: 6.76e+09
(ddim_steps=1) flux_err_y_mpe: -1.68e+07
(ddim_steps=1) flux_err_y_mape: 1.68e+07
(ddim_steps=1) n_sources:s0_precision: 9.97e-01
(ddim_steps=1) n

100%|██████████| 49/49 [01:49<00:00,  2.24s/it]


(ddim_steps=2) detection_precision: 1.21e-01
(ddim_steps=2) detection_recall: 4.67e-01
(ddim_steps=2) detection_f1: 1.93e-01
(ddim_steps=2) n_true_sources: 1.56e+05
(ddim_steps=2) n_est_sources: 5.99e+05
(ddim_steps=2) flux_err_u_mae: 1.69e+05
(ddim_steps=2) flux_err_u_mpe: -7.06e+02
(ddim_steps=2) flux_err_u_mape: 7.06e+02
(ddim_steps=2) flux_err_g_mae: 6.70e+05
(ddim_steps=2) flux_err_g_mpe: -6.86e+02
(ddim_steps=2) flux_err_g_mape: 6.86e+02
(ddim_steps=2) flux_err_r_mae: 4.27e+06
(ddim_steps=2) flux_err_r_mpe: -3.05e+03
(ddim_steps=2) flux_err_r_mape: 3.05e+03
(ddim_steps=2) flux_err_i_mae: 1.49e+06
(ddim_steps=2) flux_err_i_mpe: -6.44e+02
(ddim_steps=2) flux_err_i_mape: 6.44e+02
(ddim_steps=2) flux_err_z_mae: 2.00e+06
(ddim_steps=2) flux_err_z_mpe: -6.04e+02
(ddim_steps=2) flux_err_z_mape: 6.04e+02
(ddim_steps=2) flux_err_y_mae: 2.32e+06
(ddim_steps=2) flux_err_y_mpe: -3.44e+02
(ddim_steps=2) flux_err_y_mape: 3.44e+02
(ddim_steps=2) n_sources:s0_precision: 9.96e-01
(ddim_steps=2) n

100%|██████████| 49/49 [02:46<00:00,  3.39s/it]


(ddim_steps=5) detection_precision: 7.66e-01
(ddim_steps=5) detection_recall: 8.13e-01
(ddim_steps=5) detection_f1: 7.88e-01
(ddim_steps=5) n_true_sources: 1.56e+05
(ddim_steps=5) n_est_sources: 1.65e+05
(ddim_steps=5) flux_err_u_mae: 3.42e+04
(ddim_steps=5) flux_err_u_mpe: -2.02e+00
(ddim_steps=5) flux_err_u_mape: 2.29e+00
(ddim_steps=5) flux_err_g_mae: 9.09e+04
(ddim_steps=5) flux_err_g_mpe: -1.05e+00
(ddim_steps=5) flux_err_g_mape: 1.25e+00
(ddim_steps=5) flux_err_r_mae: 3.50e+05
(ddim_steps=5) flux_err_r_mpe: -2.13e+00
(ddim_steps=5) flux_err_r_mape: 2.28e+00
(ddim_steps=5) flux_err_i_mae: 1.94e+05
(ddim_steps=5) flux_err_i_mpe: -1.76e+00
(ddim_steps=5) flux_err_i_mape: 1.85e+00
(ddim_steps=5) flux_err_z_mae: 2.17e+05
(ddim_steps=5) flux_err_z_mpe: -2.25e+00
(ddim_steps=5) flux_err_z_mape: 2.34e+00
(ddim_steps=5) flux_err_y_mae: 2.90e+05
(ddim_steps=5) flux_err_y_mpe: -1.88e+00
(ddim_steps=5) flux_err_y_mape: 2.04e+00
(ddim_steps=5) n_sources:s0_precision: 9.97e-01
(ddim_steps=5) n

100%|██████████| 49/49 [04:20<00:00,  5.32s/it]


(ddim_steps=10) detection_precision: 8.15e-01
(ddim_steps=10) detection_recall: 8.26e-01
(ddim_steps=10) detection_f1: 8.21e-01
(ddim_steps=10) n_true_sources: 1.56e+05
(ddim_steps=10) n_est_sources: 1.58e+05
(ddim_steps=10) flux_err_u_mae: 5.78e+03
(ddim_steps=10) flux_err_u_mpe: -3.58e-01
(ddim_steps=10) flux_err_u_mape: 5.71e-01
(ddim_steps=10) flux_err_g_mae: 1.93e+04
(ddim_steps=10) flux_err_g_mpe: -1.37e-01
(ddim_steps=10) flux_err_g_mape: 2.94e-01
(ddim_steps=10) flux_err_r_mae: 3.23e+04
(ddim_steps=10) flux_err_r_mpe: -1.49e-01
(ddim_steps=10) flux_err_r_mape: 3.00e-01
(ddim_steps=10) flux_err_i_mae: 3.41e+04
(ddim_steps=10) flux_err_i_mpe: -1.83e-01
(ddim_steps=10) flux_err_i_mape: 3.09e-01
(ddim_steps=10) flux_err_z_mae: 3.35e+04
(ddim_steps=10) flux_err_z_mpe: -2.50e-01
(ddim_steps=10) flux_err_z_mape: 3.78e-01
(ddim_steps=10) flux_err_y_mae: 4.04e+04
(ddim_steps=10) flux_err_y_mpe: -2.21e-01
(ddim_steps=10) flux_err_y_mape: 4.05e-01
(ddim_steps=10) n_sources:s0_precision: 9

100%|██████████| 49/49 [07:33<00:00,  9.25s/it]


(ddim_steps=20) detection_precision: 8.21e-01
(ddim_steps=20) detection_recall: 8.19e-01
(ddim_steps=20) detection_f1: 8.20e-01
(ddim_steps=20) n_true_sources: 1.56e+05
(ddim_steps=20) n_est_sources: 1.55e+05
(ddim_steps=20) flux_err_u_mae: 3.90e+03
(ddim_steps=20) flux_err_u_mpe: -2.11e-01
(ddim_steps=20) flux_err_u_mape: 4.51e-01
(ddim_steps=20) flux_err_g_mae: 1.28e+04
(ddim_steps=20) flux_err_g_mpe: -6.91e-02
(ddim_steps=20) flux_err_g_mape: 2.18e-01
(ddim_steps=20) flux_err_r_mae: 2.09e+04
(ddim_steps=20) flux_err_r_mpe: -4.44e-02
(ddim_steps=20) flux_err_r_mape: 2.01e-01
(ddim_steps=20) flux_err_i_mae: 1.90e+04
(ddim_steps=20) flux_err_i_mpe: -6.20e-02
(ddim_steps=20) flux_err_i_mape: 1.89e-01
(ddim_steps=20) flux_err_z_mae: 2.10e+04
(ddim_steps=20) flux_err_z_mpe: -8.33e-02
(ddim_steps=20) flux_err_z_mape: 2.38e-01
(ddim_steps=20) flux_err_y_mae: 2.19e+04
(ddim_steps=20) flux_err_y_mpe: -8.63e-02
(ddim_steps=20) flux_err_y_mape: 2.90e-01
(ddim_steps=20) n_sources:s0_precision: 9

  6%|▌         | 3/49 [01:07<16:49, 21.94s/it]

In [None]:
dc2_val_dataloader = dc2.val_dataloader()
bliss_encoder.ddim_steps = 1000
bliss_encoder.detection_diffusion.sampling_timesteps = 1000
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0  # use ddpm

bliss_encoder.mode_metrics.reset()
i = 0
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    with torch.no_grad():
        bliss_encoder.update_metrics(batch_on_device, i)
    i += 1

for k, v in bliss_encoder.mode_metrics.compute().items():
    if "detection" in k:
        print(f"{k}: {v:0.5f}")
    if "n_est" in k:
        print(f"{k}: {v}")