In [7]:
import torch
from hydra import initialize, compose
from hydra.utils import instantiate
from bliss.surveys.dc2 import DC2
import tqdm
from pathlib import Path
import pickle

In [8]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")
dc2: DC2 = instantiate(notebook_cfg.surveys.dc2)
dc2.prepare_data()
dc2.setup()

dc2_train_dataloader = dc2.train_dataloader()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
output_dir = Path("./get_percentile_output/")
output_dir.mkdir(parents=True, exist_ok=True)

In [10]:
asinh_quantiles_tensor = torch.tensor([0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99], device=device)
bands = [0, 1, 2, 3, 4, 5]

In [11]:
thresholds_path = output_dir / "thresholds.pkl"
if not thresholds_path.exists():
    thresholds = []
    for batch in tqdm.tqdm(dc2_train_dataloader):
        batch_images = batch["images"][:, bands].unsqueeze(2).to(device=device)
        thresholds.append(torch.quantile(batch_images, q=asinh_quantiles_tensor))

    with open(thresholds_path, "wb") as output_f:
        pickle.dump(thresholds, output_f, pickle.HIGHEST_PROTOCOL)
else:
    with open(thresholds_path, "rb") as input_f:
        thresholds = pickle.load(input_f)

100%|██████████| 3063/3063 [01:37<00:00, 31.44it/s]


In [13]:
thresholds_tensor = torch.stack(thresholds).median(dim=0)[0].cpu()
print(thresholds_tensor)
thresholds_tensor_path = output_dir / "threshold_tensor.pkl"
with open(thresholds_tensor_path, "wb") as output_f:
    torch.save(thresholds_tensor, output_f)

tensor([-0.3856, -0.1059, -0.0336,  0.0073,  0.0569,  0.1658,  0.6423])
