In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import torch
from os import environ
from pathlib import Path
from einops import rearrange
import pickle
from tqdm import tqdm
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from hydra import initialize, compose
from hydra.utils import instantiate

from pytorch_lightning.utilities import move_data_to_device

from bliss.catalog import FullCatalog, BaseTileCatalog, TileCatalog
from bliss.surveys.dc2 import DC2DataModule
from case_studies.redshift.evaluation.utils.load_lsst import get_lsst_full_cat
from case_studies.redshift.evaluation.utils.safe_metric_collection import SafeMetricCollection as MetricCollection
from case_studies.redshift.redshift_from_img.encoder.metrics import RedshiftMeanSquaredErrorBin

environ["BLISS_HOME"] = str(Path().resolve().parents[2])

output_dir = Path("/data/scratch/declan/redshift")
output_dir.mkdir(parents=True, exist_ok=True)

# change this model path according to your training setting
model_path = "/data/scratch/jaloper/redshift/encoder_0.133145.ckpt"
lsst_root_dir = "/data/scratch/dc2_nfs/"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("discrete_eval")

# with initialize(config_path=".", version_base=None):
#     notebook_cfg = compose("notebook_discrete_plot")

In [4]:
# set up testing dataset
dataset = instantiate(notebook_cfg.train.data_source)
dataset.setup("test")


In [5]:
len(dataset.test_dataset)

25000

In [6]:
import os
os.environ['OMP_NUM_THREADS'] = '16'
os.environ['MKL_NUM_THREADS'] = '16'
os.environ['NUMEXPR_NUM_THREADS'] = '16'

In [7]:
dataset

<bliss.surveys.dc2.DC2DataModule at 0x7f56124a9000>

In [None]:
for batch_idx, batch in tqdm(enumerate(dataset.test_dataloader()), total=len(dataset.test_dataloader())):
        batch["images"] = batch["images"].to(device)

### bliss using discrete varational dist

In [8]:
notebook_cfg.encoder.mode_metrics

{'_target_': 'torchmetrics.MetricCollection', '_convert_': 'partial', 'metrics': '${mode_sample_metrics}'}

In [9]:
notebook_cfg.encoder.sample_metrics

{'_target_': 'torchmetrics.MetricCollection', '_convert_': 'partial', 'metrics': '${mode_sample_metrics}'}

In [10]:
notebook_cfg.encoder.discrete_metrics

{'_target_': 'torchmetrics.MetricCollection', '_convert_': 'partial', 'metrics': '${discrete_metrics}'}

In [11]:
notebook_cfg.encoder.mode_metrics

{'_target_': 'torchmetrics.MetricCollection', '_convert_': 'partial', 'metrics': '${mode_sample_metrics}'}

In [None]:
with initialize(config_path="../redshift_from_img", version_base=None):
    notebook_cfg = compose("discrete")

In [12]:
# Loop through the test set and update the metric
# load bliss trained model
bliss_encoder = instantiate(notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();



In [None]:
bliss_encoder.mode_metrics

In [None]:
bliss_encoder.discrete_metrics

In [None]:
bliss_encoder

In [None]:
bliss_encoder.var_dist.sample

In [None]:
bliss_encoder.var_dist.discrete_sample

In [None]:
import torch

from bliss.catalog import TileCatalog
from bliss.encoder.encoder import Encoder

In [None]:
target_cat = TileCatalog(batch["tile_catalog"]).get_brightest_sources_per_tile()

In [None]:
_, x_cat_marginal = bliss_encoder.get_features_and_parameters(batch)

In [None]:
bliss_encoder.var_dist

In [None]:
fp_pairs = bliss_encoder.var_dist._factor_param_pairs(x_cat_marginal)

In [None]:
factor = list(fp_pairs)[0][0]

In [None]:
factor

In [None]:
factor.discrete_sample

In [None]:
factor.sample

In [13]:
bliss_discrete_output_path = output_dir / "bliss_output_discrete_large_split.pkl"
bliss_discrete_grid_output_path = output_dir / "bliss_output_discrete_grid_large_split.pkl"

if not bliss_discrete_output_path.exists():
    for batch_idx, batch in tqdm(enumerate(dataset.test_dataloader()), total=len(dataset.test_dataloader())):
        batch["images"] = batch["images"].to(device)
        bliss_encoder.update_metrics(batch, batch_idx)
    bliss_mode_out_dict = bliss_encoder.mode_metrics.compute()
    bliss_discrete_out_dict = bliss_encoder.discrete_metrics.compute()

    with open(bliss_discrete_output_path, "wb") as outp:  # Overwrites any existing file.
        pickle.dump(bliss_mode_out_dict, outp, pickle.HIGHEST_PROTOCOL)
    with open(bliss_discrete_grid_output_path, "wb") as outp:  # Overwrites any existing file.
        pickle.dump(bliss_discrete_out_dict, outp, pickle.HIGHEST_PROTOCOL)
else:
    with open(bliss_discrete_output_path, "rb") as inputp:
        bliss_mode_out_dict = pickle.load(inputp)
    with open(bliss_discrete_grid_output_path, "rb") as inputp:
        bliss_discrete_out_dict = pickle.load(inputp)
    

  0%|▏                                                                                                           | 9/6250 [00:07<1:28:04,  1.18it/s]


KeyboardInterrupt: 

In [None]:
bliss_output_path = output_dir / "bliss_output_large_split.pkl"

if not bliss_output_path.exists():
    for batch_idx, batch in tqdm(enumerate(dataset.test_dataloader()), total=len(dataset.test_dataloader())):
        batch["images"] = batch["images"].to(device)
        bliss_encoder.update_metrics(batch, batch_idx)
    bliss_out_dict = bliss_encoder.mode_metrics.compute()

    with open(bliss_output_path, "wb") as outp:  # Overwrites any existing file.
        pickle.dump(bliss_out_dict, outp, pickle.HIGHEST_PROTOCOL)
else:
    with open(bliss_output_path, "rb") as inputp:
        bliss_out_dict = pickle.load(inputp)

In [None]:
from matplotlib.ticker import FormatStrFormatter
metrics = ['outlier_fraction_cata', 'outlier_fraction', 'nmad', 'bias_abs', 'mse']
metric_labels = ['Catastrophic Outlier Fraction', 'Outlier Fraction', 'NMAD', 'Absolute Bias', 'MSE']
sns.set_theme()
for i, metric in enumerate(metrics):
    mag_ranges = ['<23.9', '23.9-24.1', '24.1-24.5', '24.5-24.9', '24.9-25.6', '>25.6']
    bliss_values = [bliss_out_dict[f'redshifts/{metric}_bin_{i}'] for i in reversed(range(6))]
    bliss_discrete = [bliss_mode_out_dict[f'redshifts/{metric}_bin_{i}'] for i in range(6)]
    bliss_discrete_grid = [bliss_discrete_out_dict[f'redshifts/{metric}_bin_{i}'] for i in range(6)]

    plt.figure(figsize=(6, 6))
    plt.plot(mag_ranges, bliss_values, label="BLISS+Normal", marker='o', c="blue")
    # plt.plot(mag_ranges, bliss_discrete, label="BLISS+Discrete Bin", marker='o', c="green")
    plt.plot(mag_ranges, bliss_discrete_grid, label="BLISS+Discrete Bin w/ Grid Search", marker='o', c="orange")
    plt.xlabel('Magnitude')
    plt.xticks(rotation=45)
    plt.ylabel(metric_labels[i])
    plt.ylim([0, None])
    ax = plt.gca()
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.grid(True)
    plt.savefig(os.path.join("/home/qiaozhih/bliss/case_studies/redshift/evaluation/plot",f'different_dist_Bliss_{metrics[i]}.pdf'))