# One Star Example

Demonstrates that
1) the independent tiling posterior approximation becomes increasingly bad as a star approaches a tile border
2) the dependent tiling posterior approximation remains reasonable regardless of star position

In [None]:
%matplotlib inline

Import necessary packages and pick a GPU

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
from hydra import initialize, compose
from hydra.utils import instantiate
from bliss.encoder.encoder import Encoder
from bliss.catalog import TileCatalog
from matplotlib import pyplot as plt

Load the encoder with pre-trained weights. (This encoder was trained with 20% of tiles contain sources, which is quite high for one-star data, but the incorrect rate shouldn't detract from this example.)

In [None]:
with initialize(config_path="../../bliss/conf", version_base=None):
    cfg = compose("base_config")

encoder: Encoder = instantiate(cfg.encoder)
encoder.load_state_dict(torch.load("../../data/pretrained_models/clahed_logged_20percent.pt"))
encoder.cuda()
encoder.eval()
torch.set_grad_enabled(False)

simulator = instantiate(cfg.simulator)

## Bright star

Create a batch synethic catalogs of one bright star at three positions: 0, 0.667, and 0.133 pixels from the border. (The first position is perfectly ambiguous, the second is somewhat ambiguous, the third is unambiguous.)

The nice thing about this setting is the lack of ambiguity in the (exact) posterior: there should be about one star detected. But is there with each of these posterior approximations?

In [None]:
loc_grid_size = 3
loc_replicates = 20
n = loc_grid_size * loc_replicates
n_sources = torch.zeros(n, 20, 20, dtype=torch.long)
n_sources[:, 10, 10] = 1
locs = torch.arange(loc_grid_size) / (loc_grid_size * 5)
locs = torch.stack([locs, torch.ones(loc_grid_size) * 0.5], dim=1)
locs = locs.view(loc_grid_size, 1, 1, 1, 1, 2).expand([-1, loc_replicates, 20, 20, 1, -1])
locs = locs.reshape(n, 20, 20, 1, 2)
true_catalog_dict = {
    "n_sources": n_sources,
    "source_type": torch.zeros(n, 20, 20, 1, 1, dtype=torch.long),
    "locs": locs,
    "star_fluxes": torch.ones(n, 20, 20, 1, 5) * 5,  # 20.75 magnitude
    "galaxy_fluxes": torch.ones(n, 20, 20, 1, 5),
}
true_catalog = TileCatalog(4, true_catalog_dict)

Simulate one image for each catalog

In [None]:
rcfs, rcf_indices = simulator.randomized_image_ids(true_catalog.n_sources.size(0))
image, background, _, _ = simulator.simulate_images(true_catalog, rcfs, rcf_indices)
batch = {"images": image.cuda(), "background": background.cuda()}

Show the r-band of one image

In [None]:
plt.imshow(image[8, 2].numpy())
plt.colorbar();

Clear the GPU memory so we don't run out in case we re-run this notebook

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory_allocated() / 1e9  # show current memory usage in GB

### Dependent Tiling

Predict two catalogs: one the mode of the variational distribution and the other a sample of it

In [None]:
mode_cat, sample_cat = encoder.predict_step(batch, 0, 0).values()

One tile is cropped, so the source should be found in tile [9,9], moving in the first dimension
from 0 to 0.5. We restrict our attention to a 36 tiles to avoid spurious detections which are inevitable in a large enough image due to Gaussian noise.

For all three locations, all 20 replicates show 1 source in the variational distribution mode.

In [None]:
mode_cat.n_sources[:, 6:12, 6:12].sum([1, 2]).view(3, 20)

The sample cat isn't as consistent, but there's clear dependence on the location (border vs interior). The twos are low-flux detections.

In [None]:
sample_cat.n_sources[:, 6:12, 6:12].sum([1, 2]).view(3, 20)

### Independent Tiling

In [None]:
encoder.use_checkerboard = False
mode_cat, sample_cat = encoder.predict_step(batch, 0, 0).values()

In [None]:
mode_cat.n_sources[:, 6:12, 6:12].sum([1, 2]).view(3, 20)

In [None]:
sample_cat.n_sources[:, 6:12, 6:12].sum([1, 2]).view(3, 20)

Let's look at the underlying probabilities without sampling, which is possible for independent tiling.

Oh, these are actually pretty well calibrated regardless of star position.

In [None]:
pred = encoder.infer(batch, lambda _: _)

on_prob = pred["marginal"].factors["on_prob"].probs[:, :, :, 1]
plt.imshow(on_prob[0].cpu().numpy())
plt.colorbar();
on_prob[:, 9:11, 10].sum(1)

## Dim Star

An ambiguous detection (~80% detection prob) at the border and at the center.

First, we generate a true catalog.

In [None]:
# because there are 4 pixels in a tile column, these stars appear in the same offsets
# within pixels
loc_grid_size = 4
loc_replicates = 100
n = loc_grid_size * loc_replicates
n_sources = torch.zeros(n, 20, 20, dtype=torch.long)
n_sources[:, 10, 10] = 1
locs = torch.arange(loc_grid_size) / loc_grid_size + 0.0
locs = torch.stack([locs, torch.ones(loc_grid_size) * 0.5], dim=1)
locs = locs.view(loc_grid_size, 1, 1, 1, 1, 2).expand([-1, loc_replicates, 20, 20, 1, -1])
locs = locs.reshape(n, 20, 20, 1, 2)
true_catalog_dict = {
    "n_sources": n_sources,
    "source_type": torch.zeros(n, 20, 20, 1, 1, dtype=torch.long),
    "locs": locs,
    # with flux of 1.5, stars at all positions are detected easily (prob > 0.97 for leftmost, 1. for rest)
    # with flux of 0.3, stars at no position are detected (prob < 0.25 for all)
    # with flux of 0.5, stars at all positions are detected with prob in [0.7, 0.8]
    "star_fluxes": torch.ones(n, 20, 20, 1, 5) * 0.5,
    "galaxy_fluxes": torch.ones(n, 20, 20, 1, 5),
}
true_catalog = TileCatalog(4, true_catalog_dict)

Now we render the images and backgrounds.

In [None]:
rcfs, rcf_indices = simulator.randomized_image_ids(true_catalog.n_sources.size(0))
image, background, _, _ = simulator.simulate_images(true_catalog, rcfs, rcf_indices)
batch = {"images": image.cuda(), "background": background.cuda()}

### Independent tiling

In [None]:
mode_cat, sample_cat = encoder.predict_step(batch, 0, 0).values()

It looks like the expected light source count is similar regardless of light source position

In [None]:
mode_cat.n_sources[:, 6:12, 6:12].sum([1, 2]).view(4, 100)

Remind me, what are the 4 locations we're considering?

In [None]:
locs[:, 10, 10, 0, 0].view(4, 100)[:, 0]

We're "right" most often when the source is exactly in the middle, and wrong most often when the source is at the border. These disparities are concerning because tiling is solely a construct of our inference procedure.

Note that this is about the mode, the samples. For an ambiguous source, the mass is being split between tile and therefore it never exceeds 50%.

In [None]:
(mode_cat.n_sources[:, 6:12, 6:12].sum([1, 2]).view(4, 100) == 1).sum(1)

### Dependent tiling

In [None]:
encoder.use_checkerboard = True
mode_cat, sample_cat = encoder.predict_step(batch, 0, 0).values()

There's no pattern here for which position leads to the best mode; the centered tile is not clearly better than the border.

In [None]:
(mode_cat.n_sources[:, 6:12, 6:12].sum([1, 2]).view(4, 100) == 1).sum(1)

What about the samples? They just tell the same story as the bright-star samples. The point of looking at a dim star was to see how even the mode is bad with independent tiling.