# Working with Variational Distributions

We'll illustrate how to use the `VariationalDist` and `VariationalDistSpec` objects in tandem with the `Encoder` to perform sampling and compute the NLL. Let's load and simulate some data according to the `base_config`.

In [None]:
import sys
import os
os.chdir('/home/yolandz/bliss')
from bliss.encoder.variational_dist import VariationalDistSpec, VariationalDist
from bliss.encoder.unconstrained_dists import UnconstrainedNormal
import torch
import numpy as np
from os import environ
from pathlib import Path
from hydra import initialize, compose
from hydra.utils import instantiate
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf
from bliss.catalog import TileCatalog
from case_studies.redshift_estimation.catalog import RedshiftTileCatalog

environ["BLISS_HOME"] = "/home/declan/current/bliss"
with initialize(config_path="../../../bliss/conf", version_base=None):
    cfg = compose("base_config", overrides={"surveys.sdss.load_image_data=true"})


In [2]:
import torch
data_path = '/data/scratch/declan/redshift_estimation/dataset_0.pt'
loaded_data = torch.load(data_path)


In [3]:
tile_cat = loaded_data[0]["tile_catalog"]
tile_cat


{'locs': tensor([[[[7.0744e-01, 1.5549e-01]],
 
          [[4.4622e-01, 1.7127e-01]],
 
          [[4.6269e-01, 1.5261e-02]],
 
          [[6.6563e-01, 9.8993e-02]],
 
          [[4.4525e-01, 7.8420e-01]],
 
          [[2.6686e-01, 8.7291e-01]],
 
          [[3.5706e-02, 2.0148e-01]],
 
          [[2.1426e-01, 9.2284e-01]],
 
          [[4.3947e-01, 6.3854e-01]],
 
          [[1.1949e-01, 3.8836e-01]],
 
          [[7.3251e-01, 3.4021e-01]],
 
          [[3.4474e-01, 5.0138e-01]],
 
          [[2.9869e-01, 6.8957e-01]],
 
          [[1.9654e-01, 6.4135e-01]],
 
          [[8.1384e-01, 7.8318e-01]],
 
          [[2.9213e-01, 6.2674e-01]],
 
          [[1.0082e-01, 1.5406e-01]],
 
          [[6.5921e-01, 6.1923e-01]],
 
          [[2.6321e-01, 1.9433e-01]],
 
          [[8.7317e-02, 4.3605e-01]]],
 
 
         [[[6.4604e-01, 3.8550e-01]],
 
          [[9.9488e-02, 9.7894e-01]],
 
          [[1.1560e-01, 6.7290e-01]],
 
          [[6.8576e-02, 6.9570e-01]],
 
          [[2.9871e-01, 8.151

In [4]:
(tile_cat).keys()
tile_cat['locs'].shape[:-1]

torch.Size([20, 20, 1])

In [5]:
encoder = instantiate(cfg.train.encoder)
target_cat = RedshiftTileCatalog(encoder.tile_slen, tile_cat)

ValueError: not enough values to unpack (expected 4, got 3)

In [None]:
type(simulator)

bliss.simulator.simulated_dataset.SimulatedDataset

#### This cell below will take several minutes to run.

In [None]:
test_batch = simulator.get_batch()

In [None]:
test_batch.keys()

dict_keys(['tile_catalog', 'images', 'background', 'deconvolution', 'psf_params'])

In [None]:
test_batch['tile_catalog'].keys()

dict_keys(['locs', 'n_sources', 'source_type', 'galaxy_fluxes', 'galaxy_params', 'star_fluxes'])

Let's instaniate the `Encoder` and run its primary method (for our purposes): the `infer` method that operates on simulated batches of data. More precisely, infer operates on the images of the batch (`test_batch['images']` in our naming so far). First, let's hard-code the so-called `target_cat` aka target catalog to be the ground truth $z$. We need this to construct the variable `truth_callback` below (we won't worry too much about the motivation behind this for now).

In [None]:
encoder = instantiate(cfg.train.encoder)
target_cat = TileCatalog(encoder.tile_slen, test_batch["tile_catalog"])

In [None]:
target_cat

TileCatalog(64 x 20 x 20)

In [None]:
# filter out undetectable sources
if encoder.min_flux_threshold > 0:
    target_cat = target_cat.filter_tile_catalog_by_flux(min_flux=encoder.min_flux_threshold)
    
# make predictions/inferences
target_cat1 = target_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)
truth_callback = lambda _: target_cat1
pred = encoder.infer(test_batch, truth_callback)

Let's examine the outputs of the `infer` method.

In [None]:
pred.keys()

dict_keys(['x_features', 'marginal', 'history_cat', 'white_history_mask', 'white', 'black'])

We'll be most concerned with the dict entry `marginal`. As the name suggests, this will contain all information necessary for constructing the marginal variational distributions. In general, if a variational family is mean-field, it factorizes as the product of the marginals

$$
q(z_1, z_2 \mid x) = q(z_1 \mid x) q(z_2 \mid x)
$$

hence the naming. For each latent variable of interest, we just need to know the marginal variational distribution on it when using this particular form of mean-field variational family. BLISS uses this type of variational family for the most part, although there are some subtleties whereby adjacent tiles do interact with one another. Let's focus on the pure mean-field case for now. If we have a distribution on each latent variable of interest, we can compute

$$
\log q(z_1, \dots, z_K \mid x) = \sum_{i=1}^K \log q(z_i \mid x)
$$

by log properties. We shall see that this is essentially how the `VariationalDist` objects compute the NLL.

In [None]:
type(pred['marginal'])

bliss.encoder.variational_dist.VariationalDist

In [None]:
pred['marginal'].factors

{'on_prob': Categorical(probs: torch.Size([64, 20, 20, 2])),
 'loc': TruncatedDiagonalMVN(Normal(loc: torch.Size([64, 20, 20, 2]), scale: torch.Size([64, 20, 20, 2]))),
 'galaxy_prob': Categorical(probs: torch.Size([64, 20, 20, 2])),
 'galsim_disk_frac': TransformedDistribution(),
 'galsim_beta_radians': TransformedDistribution(),
 'galsim_disk_q': TransformedDistribution(),
 'galsim_a_d': LogNormal(),
 'galsim_bulge_q': TransformedDistribution(),
 'galsim_a_b': LogNormal(),
 'star_flux_u': LogNormal(),
 'star_flux_g': LogNormal(),
 'star_flux_r': LogNormal(),
 'star_flux_i': LogNormal(),
 'star_flux_z': LogNormal(),
 'galaxy_flux_u': LogNormal(),
 'galaxy_flux_g': LogNormal(),
 'galaxy_flux_r': LogNormal(),
 'galaxy_flux_i': LogNormal(),
 'galaxy_flux_z': LogNormal()}

We see that `pred['marginal']` returns an object of type `VariationalDist`. This variational distribution is conditional on the $x$ defined by our particular simulated batch.

The `factors` attribute contains each of the marginal factors. Some of the shapes are revealing. Clearly 64 x 20 x 20 corresponds to the number of images in a simulated batch and the tiles corresponding to the `base_config` we've used. These can be altered if desired.

In [None]:
cfg.prior.batch_size, cfg.prior.n_tiles_h, cfg.prior.n_tiles_w

(64, 20, 20)

This informs some of the dimensions we see above. For example, `loc` should 2D ***per-source***, an (x,y) coordinate pair. Hence the 2 in the last dimension for `loc`. Evidently, many of these distributions are per-source.

Recall that BLISS's variational distribution currently allows for at most 2 sources per tile. This does not seem to be accounted for in the above. The way BLISS works in reality is to first detect the brightest source in each tile; then, having accounted for this, try to find a second source. In this `base_config` we don't do this actually: the prior is constrained to have at most 1 source per-tile so it's not necessary. If one wanted to detect multiple sources per-tile, one could change by overwriting as below. To detect multiple sources, one needs to implement this two-stage process. This can be done by setting the configurable `encoder.double_detect` to be `True` (and potentially some other configurables).

For now, we won't worry about this and will be satisfied with generating and detecting at most 1 source per tile.

In [None]:
print(cfg.prior.max_sources) # Should be =1
print(cfg.encoder.double_detect) #Should be False

1
False


In [None]:
# Would change these configurables to detect multiple sources via a two-layer detection stage in encoder.
# cfg.prior.max_sources = 2
# cfg.encoder.double_detect = True

How does the `infer` method of `Encoder` produce a `VariationalDist` object? Through the `VariationalDistSpec` class. This "variational distribution specificiation" contains the information needed to construct the variational distribution. Here's a look at the class below. The most important attribute is `factor_specs` which specifies the factors for each latent variable of interest.

```
class VariationalDistSpec(torch.nn.Module):
    def __init__(self, survey_bands, tile_slen):
        super().__init__()

        self.survey_bands = survey_bands
        self.tile_slen = tile_slen

        self.factor_specs = {
            "on_prob": UnconstrainedBernoulli(),
            "loc": UnconstrainedTDBN(),
            "galaxy_prob": UnconstrainedBernoulli(),
            # galsim parameters
            "galsim_disk_frac": UnconstrainedLogitNormal(),
            "galsim_beta_radians": UnconstrainedLogitNormal(high=torch.pi),
            "galsim_disk_q": UnconstrainedLogitNormal(),
            "galsim_a_d": UnconstrainedLogNormal(),
            "galsim_bulge_q": UnconstrainedLogitNormal(),
            "galsim_a_b": UnconstrainedLogNormal(),
        }
...
```

The `Encoder` is instantiated with a `VariationalDistSpec` attribute. The `infer` method then uses the `make_dist` method of `VariationalDistSpec` to create the `VariationalDist` object.

In [None]:
encoder.vd_spec

VariationalDistSpec()

We will present just a little more detail on how `Encoder` produces the final variational distribution. Look inside `Encoder.infer` for more detail. 

For any given batch, the images are first normalized, then passed through a feature net. 

In [None]:
normalized_images = encoder.image_normalizer.get_input_tensor(test_batch)
x_features = encoder.features_net(normalized_images)
x_features.shape

torch.Size([64, 256, 20, 20])

It appears that each tile is represented by a vector of length 256. This tensor is then passed to `marginal_net`  which produces all necessary variational parameters.

In [None]:
x_cat = encoder.marginal_net(x_features)
x_cat.shape

torch.Size([64, 20, 20, 38])

The shape of `x_cat` is revealing. It tells us that per-tile, we have 38 variational parameters. Let's recover this number 38 from the variational distribution specification to check that everything matches.

In [None]:
sum(dist.dim for _,dist in encoder.vd_spec.factor_specs.items() )

38

So this magical number 38 just matches the dimensions totaled by all the variational distributions. If we added more variational distributions, this number would have to change obviously (it would increase, to 40, 45, etc.

Let's show how `VariationalDist` is created from `x_cat`. As said above, this is done using the `make_dist` method of `VariationalDistSpec`.

```
def make_dist(self, x_cat):
    factors = self._parse_factors(x_cat)
    return VariationalDist(factors, self.survey_bands, self.tile_slen)

```

In [None]:
vd = encoder.vd_spec.make_dist(x_cat)

Once one has the `VariationalDist` object created, one can compute NLL and sample using the functions quite easily.

In [None]:
vd.compute_nll(target_cat1).shape

torch.Size([64, 20, 20])

In [None]:
vd.sample(use_mode=True)

TileCatalog(64 x 20 x 20)

# Redshift-Focused Variational Distributions

I've extended the `VariationalDistSpec` and `VariationalDist` classes to `RedshiftVariationalDistSpec` and `RedshiftVariationalDist`. The main differences between these is that they add a new variational distribution on redshift to the list of parameters above. The config `redshift.yaml` extends the `base_config.yaml` so similarly both generates and detects at most one source per tile. We can make it more complicated later.

In [None]:
from catalog import RedshiftTileCatalog
from variational_dist import RedshiftVariationalDistSpec, RedshiftVariationalDist

environ["BLISS_HOME"] = "/home/declan/current/bliss"
with initialize(config_path=".", version_base=None):
    cfg = compose("redshift", overrides={"surveys.sdss.load_image_data=true"})

In [None]:
import torch

# Path to your generated data file
data_path = '/data/scratch/declan/redshift_estimation/your_data_file.pt'

# Load the data
loaded_data = torch.load(data_path)

# Access the tile_catalog or other keys as needed
tile_catalog = loaded_data['tile_catalog']

In [None]:
simulator = instantiate(cfg.simulator)
test_batch = simulator.get_batch()
test_batch['tile_catalog'].keys()

dict_keys(['locs', 'n_sources', 'source_type', 'galaxy_fluxes', 'galaxy_params', 'star_fluxes', 'redshifts'])

In [None]:
encoder = instantiate(cfg.train.encoder)
target_cat = RedshiftTileCatalog(encoder.tile_slen, test_batch["tile_catalog"])

In [None]:
type(encoder)

bliss.encoder.encoder.Encoder

We do not need to modify/extend the Encoder class at this point in time.

In [None]:
# filter out undetectable sources
if encoder.min_flux_threshold > 0:
    target_cat = target_cat.filter_tile_catalog_by_flux(min_flux=encoder.min_flux_threshold)
    
# get brightest sources
target_cat1 = target_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)

Let's make an example `RedshiftVariationalDistSpec` object to illustrate the new `factor_spec` on redshift.

In [None]:
rvds = RedshiftVariationalDistSpec(cfg.prior.survey_bands, cfg.prior.tile_slen)

In [None]:
list(rvds.factor_specs.items())

[('on_prob',
  <bliss.encoder.unconstrained_dists.UnconstrainedBernoulli at 0x7efca310fc70>),
 ('loc',
  <bliss.encoder.unconstrained_dists.UnconstrainedTDBN at 0x7efca310fca0>),
 ('galaxy_prob',
  <bliss.encoder.unconstrained_dists.UnconstrainedBernoulli at 0x7efca310ec20>),
 ('galsim_disk_frac',
  <bliss.encoder.unconstrained_dists.UnconstrainedLogitNormal at 0x7efca310fb80>),
 ('galsim_beta_radians',
  <bliss.encoder.unconstrained_dists.UnconstrainedLogitNormal at 0x7efca310f280>),
 ('galsim_disk_q',
  <bliss.encoder.unconstrained_dists.UnconstrainedLogitNormal at 0x7efca310eef0>),
 ('galsim_a_d',
  <bliss.encoder.unconstrained_dists.UnconstrainedLogNormal at 0x7efca310ee60>),
 ('galsim_bulge_q',
  <bliss.encoder.unconstrained_dists.UnconstrainedLogitNormal at 0x7efca310f520>),
 ('galsim_a_b',
  <bliss.encoder.unconstrained_dists.UnconstrainedLogNormal at 0x7efca310f5e0>),
 ('star_flux_u',
  <bliss.encoder.unconstrained_dists.UnconstrainedLogNormal at 0x7efca310f130>),
 ('star_flux_

We see that `rvds.factor_specs` contains a new variational distribution unique for this project, a distribution redshift. It's a Gaussian distribution for now, but be made to be anything one desires.

The `Encoder` object should have a `RedshiftVariationalDistSpec` as its `vd_spec` attribute now. This was done by changing the appropriate `_target_` in the config.

In [None]:
encoder.vd_spec

RedshiftVariationalDistSpec()

Using the output of `Encoder.infer` thus get an instance of `RedshiftVariationalDist` that we can use to compute the NLL and sample, etc.

In [None]:
truth_callback = lambda _: target_cat1
pred = encoder.infer(test_batch, truth_callback)

In [None]:
pred['marginal']

RedshiftVariationalDist()

In [None]:
sum(dist.dim for _,dist in encoder.vd_spec.factor_specs.items() )

40

There are now 40 variational parameters per source, up from 38 before. We added two: a location and scale parameter for the Gaussian distribution that describes redshift.

Let's compute NLL and sample to make sure everything works.

In [None]:
rvd = pred['marginal']

In [None]:
rvd.sample(use_mode=True)

RedshiftTileCatalog(64 x 20 x 20)

In [None]:
rvd.compute_nll(target_cat1).shape

torch.Size([64, 20, 20])

### Example Training Loop

***This cell runs pretty slowly. All on CPU for illustration.***

In [None]:
niter = 300
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)
for i in range(niter):
    # Many of the lines below are redundant because we have a single batch
    # so quantities don't change.
    
    target_cat = RedshiftTileCatalog(encoder.tile_slen, test_batch["tile_catalog"])
    # filter out undetectable sources
    if encoder.min_flux_threshold > 0:
        target_cat = target_cat.filter_tile_catalog_by_flux(min_flux=encoder.min_flux_threshold)

    # make predictions/inferences
    target_cat1 = target_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)
    truth_callback = lambda _: target_cat1
    pred = encoder.infer(test_batch, truth_callback)
    rvd = pred['marginal']
    
    # Main gradient step code
    optimizer.zero_grad()
    loss = rvd.compute_nll(target_cat1).mean()
    loss.backward()
    optimizer.step()
    
    if i % 10 == 0:
        print('Iteration {}: Loss {}'.format(i, loss.item()))
    

Iteration 0: Loss 31.29745101928711
Iteration 10: Loss 5.976312160491943
Iteration 20: Loss 5.461785316467285
Iteration 30: Loss 4.9790143966674805
Iteration 40: Loss 4.622425079345703
Iteration 50: Loss 4.379827976226807
Iteration 60: Loss 4.2173662185668945
Iteration 70: Loss 4.101867198944092
Iteration 80: Loss 4.002026557922363
Iteration 90: Loss 3.91727352142334
Iteration 100: Loss 3.8605949878692627
Iteration 110: Loss 3.7549383640289307
Iteration 120: Loss 3.747856855392456
Iteration 130: Loss 3.656813383102417
Iteration 140: Loss 3.6292967796325684
Iteration 150: Loss 3.5320780277252197
Iteration 160: Loss 3.470863103866577
Iteration 170: Loss 3.44598388671875
Iteration 180: Loss 3.3535513877868652
Iteration 190: Loss 3.4474265575408936
Iteration 200: Loss 3.3439383506774902
Iteration 210: Loss 3.2493293285369873
Iteration 220: Loss 3.5173134803771973
Iteration 230: Loss 3.2241787910461426
Iteration 240: Loss 3.138590097427368
Iteration 250: Loss 3.0926873683929443
Iteration 26

Let's check how we're doing on the redshift variational distributions. Recall the prior on redshift is extremely concentrated as this is a toy case for now.

In [None]:
cfg.prior.redshift_min, cfg.prior.redshift_max

(0.99, 1.01)

In [None]:
q = rvd.factors
q["redshift"].loc.flatten()

tensor([0.9924, 0.9927, 0.9909,  ..., 0.9837, 0.9853, 0.9877],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
q["redshift"].scale.flatten()

tensor([0.0416, 0.0426, 0.0425,  ..., 0.0421, 0.0418, 0.0419],
       grad_fn=<ReshapeAliasBackward0>)

The locations are looking pretty good; the scales are still way too big, but at least the variational distributions are overdispersed. The variational distribution is misspecified: the prior is $\textrm{Unif}(0.99, 1.01)$, and we have not modified the decoder $p(x \mid z)$ at all. In other words, redshift has no impact on the data right now. Accordingly, the posterior should be equal to the prior.

Of course we don't achieve that, because the variational distribution is constained to be Gaussian. We'd hope with more trainin we could get even more highly concentrated Gaussians approximately in the interval $[0.99,1.01]$ but the above suffices for now as a sanity check.