In [1]:
import torch
import tqdm

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule

from case_studies.dc2_new_diffusion.utils.encoder import DiffusionEncoder

In [2]:
# model_name = "exp_03-28-4"
# model_check_point_name = "encoder_199.ckpt"
# model_path = f"../../../bliss_output/DC2_ynet_full_diffusion_exp/{model_name}/checkpoints/{model_check_point_name}"
# device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
# with initialize(config_path="./ynet_diffusion_config", version_base=None):ss
#     new_diffusion_notebook_cfg = compose("ynet_full_diffusion_notebook_config")

In [3]:
model_name = "exp_04-04-2"
model_check_point_name = "encoder_97.ckpt"
model_path = f"../../../bliss_output/DC2_simple_net_diffusion_exp/{model_name}/checkpoints/{model_check_point_name}"
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./simple_net_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("simple_net_diffusion_notebook_config")

In [4]:
seed = 7272
pytorch_lightning.seed_everything(seed=seed)

Seed set to 7272


7272

In [5]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = 512
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: DiffusionEncoder = instantiate(new_diffusion_notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();



In [6]:
print(bliss_encoder.ddim_steps)
print(bliss_encoder.ddim_objective)
print(bliss_encoder.ddim_beta_schedule)
print(bliss_encoder.detection_diffusion.ddim_sampling_eta)
print(bliss_encoder.ddim_self_cond)
# print(bliss_encoder.catalog_parser.factors[0].threshold)

5
pred_noise
cosine
0.0
False


In [7]:
# bliss_encoder.catalog_parser.factors[0].threshold = 0.0

In [8]:
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0

In [9]:
for ddim_steps in [5]:
    dc2_val_dataloader = dc2.val_dataloader()
    bliss_encoder.ddim_steps = ddim_steps
    bliss_encoder.detection_diffusion.sampling_timesteps = ddim_steps
    
    bliss_encoder.mode_metrics.reset()
    i = 0
    total_batch = len(dc2_val_dataloader)
    for batch in tqdm.tqdm(dc2_val_dataloader):
        batch_on_device = move_data_to_device(batch, device=device)
        with torch.no_grad():
            bliss_encoder.update_metrics(batch_on_device, i)
        i += 1
    for k, v in bliss_encoder.mode_metrics.compute().items():
        # if "detection" in k:
        #     print(f"(ddim_steps={ddim_steps}) {k}: {v:0.5f}")
        # if "n_est" in k:
        #     print(f"(ddim_steps={ddim_steps}) {k}: {int(v)}")
        # if "bin" not in k:
        #     print(f"(ddim_steps={ddim_steps}) {k}: {v:.2e}")
        print(f"(ddim_steps={ddim_steps}) {k}: {v:.2e}")

100%|██████████| 49/49 [02:21<00:00,  2.89s/it]

(ddim_steps=5) flux_err_u_mae: 9.46e+01
(ddim_steps=5) flux_err_u_mpe: -5.74e-01
(ddim_steps=5) flux_err_u_mape: 7.80e-01
(ddim_steps=5) flux_err_u_mae_bin_0: 3.86e+02
(ddim_steps=5) flux_err_u_mpe_bin_0: 2.13e-02
(ddim_steps=5) flux_err_u_mape_bin_0: 2.38e-01
(ddim_steps=5) flux_err_u_mae_bin_1: 9.41e+01
(ddim_steps=5) flux_err_u_mpe_bin_1: -2.24e-02
(ddim_steps=5) flux_err_u_mape_bin_1: 2.97e-01
(ddim_steps=5) flux_err_u_mae_bin_2: 8.15e+01
(ddim_steps=5) flux_err_u_mpe_bin_2: -2.43e-02
(ddim_steps=5) flux_err_u_mape_bin_2: 3.25e-01
(ddim_steps=5) flux_err_u_mae_bin_3: 6.43e+01
(ddim_steps=5) flux_err_u_mpe_bin_3: -3.29e-02
(ddim_steps=5) flux_err_u_mape_bin_3: 3.38e-01
(ddim_steps=5) flux_err_u_mae_bin_4: 4.29e+01
(ddim_steps=5) flux_err_u_mpe_bin_4: -1.81e-01
(ddim_steps=5) flux_err_u_mape_bin_4: 4.53e-01
(ddim_steps=5) flux_err_u_mae_bin_5: 3.90e+01
(ddim_steps=5) flux_err_u_mpe_bin_5: -1.17e+00
(ddim_steps=5) flux_err_u_mape_bin_5: 1.29e+00
(ddim_steps=5) flux_err_g_mae: 1.12e+02




In [10]:
for k, v in bliss_encoder.mode_metrics.compute().items():
    if "mape" in k:
        print(f"(ddim_steps={ddim_steps}) {k}: {v:.2e}")

(ddim_steps=5) flux_err_u_mape: 7.80e-01
(ddim_steps=5) flux_err_u_mape_bin_0: 2.38e-01
(ddim_steps=5) flux_err_u_mape_bin_1: 2.97e-01
(ddim_steps=5) flux_err_u_mape_bin_2: 3.25e-01
(ddim_steps=5) flux_err_u_mape_bin_3: 3.38e-01
(ddim_steps=5) flux_err_u_mape_bin_4: 4.53e-01
(ddim_steps=5) flux_err_u_mape_bin_5: 1.29e+00
(ddim_steps=5) flux_err_g_mape: 3.14e-01
(ddim_steps=5) flux_err_g_mape_bin_0: 1.75e-01
(ddim_steps=5) flux_err_g_mape_bin_1: 1.81e-01
(ddim_steps=5) flux_err_g_mape_bin_2: 1.90e-01
(ddim_steps=5) flux_err_g_mape_bin_3: 1.90e-01
(ddim_steps=5) flux_err_g_mape_bin_4: 2.03e-01
(ddim_steps=5) flux_err_g_mape_bin_5: 4.62e-01
(ddim_steps=5) flux_err_r_mape: 1.91e-01
(ddim_steps=5) flux_err_r_mape_bin_0: 1.62e-01
(ddim_steps=5) flux_err_r_mape_bin_1: 1.80e-01
(ddim_steps=5) flux_err_r_mape_bin_2: 1.90e-01
(ddim_steps=5) flux_err_r_mape_bin_3: 1.98e-01
(ddim_steps=5) flux_err_r_mape_bin_4: 1.89e-01
(ddim_steps=5) flux_err_r_mape_bin_5: 2.00e-01
(ddim_steps=5) flux_err_i_mape:

In [None]:
dc2_val_dataloader = dc2.val_dataloader()
bliss_encoder.ddim_steps = 1000
bliss_encoder.detection_diffusion.sampling_timesteps = 1000
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0  # use ddpm

bliss_encoder.mode_metrics.reset()
i = 0
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    with torch.no_grad():
        bliss_encoder.update_metrics(batch_on_device, i)
    i += 1

for k, v in bliss_encoder.mode_metrics.compute().items():
    if "detection" in k:
        print(f"{k}: {v:0.5f}")
    if "n_est" in k:
        print(f"{k}: {v}")