In [None]:
%load_ext autoreload
%autoreload 2
%aimport

%matplotlib inline 

import torch
import matplotlib.pyplot as plt
from einops import rearrange

from bliss.catalog import TileCatalog, get_images_in_tiles
from bliss.models.galsim_decoder import PsfSampler



# check GPU is configured correctly
device = torch.device('cuda:0')
!echo $CUDA_VISIBLE_DEVICES

In [None]:
# load models
from hydra import compose, initialize
from hydra.utils import instantiate

with initialize(config_path="../config"):
    cfg = compose("config", overrides=[])

In [None]:
def validation_step(self, batch, image_str):
    catalog_dict = {
        "locs": batch["locs"][:, :, :, 0 : self.max_detections],
        "log_fluxes": batch["log_fluxes"][:, :, :, 0 : self.max_detections],
        "galaxy_bools": batch["galaxy_bools"][:, :, :, 0 : self.max_detections],
        "n_sources": batch["n_sources"].clamp(max=self.max_detections),
    }
    true_tile_catalog = TileCatalog(self.tile_slen, catalog_dict)
    true_full_catalog = true_tile_catalog.to_full_params()
    image_ptiles = get_images_in_tiles(
        torch.cat((batch[image_str], batch["background"]), dim=1),
        self.tile_slen,
        self.ptile_slen,
    )
    image_ptiles = rearrange(image_ptiles, "n nth ntw b h w -> (n nth ntw) b h w")
    dist_params = self.encode(image_ptiles)
    est_catalog_dict = self.variational_mode(dist_params)
    est_tile_catalog = TileCatalog.from_flat_dict(
        true_tile_catalog.tile_slen,
        true_tile_catalog.n_tiles_h,
        true_tile_catalog.n_tiles_w,
        est_catalog_dict,
    )
    est_full_catalog = est_tile_catalog.to_full_params()

    metrics = self.val_detection_metrics(true_full_catalog, est_full_catalog)
    return metrics

In [None]:
from case_studies.psf_homogenization.galsim_decoder_std import GalsimBlendsSGwithPSF
prior = instantiate(cfg.datasets.galsim_blended_std_psf.prior) 
decoder = instantiate(cfg.datasets.galsim_blended_std_psf.decoder)
background = instantiate(cfg.datasets.galsim_blended_std_psf.background)
tile_slen = 4
max_tile_n_sources = 1
num_workers = 5
batch_size = 10000
n_batches = 1
psf_sampler = PsfSampler(0.396, "gaussian",1.9, 1.9)

ds_psf = GalsimBlendsSGwithPSF(prior, decoder, background, tile_slen, max_tile_n_sources, num_workers=0, batch_size=10, n_batches=1, psf_sampler=psf_sampler, std_psf_fwhm=1.5)

In [None]:
DEVICE = "cuda:0"
PATH = "/home/wangchv/bliss/case_studies/psf_homogenization/output/sdss_detection_encoder_full_decoder_SG_std/std 1.5 from rand 1.5/checkpoints/epoch=989-val_loss=-0.219.ckpt"
detection_encoder_homo = instantiate(cfg.models.detection_encoder).eval()
model_checkpoint = torch.load(PATH, map_location=DEVICE)
model_state_dict = model_checkpoint["state_dict"]
detection_encoder_homo.load_state_dict(model_state_dict)
#detection_encoder_homo = detection_encoder_homo.to(device)

In [None]:
DEVICE = "cuda:0"
PATH = "/home/wangchv/bliss/case_studies/psf_homogenization/output/sdss_detection_encoder_full_decoder_SG_rand/rand 1.5/checkpoints/epoch=879-val_loss=-0.225.ckpt"
detection_encoder_unhomo = instantiate(cfg.models.detection_encoder).eval()
model_checkpoint = torch.load(PATH, map_location=DEVICE)
model_state_dict = model_checkpoint["state_dict"]
detection_encoder_unhomo.load_state_dict(model_state_dict)
#detection_encoder_unhomo = detection_encoder_unhomo.to(device) # transfer to GPU

In [None]:
for x in ds_psf.val_dataloader():
    #detection_encoder_homo_device = detection_encoder_homo.to(device)
    #x['images'] = x['images'].to(device)
    #x['background'] = x['background'].to(device)
    #results = validation_step(detection_encoder_homo, x, 'images')
    print(validation_step(detection_encoder_homo, x, "images"))
    print(validation_step(detection_encoder_unhomo, x, "noisy_image"))
    #del detection_encoder_homo_device
    #del x
    #gc.collect()
    #torch.cuda.empty_cache()
    break

In [None]:
#generate images
for i in range(11):
    t = 0.5 + 0.1 * i
    psf_sampler = PsfSampler(0.396, "gaussian", t, t)
    ds_psf = GalsimBlendsSGwithPSF(prior, decoder, background, tile_slen, max_tile_n_sources, num_workers=0, batch_size=5, n_batches=1, psf_sampler=psf_sampler, std_psf_fwhm=0.8, valid_n_batches=1)
    
    for x in ds_psf.val_dataloader():
        print(t)
        plt.figure()
        plt.subplot(121)
        plt.imshow(x['images'][0, 0].numpy()) # plot first figure of each batch.
        plt.subplot(122)
        plt.imshow(x['noisy_image'][0, 0].numpy())
        plt.show()
    plt.close()

In [None]:
#draw plots
import matplotlib.pyplot as plt
import numpy as np
psfs = np.arange(0.5, 1.6, 0.1)
homo = np.array([0.963, 0.963, 0.959, 0.946, 0.893, 0.826, 0.749, 0.682, 0.631, 0.520, 0.466])
unhomo = np.array([0.967, 0.967, 0.963, 0.944, 0.881, 0.777, 0.684, 0.598, 0.537, 0.464, 0.418])
unhomo_std = np.array([0.963, 0.969, 0.961, 0.961, 0.953, 0.948, 0.932, 0.918, 0.905, 0.897, 0.890])
unhomo_rand15 = np.array([0.868, 0.887, 0.909, 0.915, 0.927, 0.927, 0.922, 0.920, 0.904, 0.904, 0.890])
homo_rand15_std08 = np.array([0.900, 0.901, 0.900, 0.897, 0.902, 0.904, 0.902, 0.903, 0.895, 0.881, 0.882])
homo_rand15_std15 = np.array([0.899, 0.896, 0.891, 0.903, 0.889, 0.893, 0.891, 0.894, 0.890, 0.886, 0.881])
plt.plot(psfs, homo, "o-", label="homo model std 0.8 rand 0.7")
plt.plot(psfs, unhomo, "o-", label="unhomo model rand 0.7")
plt.plot(psfs, unhomo_std, "o-", label="std unhomo models")
plt.plot(psfs, unhomo_rand15, "o-", label="unhomo model rand 1.5")
plt.plot(psfs, homo_rand15_std08, "o-", label="homo model std 0.8 rand 1.5")
plt.plot(psfs, homo_rand15_std15, "o-", label="homo model std 1.5 rand 1.5")
plt.xlabel("dataset psf")
plt.ylabel("f1")
plt.legend()
plt.title("Precision of homo vs unhomo model")
plt.show()

In [None]:
psfs2 = np.arange(0.5, 2.1, 0.1)
unhomo_rand15 = np.array([0.868, 0.887, 0.909, 0.915, 0.927, 0.927, 0.922, 0.920, 0.904, 0.904, 0.890, 0.859, 0.824, 0.724, 0.626, 0.500])
homo_rand15_std15 = np.array([0.899, 0.896, 0.891, 0.903, 0.889, 0.893, 0.891, 0.894, 0.890, 0.886, 0.881, 0.855, 0.832, 0.769, 0.729, 0.660])
plt.plot(psfs2, unhomo_rand15, "o-", label="unhomo model 1.5")
plt.plot(psfs2, homo_rand15_std15, "o-", label="homo model std 1.5 rand 1.5")
plt.xlabel("dataset psf")
plt.ylabel("f1")
plt.legend()
plt.title("Precision of homo vs unhomo model")
plt.show()