### Generating images from RedshiftTileCatalog, working with VariationalDist and VariationalDistSpec Objects

In [None]:
import sys
import os

os.getcwd()

os.chdir('/home/declan/current/bliss')

from bliss.encoder.variational_dist import VariationalDistSpec, VariationalDist
from bliss.encoder.unconstrained_dists import UnconstrainedNormal
import torch
import numpy as np
from os import environ
from pathlib import Path
from hydra import initialize, compose
from hydra.utils import instantiate
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf
from bliss.catalog import TileCatalog

environ["BLISS_HOME"] = "/home/declan/current/bliss"
with initialize(config_path=".", version_base=None):
    cfg = compose("redshift", overrides={"surveys.sdss.load_image_data=true"})


In [None]:
simulator = instantiate(cfg.simulator)

In [None]:
type(simulator)

In [None]:
prior = simulator.catalog_prior

In [None]:
prior

In [None]:
yo = simulator.catalog_prior.sample()

In [None]:
yo

In [None]:
vars(yo).keys()

In [None]:
yo['data'].keys()

Let's try to generate actual data using a RedshiftTileCatalog. All we have to do (for now) is plug in to the existing infrastructure and ignore `RedshiftTileCatalog.redshifts`.

In [None]:
cfg.generate

In [None]:
cfg.paths.data

In [None]:
simulated_dataset = instantiate(cfg.simulator, num_workers=0)
test_batch = simulated_dataset.get_batch()

In [None]:
test_batch.keys()

In [None]:
test_batch['tile_catalog'].keys()

In [None]:
test_batch['tile_catalog']['redshifts'].shape

### TODO: Change shape of redshifts (what should they be)?
### TODO: Make redshifts an attribute of RedshiftTileCatalog (currently hidden in .data)

# Playing Around With VarDist and VarDistSpec

In [1]:
import sys
import os

os.getcwd()

os.chdir('/home/declan/current/bliss')

from bliss.encoder.variational_dist import VariationalDistSpec, VariationalDist
from bliss.encoder.unconstrained_dists import UnconstrainedNormal
import torch
import numpy as np
from os import environ
from pathlib import Path
from hydra import initialize, compose
from hydra.utils import instantiate
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf
from bliss.catalog import TileCatalog

environ["BLISS_HOME"] = "/home/declan/current/bliss"
with initialize(config_path=".", version_base=None):
    cfg = compose("non_redshift", overrides={"surveys.sdss.load_image_data=true"})

In [5]:
simulator = instantiate(cfg.simulator)

In [6]:
type(simulator)

bliss.simulator.simulated_dataset.SimulatedDataset

In [7]:
test_batch = simulator.get_batch()

In [8]:
from bliss.encoder.variational_dist import VariationalDist, VariationalDistSpec

In [9]:
test_batch.keys()

dict_keys(['tile_catalog', 'images', 'background', 'deconvolution', 'psf_params'])

In [10]:
test_batch['tile_catalog'].keys()

dict_keys(['locs', 'n_sources', 'source_type', 'galaxy_fluxes', 'galaxy_params', 'star_fluxes'])

In [11]:
test_batch['images'].shape

torch.Size([64, 5, 80, 80])

In [12]:
vds = VariationalDistSpec(cfg.prior.survey_bands, cfg.prior.tile_slen)

In [13]:
vd = vds.make_dist(test_batch['images'])

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 80 (input tensor's size at dimension 3), but got split_sizes=[1, 4, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

In [14]:
test_batch['images'].shape

torch.Size([64, 5, 80, 80])

In [15]:
split_sizes = [v.dim for v in vds.factor_specs.values()]

In [16]:
split_sizes

[1, 4, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

In [17]:
sum(split_sizes)

38

In [18]:
encoder = instantiate(cfg.train.encoder)
encoder

Encoder(
  (metrics): CatalogMetrics()
  (image_normalizer): ImageNormalizer()
  (features_net): FeaturesNet(
    (preprocess3d): Sequential(
      (0): Conv3d(5, 64, kernel_size=(6, 5, 5), stride=(1, 1, 1), padding=(0, 2, 2))
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU()
    )
    (backbone): Sequential(
      (0): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (activation): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): ConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
          (activation): SiLU(inplace=True)
        )
        (1): ConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(5, 5), 

In [19]:
target_cat = TileCatalog(encoder.tile_slen, test_batch["tile_catalog"])

In [21]:
# filter out undetectable sources
if encoder.min_flux_threshold > 0:
    target_cat = target_cat.filter_tile_catalog_by_flux(min_flux=encoder.min_flux_threshold)
    
# make predictions/inferences
target_cat1 = target_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)
truth_callback = lambda _: target_cat1
pred = encoder.infer(test_batch, truth_callback)

In [22]:
pred.keys()

dict_keys(['marginal', 'white', 'black', 'history_cat', 'x_features', 'white_history_mask'])

In [25]:
pred['x_features'].shape

torch.Size([64, 256, 20, 20])

In [27]:
pred['history_cat']

TileCatalog(64 x 20 x 20)

In [30]:
test_batch['tile_catalog']['locs'].shape

torch.Size([64, 20, 20, 1, 2])

In [34]:
pred['marginal'].pred

{'on_prob': Categorical(probs: torch.Size([64, 20, 20, 2])),
 'loc': TruncatedDiagonalMVN(Normal(loc: torch.Size([64, 20, 20, 2]), scale: torch.Size([64, 20, 20, 2]))),
 'galaxy_prob': Categorical(probs: torch.Size([64, 20, 20, 2])),
 'galsim_disk_frac': TransformedDistribution(),
 'galsim_beta_radians': TransformedDistribution(),
 'galsim_disk_q': TransformedDistribution(),
 'galsim_a_d': LogNormal(),
 'galsim_bulge_q': TransformedDistribution(),
 'galsim_a_b': LogNormal(),
 'star_flux_u': LogNormal(),
 'star_flux_g': LogNormal(),
 'star_flux_r': LogNormal(),
 'star_flux_i': LogNormal(),
 'star_flux_z': LogNormal(),
 'galaxy_flux_u': LogNormal(),
 'galaxy_flux_g': LogNormal(),
 'galaxy_flux_r': LogNormal(),
 'galaxy_flux_i': LogNormal(),
 'galaxy_flux_z': LogNormal()}

In [36]:
pred['marginal'].pred['galaxy_flux_z']

LogNormal()

In [56]:
yo = pred['marginal'].pred['on_prob']

In [58]:
vars(yo).keys()

dict_keys(['probs', '_param', '_num_events', '_batch_shape', '_event_shape'])

In [59]:
yo.batch_shape

torch.Size([64, 20, 20])

In [60]:
yo.event_shape

torch.Size([])

In [61]:
x_features = encoder.get_features(test_batch)

In [63]:
x_cat_marginal = encoder.marginal_net(x_features)

In [64]:
x_cat_marginal.shape

torch.Size([64, 20, 20, 38])

In [65]:
vd = vds.make_dist(x_cat_marginal)

In [66]:
vd

VariationalDist()

In [37]:
vars(pred['marginal'].pred['galaxy_flux_z'])

{'transforms': [ExpTransform()],
 'base_dist': Normal(loc: torch.Size([64, 20, 20]), scale: torch.Size([64, 20, 20])),
 '_batch_shape': torch.Size([64, 20, 20]),
 '_event_shape': torch.Size([]),
 '_validate_args': False}

In [44]:
pred['marginal'].pred['loc'].base_dist._event_shape[0]

2

In [None]:
pred['marginal'].pred['galsim_a

In [46]:
pred['marginal'].pred['galsim_a_d'].base_dist._event_shape

torch.Size([])

In [54]:
total = 0
for key, value in pred['marginal'].pred.items():
    print(key)
    try:
        dim = pred['marginal'].pred[key].base_dist._event_shape[-1]
    except:
        dim = 1
    total += dim
    print(dim)
    

on_prob
1
loc
2
galaxy_prob
1
galsim_disk_frac
1
galsim_beta_radians
1
galsim_disk_q
1
galsim_a_d
1
galsim_bulge_q
1
galsim_a_b
1
star_flux_u
1
star_flux_g
1
star_flux_r
1
star_flux_i
1
star_flux_z
1
galaxy_flux_u
1
galaxy_flux_g
1
galaxy_flux_r
1
galaxy_flux_i
1
galaxy_flux_z
1


In [53]:
total

20

In [31]:
vd = vds.make_dist(pred['x_features'])

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 20 (input tensor's size at dimension 3), but got split_sizes=[1, 4, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

#### Below is first playing around, working with my config a bit. Trying to generate redshifts (dummy)

In [None]:
import sys
import os

In [None]:
os.getcwd()

In [None]:
os.chdir('/home/declan/current/bliss')

In [None]:
sys.path.append('./case_studies/redshift')

In [None]:
from bliss.encoder.variational_dist import VariationalDistSpec, VariationalDist
from bliss.encoder.unconstrained_dists import UnconstrainedNormal
import torch

In [None]:
import numpy as np
from os import environ
from pathlib import Path
from hydra import initialize, compose
from hydra.utils import instantiate
import matplotlib.pyplot as plt

In [None]:
environ["BLISS_HOME"] = "/home/declan/current/bliss"
with initialize(config_path=".", version_base=None):
    cfg = compose("redshift", overrides={"surveys.sdss.load_image_data=true"})

In [None]:
cfg.simulator

In [None]:
simulator = instantiate(cfg.simulator)

In [None]:
simulator

In [None]:
vars(simulator)

In [None]:
simulator.catalog_prior.redshift_max

In [None]:
simulator.catalog_prior.galaxy_flux_min

In [None]:
simulator.catalog_prior._sample_redshifts().shape

In [None]:
yo = simulator.catalog_prior.sample()

In [None]:
yo.redshifts

In [None]:
type(yo)

In [None]:
yo.n_sources.shape

In [None]:
yo.__getitem__("redshifts")

In [None]:
yo.allowed_params

In [None]:
simulator = instantiate(cfg.simulator)

In [None]:
tc = simulator.catalog_prior.sample()

In [None]:
tc

In [None]:
vars(tc).keys()

In [None]:
tc.n_sources.shape

In [None]:
tc.n_tiles_h

In [None]:
tc.n_sources[0][0][0]

In [None]:
tc.tile_slen

In [None]:
tc.data.shape

In [None]:
tc.data