In [1]:
%load_ext autoreload
%autoreload 2
%aimport

%matplotlib inline 

import math

import torch
import matplotlib.pyplot as plt
from einops import rearrange
from typing import Union, Dict, Optional
from torch import Tensor, nn
from torch.distributions import Categorical, Normal
from torch.nn import functional as F

from bliss.catalog import TileCatalog, get_images_in_tiles, get_is_on_from_n_sources
from bliss.reporting import DetectionMetrics
from case_studies.psf_homogenization.psf_decoder import PsfSampler, GalsimBlendswithPSF
from bliss.models.detection_encoder import (
    DetectionEncoder,
    LogBackgroundTransform,
    ConcatBackgroundTransform,
    EncoderCNN,
    make_enc_final,
)


# check GPU is configured correctly
device = torch.device('cuda:0')
!echo $CUDA_VISIBLE_DEVICES

Modules to reload:
all-except-skipped

Modules to skip:

5


In [2]:
# load models
from hydra import compose, initialize
from hydra.utils import instantiate
from bliss.encoder import Encoder

with initialize(config_path="../config"):
    cfg = compose("config", overrides=[])

In [3]:
# set up test data
prior = instantiate(cfg.datasets.galsim_blended_galaxies_psf.prior) 
decoder = instantiate(cfg.datasets.galsim_blended_galaxies_psf.decoder)
background = instantiate(cfg.datasets.galsim_blended_galaxies.background)
tile_slen = 4
max_tile_n_sources = 1
num_workers = 5
batch_size = 10000
n_batches = 1
psf_sampler = PsfSampler(0.8, 0.8)

ds_psf = GalsimBlendswithPSF(prior, decoder, background, tile_slen, max_tile_n_sources, num_workers=0, batch_size=1000, n_batches=1, psf_sampler=psf_sampler, std_psf_fwhm=1.0, valid_n_batches=1)



In [3]:
def validation_step(self, batch, image_str):
    catalog_dict = {
        "locs": batch["locs"][:, :, :, 0 : self.max_detections],
        "log_fluxes": batch["log_fluxes"][:, :, :, 0 : self.max_detections],
        "galaxy_bools": batch["galaxy_bools"][:, :, :, 0 : self.max_detections],
        "n_sources": batch["n_sources"].clamp(max=self.max_detections),
    }
    true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict)
    true_full_catalog = true_tile_catalog.to_full_params()
    image_ptiles = get_images_in_tiles(
        torch.cat((batch[image_str], batch["background"]), dim=1),
        self.tile_slen,
        self.ptile_slen,
    )
    image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w")
    dist_params = self.encode(image_ptiles)
    est_catalog_dict = self.variational_mode(dist_params)
    est_tile_catalog = TileCatalog.from_flat_dict(
        true_tile_catalog.tile_slen,
        true_tile_catalog.n_tiles_h,
        true_tile_catalog.n_tiles_w,
        est_catalog_dict,
    )
    est_full_catalog = est_tile_catalog.to_full_params()

    metrics = self.val_detection_metrics(true_full_catalog, est_full_catalog)
    return metrics


In [5]:
DEVICE = "cuda:0"
PATH = "/home/wangchv/bliss/case_studies/psf_homogenization/output/sdss_detection_encoder_full_decoder_homo/std=1.0/checkpoints/epoch=989-val_loss=-0.015.ckpt"
detection_encoder_homo = instantiate(cfg.models.detection_encoder).eval()
model_checkpoint = torch.load(PATH, map_location=DEVICE)
model_state_dict = model_checkpoint["state_dict"]
detection_encoder_homo.load_state_dict(model_state_dict)

<All keys matched successfully>

In [6]:
DEVICE = "cuda:0"
PATH = "/home/wangchv/bliss/case_studies/psf_homogenization/output/sdss_detection_encoder_full_decoder_unhomo/1.0~1.0/checkpoints/epoch=829-val_loss=-0.006.ckpt"
detection_encoder_unhomo = instantiate(cfg.models.detection_encoder).eval()
model_checkpoint = torch.load(PATH, map_location=DEVICE)
model_state_dict = model_checkpoint["state_dict"]
detection_encoder_unhomo.load_state_dict(model_state_dict)

<All keys matched successfully>

In [7]:
for x in ds_psf.val_dataloader():
    print(validation_step(detection_encoder_homo, x, "images"))
    print(validation_step(detection_encoder_unhomo, x, "noisy_image"))

{'tp': tensor(2387), 'fp': tensor(166), 'precision': tensor(0.9350), 'recall': tensor(0.6921), 'f1': tensor(0.7954), 'avg_distance': tensor(0.4657, grad_fn=<DivBackward0>), 'n_galaxies_detected': tensor(2387)}
{'tp': tensor(2315), 'fp': tensor(160), 'precision': tensor(0.9354), 'recall': tensor(0.6724), 'f1': tensor(0.7824), 'avg_distance': tensor(0.5780, grad_fn=<DivBackward0>), 'n_galaxies_detected': tensor(2315)}


In [4]:
from case_studies.psf_homogenization.galsim_blends_sg import GalsimBlendsSGwithPSF
prior = instantiate(cfg.datasets.galsim_blended_std_psf.prior) 
decoder = instantiate(cfg.datasets.galsim_blended_std_psf.decoder)
background = instantiate(cfg.datasets.galsim_blended_std_psf.background)
tile_slen = 4
max_tile_n_sources = 1
num_workers = 5
batch_size = 10000
n_batches = 1
psf_sampler = PsfSampler(1.0, 1.0)

ds_psf = GalsimBlendsSGwithPSF(prior, decoder, background, tile_slen, max_tile_n_sources, num_workers=0, batch_size=1000, n_batches=1, psf_sampler=psf_sampler, std_psf_fwhm=1.0, valid_n_batches=1)


In [5]:
DEVICE = "cuda:0"
PATH = "/home/wangchv/bliss/case_studies/psf_homogenization/output/sdss_detection_encoder_full_decoder_SG_std/version_0/checkpoints/epoch=859-val_loss=0.028.ckpt"
detection_encoder_homo = instantiate(cfg.models.detection_encoder).eval()
model_checkpoint = torch.load(PATH, map_location=DEVICE)
model_state_dict = model_checkpoint["state_dict"]
detection_encoder_homo.load_state_dict(model_state_dict)

<All keys matched successfully>

In [6]:
DEVICE = "cuda:0"
PATH = "/home/wangchv/bliss/case_studies/psf_homogenization/output/sdss_detection_encoder_full_decoder_SG_rand/version_0/checkpoints/epoch=999-val_loss=0.029.ckpt"
detection_encoder_unhomo = instantiate(cfg.models.detection_encoder).eval()
model_checkpoint = torch.load(PATH, map_location=DEVICE)
model_state_dict = model_checkpoint["state_dict"]
detection_encoder_unhomo.load_state_dict(model_state_dict)

<All keys matched successfully>

In [7]:
for x in ds_psf.val_dataloader():
    print(validation_step(detection_encoder_homo, x, "images"))
    print(validation_step(detection_encoder_unhomo, x, "noisy_image"))

{'tp': tensor(245), 'fp': tensor(17), 'precision': tensor(0.9351), 'recall': tensor(0.9919), 'f1': tensor(0.9627), 'avg_distance': tensor(0.0994, grad_fn=<DivBackward0>), 'n_galaxies_detected': tensor(159)}
{'tp': tensor(247), 'fp': tensor(16), 'precision': tensor(0.9392), 'recall': tensor(0.9960), 'f1': tensor(0.9667), 'avg_distance': tensor(0.1035, grad_fn=<DivBackward0>), 'n_galaxies_detected': tensor(161)}
