In [1]:
import torch
import numpy as np
import math
import os
import pytorch_lightning as pl
import torch
from hydra.utils import instantiate

In [2]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "4"

from pathlib import Path
from hydra import initialize, compose
from bliss.main import predict

environ["BLISS_HOME"] = str("/home/declan/current/bliss")

In [3]:
os.environ["BLISS_HOME"]

'/home/declan/current/bliss'

In [4]:
os.getcwd()

'/home/declan/current/bliss/case_studies/redshift_estimation/notebooks'

In [5]:
with initialize(config_path="../", version_base=None):
    cfg = compose("redshift", {
        "predict.weight_save_path=/home/declan/current/bliss/redshift_output/version_4/checkpoints/best_encoder.ckpt"
        })

We'd really like to call the `predict` function from `bliss.main` here. Still working on getting that going. In the meantime, we can "initialize" an encoder using the config `redshift.yaml` and load the weights manually like so. The cell below will take a while as I still have it loading all of the data, which probably overkill.

In [6]:
pl.seed_everything(cfg.train.seed)

# setup dataset and encoder
# taken from train in main.py
dataset = instantiate(cfg.train.data_source)
encoder = instantiate(cfg.train.encoder)

Global seed set to 42


In [7]:
type(dataset)

bliss.simulator.simulated_dataset.CachedSimulatedDataset

In [8]:
type(encoder)

bliss.encoder.encoder.Encoder

We'll access the test Dataloader that CachedSimulatedDataset can construct for us.

In [9]:
test_dataloader = dataset.test_dataloader()
type(test_dataloader)

torch.utils.data.dataloader.DataLoader

Let's access some observations from the test dataloader.

In [10]:
observation = next(iter(test_dataloader))

In [11]:
observation.keys()

dict_keys(['images', 'background', 'deconvolution', 'psf_params', 'tile_catalog'])

In [12]:
observation['images'].shape

torch.Size([32, 5, 80, 80])

In [13]:
observation['tile_catalog'].keys()

dict_keys(['locs', 'n_sources', 'source_type', 'galaxy_fluxes', 'galaxy_params', 'star_fluxes', 'redshifts'])

Let's use the untrained encoder for prediction (it should perform very badly). 

In [14]:
est_cat = encoder.sample(observation, use_mode=True)

In [15]:
est_cat = est_cat.to_dict()
est_cat.keys()

dict_keys(['locs', 'n_sources', 'star_fluxes', 'source_type', 'galaxy_params', 'galaxy_fluxes', 'redshifts'])

When `observation` is passed to `encoder`, the encoder ignores the ground truth `observation['tile_catalog']` to make the prediction obviously. But now we can compare the prediction to the ground truth.

In [16]:
est_cat['redshifts'].shape # 32 x 18 x 18 x1
torch.round(est_cat['redshifts'][0].reshape(18,18), decimals=2)

tensor([[ 0.0200, -0.0500, -0.0300, -0.0900,  0.0300,  0.1700, -0.3000,  0.2500,
          0.0400, -0.0200,  0.0800,  0.0000,  0.0400, -0.1400, -0.1800, -0.0900,
          0.0100, -0.1600],
        [-0.0300, -0.0200,  0.0500,  0.0100, -0.0600,  0.0300, -0.1800, -1.0200,
          0.4100,  0.1600, -0.0100,  0.1900,  0.1600,  0.2200, -0.0300, -0.4300,
          0.2100,  0.2200],
        [-0.0300, -0.0500, -0.0800, -0.0200, -0.0500,  0.2100, -0.1500,  0.0200,
          0.1400,  0.3500,  0.0900,  0.4000,  0.1000,  0.0700,  0.1900,  0.1200,
          0.1900, -0.1800],
        [-0.0300, -0.0300,  0.0300,  0.2100,  0.6000, -1.5300, -0.1200,  0.2800,
         -0.5500, -0.3900, -0.0900, -0.0000,  0.1200,  0.1100, -0.1700,  0.1700,
         -0.0400,  0.0400],
        [-0.0200, -0.0200,  0.0500, -0.0700,  0.0900, -0.4200, -0.2400, -0.4900,
         -0.5500,  0.0300, -0.1300, -0.1500, -0.4700, -0.3700,  0.1900,  0.0300,
         -0.1400,  0.1300],
        [-0.0000, -0.0200, -0.1100, -0.0200, -0.14

So it does terribly as expected. But we can load the weights and it should perform much better.

In [17]:
PATH = "/home/declan/current/bliss/redshift_output/version_4/checkpoints/best_encoder.ckpt"
checkpoint = torch.load(PATH)
encoder.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [18]:
est_cat = encoder.sample(observation, use_mode=True)
est_cat = est_cat.to_dict()
torch.round(est_cat['redshifts'][0].reshape(18,18), decimals=4)

tensor([[0.9982, 0.9962, 0.9970, 0.9988, 0.9977, 0.9975, 0.9971, 0.9983, 0.9974,
         0.9973, 0.9973, 0.9974, 0.9984, 0.9986, 0.9977, 0.9977, 0.9970, 0.9980],
        [0.9979, 0.9983, 0.9976, 0.9972, 0.9979, 0.9978, 0.9978, 0.9989, 0.9975,
         0.9980, 0.9979, 0.9969, 0.9979, 0.9986, 0.9974, 0.9984, 0.9970, 0.9981],
        [0.9979, 0.9978, 0.9983, 0.9990, 0.9984, 0.9973, 0.9957, 0.9974, 0.9980,
         0.9982, 0.9978, 0.9969, 0.9972, 0.9970, 0.9978, 0.9979, 0.9971, 0.9978],
        [0.9989, 0.9989, 0.9980, 0.9976, 0.9974, 0.9964, 0.9979, 0.9972, 0.9977,
         0.9989, 0.9977, 0.9978, 0.9993, 0.9972, 0.9973, 0.9965, 0.9977, 0.9979],
        [0.9985, 0.9983, 0.9980, 0.9985, 0.9977, 0.9982, 0.9981, 0.9974, 0.9977,
         0.9983, 0.9983, 0.9978, 0.9972, 0.9980, 0.9989, 0.9967, 0.9975, 0.9988],
        [0.9972, 0.9976, 0.9976, 0.9987, 0.9972, 0.9974, 0.9970, 0.9974, 0.9981,
         0.9997, 0.9981, 0.9978, 0.9972, 0.9978, 0.9985, 0.9972, 0.9974, 0.9972],
        [0.9973, 0.997

Looking pretty good!