In [1]:
import torch
import tqdm
import os
import time

import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange, repeat
from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from case_studies.dc2_new_diffusion.utils.encoder import DiffusionEncoder

In [2]:
model_name = "exp_03-17-1"
model_check_point_name = "encoder_143.ckpt"
model_path = f"../../../bliss_output/DC2_ynet_full_diffusion_exp/{model_name}/checkpoints/{model_check_point_name}"
cached_data_path = Path("/data/scratch/pduan/posterior_cached_files")
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./ynet_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("ynet_full_diffusion_notebook_config")

In [3]:
seed = 7272
pytorch_lightning.seed_everything(seed=seed)

Seed set to 7272


7272

In [4]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.setup(stage="validate")

bliss_encoder: DiffusionEncoder = instantiate(new_diffusion_notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [5]:
print(bliss_encoder.ddim_steps)
print(bliss_encoder.ddim_objective)
print(bliss_encoder.ddim_beta_schedule)
print(bliss_encoder.detection_diffusion.ddim_sampling_eta)
print(bliss_encoder.ddim_self_cond)
print(bliss_encoder.catalog_parser.factors[0].threshold)

5
pred_noise
cosine
0.0
False
0.0


In [6]:
ddim_steps = 50
bliss_encoder.ddim_steps = ddim_steps
bliss_encoder.detection_diffusion.sampling_timesteps = ddim_steps
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0
bliss_encoder.catalog_parser.factors[0].threshold = 0.0

In [7]:
bliss_encoder = torch.compile(bliss_encoder)

In [8]:
model = bliss_encoder.detection_diffusion.model

In [9]:
total_iters = 10

for batch_size in [8, 16, 32, 64, 128]:
    dc2.batch_size = batch_size
    dc2_val_dataloader = dc2.val_dataloader()
    one_batch = next(iter(dc2_val_dataloader))
    one_batch = move_data_to_device(one_batch, device=device)
    total_sample_time = 0.0
    total_infer_time = 0.0
    for _ in tqdm.tqdm(range(total_iters)):
        with torch.inference_mode():
            sample_start_time = time.time()
            bliss_encoder.sample(one_batch, return_inter_output=False)
            sample_end_time = time.time()
            total_sample_time += sample_end_time - sample_start_time

            for _ in range(ddim_steps):
                rand_x = torch.randn(batch_size, 10, 20, 20, device=device)
                rand_t = torch.full((batch_size, ), 999, 
                                    device=device, dtype=torch.long)
                rand_input_image = torch.randn(batch_size, 6, 14, 80, 80, device=device)
                infer_start_time = time.time()
                model(rand_x, rand_t, rand_input_image, None)
                infer_end_time = time.time()
                total_infer_time += infer_end_time - infer_start_time

    print(f"(b={batch_size}) sample time: {total_sample_time / total_iters:.2f}s")
    print(f"(b={batch_size}) infer time: {total_infer_time / total_iters:.2f}s")

100%|██████████| 10/10 [00:19<00:00,  1.96s/it]

(b=8) sample time: 1.04s
(b=8) infer time: 0.91s



100%|██████████| 10/10 [00:20<00:00,  2.06s/it]

(b=16) sample time: 1.07s
(b=16) infer time: 0.98s



100%|██████████| 10/10 [00:28<00:00,  2.81s/it]

(b=32) sample time: 1.43s
(b=32) infer time: 1.37s



100%|██████████| 10/10 [00:51<00:00,  5.17s/it]

(b=64) sample time: 2.63s
(b=64) infer time: 2.53s



100%|██████████| 10/10 [01:39<00:00,  9.95s/it]

(b=128) sample time: 5.07s
(b=128) infer time: 4.87s



