In [1]:
import torch
import tqdm

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from case_studies.dc2_new_diffusion.utils.encoder import DiffusionEncoder

In [2]:
model_name = "exp_01-27-3"
model_check_point_name = "encoder_48.ckpt"
model_path = f"../../../bliss_output/DC2_ynet_diffusion_exp/{model_name}/checkpoints/{model_check_point_name}"
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./ynet_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("ynet_diffusion_notebook_config")

In [3]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: DiffusionEncoder = instantiate(new_diffusion_notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [None]:
print(bliss_encoder.ddim_steps)
print(bliss_encoder.ddim_objective)
print(bliss_encoder.ddim_beta_schedule)

In [None]:
print(bliss_encoder.catalog_parser.factors[0].threshold)

In [6]:
bliss_encoder.catalog_parser.factors[0].threshold = 0.0

In [None]:
for ddim_steps in [1, 2, 5, 10, 20, 50, 100, 200]:
    dc2_val_dataloader = dc2.val_dataloader()
    bliss_encoder.ddim_steps = ddim_steps
    bliss_encoder.detection_diffusion.sampling_timesteps = ddim_steps
    
    bliss_encoder.mode_metrics.reset()
    i = 0
    total_batch = len(dc2_val_dataloader)
    # total_x_start_l1_loss = 0.0
    # total_x_start_l2_loss = 0.0
    # total_final_pred_l1_loss = 0.0
    # total_final_pred_l2_loss = 0.0
    for batch in tqdm.tqdm(dc2_val_dataloader):
        batch_on_device = move_data_to_device(batch, device=device)
        with torch.no_grad():
            # bliss_encoder._compute_cur_batch_loss(batch_on_device)
            # final_cat_tensor, inter_cat_tensor, final_pred, x_start = bliss_encoder.sample_all_details(batch_on_device)
            # assert inter_cat_tensor.shape == x_start.shape
            # assert final_cat_tensor.shape == final_pred.shape
            # total_x_start_l1_loss += (inter_cat_tensor - x_start).abs().mean()
            # total_x_start_l2_loss += ((inter_cat_tensor - x_start) ** 2).mean()
            # total_final_pred_l1_loss += (final_cat_tensor - final_pred).abs().mean()
            # total_final_pred_l2_loss += ((final_cat_tensor - final_pred) ** 2).mean()
            bliss_encoder.update_metrics(batch_on_device, i)
        i += 1
    # print(f"(ddim_steps={ddim_steps}) x_start l1 loss: {total_x_start_l1_loss / total_batch:0.4f}")
    # print(f"(ddim_steps={ddim_steps}) x_start l2 loss: {total_x_start_l2_loss / total_batch:0.4f}")
    # print(f"(ddim_steps={ddim_steps}) final_pred l1 loss: {total_final_pred_l1_loss / total_batch:0.4f}")
    # print(f"(ddim_steps={ddim_steps}) final_pred l2 loss: {total_final_pred_l2_loss / total_batch:0.4f}")
    for k, v in bliss_encoder.mode_metrics.compute().items():
        if "detection" in k:
            print(f"(ddim_steps={ddim_steps}) {k}: {v:0.5f}")
        if "n_est" in k:
            print(f"(ddim_steps={ddim_steps}) {k}: {int(v)}")

In [None]:
dc2_val_dataloader = dc2.val_dataloader()
bliss_encoder.ddim_steps = 1000
bliss_encoder.detection_diffusion.sampling_timesteps = 1000
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0  # use ddpm

bliss_encoder.mode_metrics.reset()
i = 0
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch_on_device = move_data_to_device(batch, device=device)
    with torch.no_grad():
        bliss_encoder.update_metrics(batch_on_device, i)
    i += 1

for k, v in bliss_encoder.mode_metrics.compute().items():
    if "detection" in k:
        print(f"{k}: {v:0.5f}")
    if "n_est" in k:
        print(f"{k}: {v}")