In [1]:
import torch
import tqdm

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from case_studies.dc2_diffusion.utils.slim_encoder import SlimDiffusionEncoder
from case_studies.dc2_diffusion.utils.only_locs_encoder import OnlyLocsDiffusionEncoder

In [2]:
model_name = "exp_01-23-1"
model_path = f"../../../bliss_output/DC2_only_locs_diffusion_exp/{model_name}/checkpoints/best_encoder.ckpt"
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
with initialize(config_path=".", version_base=None):
    slim_notebook_cfg = compose("only_locs_notebook_config")

In [3]:
tile_slen = slim_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = slim_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = slim_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(slim_notebook_cfg.surveys.dc2)
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: OnlyLocsDiffusionEncoder = instantiate(slim_notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder = bliss_encoder.eval()

In [None]:
print(bliss_encoder.ddim_steps)
print(bliss_encoder.ddim_objective)
print(bliss_encoder.ddim_beta_schedule)
# print(bliss_encoder.ddim_self_cond)
# print(bliss_encoder.add_fake_tiles)
# print(bliss_encoder.empty_tile_random_noise)
# print(bliss_encoder.correct_bits)

In [None]:
print(bliss_encoder.catalog_parser.factors[0].threshold)

In [6]:
bliss_encoder.catalog_parser.factors[0].threshold = 0.5

In [None]:
for ddim_steps in [1, 2, 5, 10, 50, 100, 200]:
    dc2_val_dataloader = dc2.val_dataloader()
    bliss_encoder.ddim_steps = ddim_steps
    bliss_encoder.detection_diffusion.sampling_timesteps = ddim_steps
    
    bliss_encoder.mode_metrics.reset()
    i = 0
    for batch in tqdm.tqdm(dc2_val_dataloader):
        batch_on_device = move_data_to_device(batch, device=device)
        with torch.no_grad():
            bliss_encoder.update_metrics(batch_on_device, i)
        i += 1
        
    for k, v in bliss_encoder.mode_metrics.compute().items():
        if "detection" in k:
            print(f"(ddim_steps={ddim_steps}) {k}: {v:0.5f}")
        if "n_est" in k:
            print(f"(ddim_steps={ddim_steps}) {k}: {int(v)}")

In [None]:
dc2_val_dataloader = dc2.val_dataloader()
bliss_encoder.ddim_steps = 1000
bliss_encoder.detection_diffusion.sampling_timesteps = 1000
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0  # use ddpm

bliss_encoder.mode_metrics.reset()
i = 0
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    with torch.no_grad():
        bliss_encoder.update_metrics(batch_on_device, i)
    i += 1

for k, v in bliss_encoder.mode_metrics.compute().items():
    if "detection" in k:
        print(f"{k}: {v:0.5f}")
    if "n_est" in k:
        print(f"{k}: {v}")