### Outline

The goal of this notebook is to illustrate the full BLISS pipeline, briefly. We will identify several of the main data objects used along the way, but avoid mathematical intricacies and details of the training objective. Briefly, we'll have

1. Generation of synthetic data
2. Training of the encoder network
3. *(not currently in this notebook)* Validation/evaluation of the encoder network on held-out data (either simulated or real).

---

#### 1. Generation of synthetic data

This would typically be done using something like ```bliss mode=generate``` from the command line. This calls the ```generate``` function within ```bliss/main.py```, using the default ```DictConfig``` object specified by the various ```.yaml``` files.

The code below loads the file ```m2_config.yaml``` from ```case_studies/dependent_tiling``` as a DictConfig for use by ```hydra```. You may have to change some absolute and relative paths to get this to load for you.

In [None]:
import sys
import os

In [None]:
os.getcwd()

In [None]:
# Change twhit to your username
os.chdir('/home/twhit/bliss')
os.getcwd()

In [None]:
from bliss.encoder.variational_dist import VariationalDistSpec, VariationalDist
from bliss.encoder.unconstrained_dists import UnconstrainedNormal
import torch
import pytorch_lightning as pl
import numpy as np
from os import environ
from pathlib import Path
from hydra import initialize, compose
from hydra.utils import instantiate
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf
from bliss.catalog import TileCatalog

In [None]:
# Change twhit to your username
environ["BLISS_HOME"] = "/home/twhit/bliss"
with initialize(config_path="../../case_studies/dependent_tiling", version_base=None):
    cfg = compose("m2_config", overrides={"surveys.sdss.load_image_data=true"})

In [None]:
type(cfg)

In [None]:
print(OmegaConf.to_yaml(cfg, resolve=False, sort_keys=False))

You can browse the above printouts to get a feel for how the config is structured. Our project will eventually add some configurables and we'll have our own config similar to the above.

The ```generate``` function from ```bliss/main.py``` takes arguments as

```
def generate(gen_cfg: DictConfig):
   ...

```

and so we can plug in a given ```DictConfig``` like the above to generate data. The true ```generate``` function is much more complex than what is given below, and helps cache previously simulated data to save time, etc. Our altered function below is used only for illustration purposes, where we'll generate a single batch of simulated data.

In [None]:
cfg.generate

In [None]:
def generate(gen_cfg: DictConfig):
    simulated_dataset = instantiate(gen_cfg.simulator, num_workers=0)

    for _ in range(1):
        batch = simulated_dataset.get_batch()
        
    return batch

In [None]:
# Running this cell will take a few minutes
simulated_batch_of_data = generate(cfg.generate)

Note that the 'full' configurable has three main subconfigs: `cfg.generate, cfg.train, cfg.predict`. This just helps keep things more organized. We only passed `cfg.generate` to the ```generate``` function.

In [None]:
simulated_batch_of_data.keys()

The simulated batch is a ```dict``` object. These can be stored to disk in some fashion that we won't worry about right now (see ```generate``` function in ```bliss/main.py```), as we'll just work with this single batch. Let's examine some of the objects.

In [None]:
tc = simulated_batch_of_data['tile_catalog']
tc.keys()

In [None]:
tc['locs'].shape

In [None]:
tc['n_sources'].shape

In [None]:
tc['source_type'].shape

In [None]:
tc['galaxy_fluxes'].shape

In [None]:
tc['galaxy_params'].shape

The first three dimensions of all these objects are 32 x 56 x 56. These numbers represent the following:
- 32 = number of synthetic (simulated) images
- 56 = number of tiles lengthwise
- 56 = number of tiles widthwise

BLISS operates by dividing a given images into *tiles* of a certain size. This can be thought of as parsing the image into bite-size pieces. The number of tiles and number of images in simulated batches are controllable from the config object, e.g. 

In [None]:
cfg.prior.batch_size, cfg.prior.n_tiles_h, cfg.prior.n_tiles_w

BLISS constrains the number of objects per tile to be between 0 and 5. Tiles are small enough (in terms of pixel size) to make this reasonable. Hence if we examine ```tc['n_sources']``` we see that this tensor specifies the number of sources for each tile in each image.

In [None]:
tc['n_sources'][0] #56 x 56 tensor telling us the number of sources in each tile for the first image.

In [None]:
tc['n_sources'].max(), tc['n_sources'].min()

Each object (or source) can be one of several different types (although I think it's generally either a star or a galaxy). There are at most 5 sources per tile. Hence ```tc['source_type']``` has a per-tile shape of 5 x 1. For each tile, this tells us the type of each source. If there are fewer than 5 sources, the extra information is discarded or masked somehow. This number 5 that is floating around all the other dimensions arises similarly from the fact that we have at most 5 sources per tile. We can gather from the shapes above that ```locs``` contains a 2d coordinate for each source; ```galaxy_fluxes``` contains 5 parameters for each source; ```galaxy_params``` contains 6 parameters for each source, etc. 

***All of the these are latent random variables $z$ that are used to generate the image. Given a tile catalog like the above, we have all the information necessary to generate synthetic images $x$. The inference problem is then to take an image $x$ and construct a distribution on all of these quantities $z$. In other words, given an image, we divide it into tiles and for each tile we aim to recover the number of sources, the type of each source, the locations of each, the fluxes for each, etc.***

Let's examine a sample synthetic image below.

In [None]:
ims = simulated_batch_of_data['images']

In [None]:
ims.shape

As expected, there are 32 images. The dimension 5 does not correspond to the number of sources per tile, but rather to the *photometric band* of each image, corresponding to u,g,r,i,z for SDSS data. Read more here: https://www.sdss4.org/instruments/camera/#Filters

Again, these are specified in the config:

In [None]:
cfg.prior.survey_bands

We gather that each image for each band is 112 x 112. Recalling that we have 56 x 56 tiles, this implies that each tile is 2x2. Again, this is specified in the config.

In [None]:
cfg.prior.tile_slen

Let's examine some simulated images.

In [None]:
plt.imshow(ims[0][0]) #u-band for first of 32 images

In [None]:
plt.imshow(ims[0][1]) #g-band for first of 32 images

In [None]:
plt.imshow(ims[0][2]) #r-band for first of 32 images

In [None]:
plt.imshow(ims[0][3]) #i-band for first of 32 images

In [None]:
plt.imshow(ims[0][4]) #z-band for first of 32 images

We notice that some objects appear not to show up in some bands, or appear more faint in some bands than others. This is normal: objects emit light at different wavelengths, and so in some wavelength ranges (e.g. a particular band) an object may not appear to be visible at all.

We won't worry about the other keys of `simulated_batch_of_data` for now (`background`, `deconvolution`, and `psf_params`). We can explore these later as needed.

---

#### 2. Training the encoder

The encoder $q_{\phi}(z \mid x)$ learns a *variational distribution* on $z$ conditional on an observed $x$. We use machine learning and amortization to automate this process: Given any $x$, we get a distribution on $z$ by passing $x$ through a neural network (whose parameters are $\phi$). For us, $z$ is a complicated object consisting of all the parameters in the tile catalog above (or more). 

The training objective is given by

$$
\max_\phi \thinspace \Bigl[\mathbb{E}_{p(z,x)} \log q_\phi(z \mid x)\Bigr].
$$

In words, we want to find the neural network parameters $\phi$ that maximize the variational density across all $z,x$ from the generative model $p(z,x)$. For us, the generative model is given by 

$$
p(z,x) = p(z) p(x \mid x)
$$

where $p(z)$ is the prior, and $p(x \mid x)$ generates images from the prior. We won't focus much at all on $p(z \mid x)$. Instead, we'll focus on the prior, and we will alter the tile catalog $z$ (e.g., to include new parameters that are specific to weak lensing, such as shear and convergence). The information for the prior is again given by the config:

In [None]:
print(OmegaConf.to_yaml(cfg.prior, resolve=False, sort_keys=False))

In [None]:
cfg.prior.galaxy_a_bd_ratio

These numbers can be considered as hyperparameters that define the prior. We don't actually need to infer these numbers precisely. Rather, for a tile catalog $z$ sampled from the prior $p(z)$, we want to infer $z$ given its corresponding image $x$.

Training is performed using the ```train``` function from ```bliss/main.py```. As above, we pass `cfg.train` to the `train` function, which is reproduced below (with some lines omitted).

In [None]:
def train(train_cfg: DictConfig):
    # setup seed
    pl.seed_everything(train_cfg.seed)

    # setup dataset, encoder, and trainer
    dataset = instantiate(train_cfg.data_source)
    encoder = instantiate(train_cfg.encoder)
    trainer = instantiate(train_cfg.trainer)

    # train!
    trainer.fit(encoder, datamodule=dataset)

Let's examine the three main objects that the training procedure evidently uses: a dataset, an encoder, and a trainer. Again, these are all specified by the config, and instatiated in the train function above.

In [None]:
cfg.train.data_source

We see that the data-source is a cached dataset. In other words, lots of saved $z,x$ pairs from the generative model $p(z,x)$ that have been written to disk previously. These are used to fit the objective function approximately — i.e.,

$$
\max_\phi \frac{1}{n} \sum_{i=1}^n \log q_\phi(z_i \mid x_i),
$$

where $n$ is the number of cached $z,x$ pairs that have been saved. For us, we only have one pair of $z,x$. We'll evidently need to convert these to type `CachedSimulatedDataset` to fit in with the framework.

In [None]:
cfg.train.encoder

All these configurables will instantiate an `Encoder` object. This object will perform the function described above, i.e. given an image $x$ it will give us a distribution on the tile catalog $z$. 

In [None]:
cfg.train.trainer

The trainer evidently wraps up the training procedure with lots of information about logging, checkpoints, metrics, etc. We won't worry so much about this for now.

Our simplified version of training for illustrative purposes will try to unwrap some of the abstraction above. Let's first instantiate the encoder so we can examine it.

In [None]:
encoder = instantiate(cfg.train.encoder)
encoder

If you're familiar with PyTorch, you'll recognize that the encoder is essentially a huge neural network as described above, with some fancy preprocessing, normalization, metrics, etc. Let's examine some of the `Encoder` object's methods, which are reproduced below.

```
def _single_detection_nll(self, target_cat, pred):
    marginal_loss = pred["marginal"].compute_nll(target_cat)

    if not self.use_checkerboard:
        return marginal_loss

    white_loss = pred["white"].compute_nll(target_cat)
    white_loss_mask = 1 - pred["white_history_mask"]
    white_loss *= white_loss_mask

    black_loss = pred["black"].compute_nll(target_cat)
    black_loss_mask = pred["white_history_mask"]
    black_loss *= black_loss_mask

    # we divide by two because we score two predictions for each tile
    return (marginal_loss + white_loss + black_loss) / 2

def _double_detection_nll(self, target_cat1, target_cat, pred):
    target_cat2 = target_cat.get_brightest_sources_per_tile(band=2, exclude_num=1)

    nll_marginal_z1 = self._single_detection_nll(target_cat1, pred)
    nll_cond_z2 = pred["second"].compute_nll(target_cat2)
    nll_marginal_z2 = self._single_detection_nll(target_cat2, pred)
    nll_cond_z1 = pred["second"].compute_nll(target_cat1)

    none_mask = target_cat.n_sources == 0
    loss0 = nll_marginal_z1 * none_mask

    one_mask = target_cat.n_sources == 1
    loss1 = (nll_marginal_z1 + nll_cond_z2) * one_mask

    two_mask = target_cat.n_sources >= 2
    loss2a = nll_marginal_z1 + nll_cond_z2
    loss2b = nll_marginal_z2 + nll_cond_z1
    lse_stack = torch.stack([loss2a, loss2b], dim=-1)
    loss2_unmasked = -torch.logsumexp(-lse_stack, dim=-1)
    loss2 = loss2_unmasked * two_mask

    return loss0 + loss1 + loss2
```

Here, NLL stands for negative log likelihood, and for us this corresponds to $-\log q_\phi(z \mid x)$. We try to minimize this quantity, which is equivalent to maximizing $\log q_\phi(z \mid x)$ (as we formulated above). Let's compute the NLL loss for the encoder. This should be poor because the encoder has been initialized only, not trained at all. The following code snippets are adapated from the `_compute_loss` function of the `Encoder` class in `bliss/encoder/encoder.py`.

In [None]:
batch = simulated_batch_of_data #renaming to something shorter
batch_size = batch["images"].size(0)
target_cat = TileCatalog(encoder.tile_slen, batch["tile_catalog"])

In [None]:
target_cat

The object `target_cat` is the "target catalog" of interest. It's of type TileCatalog that we create from the simulated batch.

In [None]:
# filter out undetectable sources
if encoder.min_flux_threshold > 0:
    target_cat = target_cat.filter_tile_catalog_by_flux(min_flux=encoder.min_flux_threshold)
    
# make predictions/inferences
target_cat1 = target_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)
truth_callback = lambda _: target_cat1
pred = encoder.infer(batch, truth_callback)

The above code performs some preprocessing on a per-tile basis to eliminate dim objects, etc. We don't need to understand all the details right now. The `infer` method of the `Encoder` object on the last line operates directly on the images of the batch i.e. `batch['images']`. It is a complex method that splits the image into tiles that are designated as white or black in a 'checkerboard' scheme --- this scheme helps with detection of objects at tile boundaries, which is a complex problem in its own right. We don't really need to understand the details of how all of this works for now. We do want to at least understand the form of the resulting object, which is stored in variable `pred`.

In [None]:
pred.keys()

These names will seem somewhat mysterious, and that's okay. We can learn more about them and how they are computed throughout the semester. The NLL functions reproduced above take a target catalog and the pred object above, and use these to compute the NLL loss. In other words, the quantity

$$
- \log q_\phi(z \mid x)
$$

that we aim to compute is given by the following: firstly, $z$ is the `target_cat` of type `TileCatalog`. Recall that because we're generating synthetic data, the latent variable $z$ is not hidden, but known. The $x$ is given by the images from `batch['images']`, and these are operated on by the `infer` method of the encoder. The resulting computations yield the objects in `pred.keys()` above, which can be used to compute $-\log q_\phi(z \mid x)$ for this particular data batch via the functions `_single_detection_nll` and `_double_detection_nll`. We don't need to go into detail as to how these are computed for now.

In [None]:
loss = encoder._double_detection_nll(target_cat1, target_cat, pred)

In [None]:
loss.shape

In [None]:
loss.mean()

Loss is evidently computed on a per-image per-tile basis. We'll average across all of these because they all deserve equal weighting in our toy setup here. Now that we've illustrated how to compute the loss, let's wrap this all into a training loop to fit the encoder. We'll fit by optimizing the parameters directly rather than wrapping the procedure into a PyTorch Lightning routine as is done in the true BLISS code. 

This is extremely simplistic: we have a single batch of data $z,x$ that we generated above.  Nevertheless, in the training loop below, we still redefine/recompute `pred` and the target catalogs. In a true training procedure with many different batches of images, this would need to be done within the loop because we'll have a different batch of images every time.

***The cell below runs very slow because it's all CPU. True codebase is optimized to GPU. Naive use of GPU below will lead to out of memory error.***

In [None]:
# This is a very small number of iterations, but it still takes 5-10 minutes
# In practice, BLISS training will be a lot faster

niter = 30
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)

for i in range(niter):
    # Many of the lines below are redundant because we have a single batch
    # so quantities don't change.
    
    target_cat = TileCatalog(encoder.tile_slen, batch["tile_catalog"])
    # filter out undetectable sources
    if encoder.min_flux_threshold > 0:
        target_cat = target_cat.filter_tile_catalog_by_flux(min_flux=encoder.min_flux_threshold)

    # make predictions/inferences
    target_cat1 = target_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)
    truth_callback = lambda _: target_cat1
    pred = encoder.infer(batch, truth_callback)
    
    # Main gradient step code
    optimizer.zero_grad()
    loss = encoder._double_detection_nll(target_cat1, target_cat, pred).mean()
    loss.backward()
    optimizer.step()
    
    print('Iteration {}: Loss {}'.format(i, loss.item()))
    

The training procedure above is very rough: there's no learning rate tuning, scheduling, etc., and we don't run the fitting procedure all the way to convergence due to time constraints. Nevertheless, this is enough to get a feel for how fitting the encoder should generally go. In actuality, however, a lot of this will be abstracted away within PyTorch lightning procedures.

We want to see how the encoder is doing. Recall that we just have a single data pair $z,x$ that we've been training with. It's a reasonable expectation that with enough training time our encoder should learn to output the correct $z$ given $x$ as an input. Let's check it out. We do this by using the `sample` method of the encoder. We'll simply take the posterior mode for now, i.e. the mode of the distribution $q_\phi(z \mid x)$, but we could generate more diverse samples from the actual distribution if desired.

In [None]:
posterior_mode = encoder.sample(batch, use_mode=True)

In [None]:
type(posterior_mode)

In [None]:
vars(posterior_mode).keys()

Let's compare the posterior model to the true target catalog $z$, which is the variable `target_cat` after wrapping up in this object.

In [None]:
posterior_mode.n_sources.shape, target_cat.n_sources.shape

BLISS usually omits border tiles. We'll want to do the same when we look at the target catalog for evaluation.

In [None]:
posterior_mode.n_sources.shape, target_cat.n_sources[:,1:-1,1:-1, ...].shape

In [None]:
1.-(posterior_mode.n_sources != target_cat.n_sources[:,1:-1,1:-1, ...]).sum()/(32*54*54)

The line above tells us the proportion of tiles across the 32 images in our batch in which the posterior mode identifies the correct number of sources. This proportion would probably increase if we trained longer. Note that the variational distribution constrains us to have at most 2 sources per tile, so it's not surprising that some are wrong.

In [None]:
1.-(torch.abs(posterior_mode.n_sources - target_cat.n_sources[:,1:-1,1:-1, ...]) > 1).sum()/(32*54*54)

The line above tells us the proportion of tiles in our 32 images in which the number of sources detected differs from the true number of sources by no more than 1. In other words, even when the number of sources is wrong, it's usually off by no more than 1, mistaking 3 sources for 2, for example.

If you want, you can also check if the predicted locations, fluxes, etc. look approximately correct. We can do this together at a later date.