In [None]:
import torch
import tqdm

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule

from case_studies.dc2_mdt.utils.encoder import DiffusionEncoder

In [None]:
model_name = "exp_04-10-2"
model_check_point_name = "encoder_67.ckpt"
model_path = f"../../../bliss_output/DC2_mdt_exp/{model_name}/checkpoints/{model_check_point_name}"
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./mdt_config", version_base=None):
    notebook_cfg = compose("mdt_notebook_config")

In [None]:
seed = 7272
pytorch_lightning.seed_everything(seed=seed)

In [None]:
tile_slen = notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.batch_size = 512
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: DiffusionEncoder = instantiate(notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [None]:
print(bliss_encoder.d_objective)
print(bliss_encoder.d_beta_schedule)
print(bliss_encoder.d_sampling_method)
print(bliss_encoder.d_sampling_timesteps)
print(bliss_encoder.ddim_eta)
print(bliss_encoder.catalog_parser.factors[0].threshold)

In [None]:
for sampling_steps in [20, 50, 100]:
    dc2_val_dataloader = dc2.val_dataloader()
    bliss_encoder.d_sampling_method = "ddpm"
    bliss_encoder.d_sampling_timesteps = sampling_steps
    bliss_encoder.training_diffusion.ddpm_sampling_timesteps = sampling_steps
    
    bliss_encoder.mode_metrics.reset()
    i = 0
    total_batch = len(dc2_val_dataloader)
    for batch in tqdm.tqdm(dc2_val_dataloader):
        batch_on_device = move_data_to_device(batch, device=device)
        with torch.no_grad():
            bliss_encoder.update_metrics(batch_on_device, i)
        i += 1
    for k, v in bliss_encoder.mode_metrics.compute().items():
        if "bin" not in k:
            print(f"(steps={sampling_steps}) {k}: {v:.2e}")