In [1]:
import torch
import numpy as np
import math
import os
import pytorch_lightning as pl
import torch
from hydra.utils import instantiate

In [2]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "4"

from pathlib import Path
from hydra import initialize, compose
from bliss.main import predict

environ["BLISS_HOME"] = str("/home/declan/current/bliss")

In [3]:
os.environ["BLISS_HOME"]

'/home/declan/current/bliss'

In [4]:
os.getcwd()

'/home/declan/current/bliss/case_studies/redshift_estimation/notebooks'

I have stored the two checkpoints in the following locations:
   1.  `/data/scratch/declan/sdss_encoder_ckpt.ckpt` for the SDSS-like galaxies
   2.  `/data/scratch/declan/dc2_encoder_ckpt.ckpt` for the DC2-like galaxies

Recall that the corresponding directories containing the data used to train these two encoder are in
   1. `/data/scratch/declan/sdss_like_galaxies` for the SDSS-like galaxies
   2. `/data/scratch/declan/dc2_like_galaxies` for the DC2-like galaxies

Below, I will load the DC2 checkpoint for example. You should change both fields below to the corresponding ones for SDSS-like data. Note I give a checkpoint path and the location of the training data.

In [10]:
with initialize(config_path="../", version_base=None):
    cfg = compose("redshift", {
        "predict.weight_save_path=/data/scratch/declan/dc2_encoder_ckpt.ckpt",
        "cached_simulator.cached_data_path=/data/scratch/declan/dc2_like_galaxies"
        })

This cell will take a while because it loads all of the training data (100 GB)

In [11]:
pl.seed_everything(cfg.train.seed)

# setup dataset and encoder
dataset = instantiate(cfg.train.data_source)
encoder = instantiate(cfg.train.encoder)

Global seed set to 42


We aren't technically in `predict` mode so we need to manually load the checkpoint to the encoder.

In [12]:
PATH = "/data/scratch/declan/dc2_encoder_ckpt.ckpt"
checkpoint = torch.load(PATH)
encoder.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

Your tasks are as follows:
- From the dataset, plot representative example images (you only need a few, like 2-3. You can pick the best ones to include in the write-up).
- Feed these images through the encoder to get predictions for all quantities. Plot predictions of location (e.g. with an "x" market) overlaid on the top of the example images. If stuff is working well, these should be right in the center of the galaxies approximately. Can you think of any clever ways to visualize redshift predictions for one example image?
- Create a scatterplot of predicted redshift vs. true redshift for all data in the training set. You can do an out-of-sample plot for validation data as well.
- Compute metrics such as MSE and NLL averaged across the training dataset. You can also do an out-of-sample plot for validation data as well.


Below, I work with an example observation just to provide a quick example.

In [13]:
train_dataloader = dataset.train_dataloader() # the data we trained on

In [14]:
val_dataloader = dataset.val_dataloader() # didn't train on, but used to choose checkpoint I think

You will have to iterate through the whole dataloader of training and/or validation data.

In [16]:
observation = next(iter(train_dataloader))

In [17]:
observation.keys()

dict_keys(['images', 'background', 'deconvolution', 'psf_params', 'tile_catalog'])

In [19]:
observation['images'].shape

torch.Size([32, 5, 80, 80])

Recall the tile catalog contains the "true" values used to generate the images. 

In [21]:
observation['tile_catalog'].keys()

dict_keys(['locs', 'n_sources', 'source_type', 'galaxy_fluxes', 'galaxy_params', 'star_fluxes', 'redshifts'])

Let's use the untrained encoder for prediction (it should perform very badly). 

In [25]:
est_cat = encoder.sample(observation, use_mode=True) # I'm using the mode to predict

The estimated catalog `est_cat` now contains the predicted values for each quantity. We can compre to the ground truth above.

In [26]:
est_cat = est_cat.to_dict()
est_cat.keys()

dict_keys(['locs', 'n_sources', 'star_fluxes', 'source_type', 'galaxy_params', 'galaxy_fluxes', 'redshifts'])

When `observation` is passed to `encoder`, the encoder ignores the ground truth `observation['tile_catalog']` to make the prediction obviously. But now we can compare the prediction to the ground truth.

In [27]:
torch.round(est_cat['redshifts'][0].reshape(18,18), decimals=2) # for one image

tensor([[1.3300, 1.3700, 1.3200, 1.3300, 1.3500, 1.3100, 1.3500, 1.3100, 1.3400,
         1.3600, 1.3600, 1.3700, 1.3200, 1.3200, 1.3500, 1.3800, 1.3800, 1.4900],
        [1.3500, 1.3700, 1.3100, 1.3400, 1.3600, 1.3300, 1.4000, 1.3800, 1.3600,
         1.3400, 1.3800, 1.3600, 1.3700, 1.3200, 1.3600, 1.4100, 1.4000, 1.4200],
        [1.3500, 1.3500, 1.3200, 1.3300, 1.3300, 1.3300, 1.3800, 1.4000, 1.3600,
         1.3300, 1.4000, 1.3400, 1.3500, 1.3400, 1.3500, 1.3900, 1.3500, 1.3300],
        [1.3300, 1.3000, 1.4000, 1.3100, 1.3900, 1.3200, 1.3500, 1.3800, 1.3300,
         1.3500, 1.4100, 1.2900, 1.4300, 1.3800, 1.3700, 1.3500, 1.4100, 1.3500],
        [1.3500, 1.3500, 1.2900, 1.3500, 1.3500, 1.3300, 1.3400, 1.3400, 1.4200,
         1.3800, 1.3800, 1.3400, 1.4000, 1.4200, 1.3300, 1.3300, 1.3700, 1.3500],
        [1.3600, 1.2900, 1.3300, 1.3300, 1.3200, 1.3300, 1.3300, 1.3400, 1.3400,
         1.3400, 1.3300, 1.3100, 1.3300, 1.3200, 1.3400, 1.3500, 1.3500, 1.3300],
        [1.3600, 1.280

In [29]:
observation['tile_catalog']['redshifts'][0].shape

torch.Size([20, 20, 1, 1])

The shapes are off, as we see that our estimated catalog is 18x18. This is because when training on the true images, the edge is filtered out. We can ignore it as well. 

In [30]:
true_redshifts = observation['tile_catalog']['redshifts'][0][1:-1, 1:-1].reshape((18,18))

In [31]:
true_redshifts

tensor([[0.4424, 1.2134, 2.8498, 1.9530, 0.5773, 0.5007, 0.9868, 2.4184, 1.1428,
         0.4347, 0.5634, 1.6868, 0.4143, 1.6364, 2.8903, 0.5156, 1.9588, 0.4310],
        [0.6704, 2.3828, 0.9407, 2.6686, 0.5478, 1.0735, 0.5573, 2.7453, 2.8398,
         2.0776, 1.4965, 1.2019, 2.1638, 0.7854, 2.0735, 1.5511, 1.1001, 0.8776],
        [0.1326, 1.4051, 2.2161, 1.8173, 1.0458, 2.9106, 2.1056, 1.1803, 0.1648,
         1.1238, 1.2248, 1.7003, 1.2856, 0.6593, 0.9938, 2.2289, 2.4001, 2.4243],
        [0.8146, 1.7009, 0.6363, 0.9262, 1.6455, 1.0841, 0.5214, 1.3127, 0.4255,
         0.6747, 2.0102, 2.2205, 1.0597, 1.7882, 0.8592, 2.5242, 0.7988, 0.8940],
        [0.6264, 0.7303, 0.6737, 0.8170, 1.4574, 1.8295, 1.2547, 1.5315, 0.9660,
         0.7526, 0.7736, 1.8675, 0.8731, 0.9475, 1.9461, 2.3829, 1.4235, 1.3090],
        [1.2252, 0.9115, 1.0557, 0.9719, 0.5587, 0.5436, 0.5807, 0.3490, 0.3271,
         2.5016, 1.2764, 0.5977, 1.4651, 2.8698, 0.4914, 1.2269, 1.1538, 0.4595],
        [2.0515, 2.591

At a glance, looks pretty poor. I suspect SDSS is better and something simply went wrong with DC2 (we'll investigate).