# One Star Example

Demonstrates that
1) the independent tiling posterior approximation becomes increasingly bad as a star approaches a tile border
2) the checkerboard tiling posterior approximation remains reasonable regardless of star position

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

Import necessary packages and pick a GPU

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from bliss.catalog import TileCatalog
from matplotlib import pyplot as plt
import numpy as np

Load the encoder with pre-trained weights. (This encoder was trained with 20% of tiles contain sources, which is quite high for one-star data, but the incorrect rate shouldn't detract from this example.)

In [None]:
from hydra import initialize, compose
from hydra.utils import instantiate

with initialize(config_path=".", version_base=None):
    overrides = {
        "predict.weight_save_path=/home/regier/bliss_output/jul25_toy_example_10_percent/version_0/checkpoints/best_encoder.ckpt",
        "decoder.with_noise=true",
        "decoder.with_dither=false",
        "encoder.predict_mode_not_samples=false",
        "train.trainer.logger=null",
        "train.trainer.max_epochs=0",
        "+train.trainer.num_sanity_val_steps=0",
        "cached_simulator.num_workers=0",
        "cached_simulator.splits=0:80/0:90/99:100",
#        "paths.cached_data=/data/scratch/regier/toy_example"
    }
    cfg = compose("toy_example", overrides)

decoder = instantiate(cfg.simulator.decoder)

torch.set_grad_enabled(False)

trainer = instantiate(cfg.train.trainer)

data_source = instantiate(cfg.train.data_source)
data_source.setup("fit")
data_source.setup("test")

encoder = instantiate(cfg.encoder).cuda()
encoder.eval()
state_dict = torch.load(cfg.predict.weight_save_path)["state_dict"]
encoder.load_state_dict(state_dict);

In [None]:
encoder.use_checkerboard = True
encoder.n_sampler_colors = 4


from pytorch_lightning.callbacks import Callback


class NllCallback(Callback):
    def __init__(self) -> None:
        super().__init__()
        self.nlls = []

    def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        nlls = pl_module.compute_sampler_nll(batch).sum([1,2])
        self.nlls.append(nlls)
    
    def report(self):
        nlls = torch.cat(self.nlls)
        print(f"Mean NLL: {nlls.mean().item():.2f} ({nlls.std().item() / len(self.nlls):.2f})")

nll_callback = NllCallback()
trainer = instantiate(cfg.train.trainer, callbacks=[nll_callback])
trainer.predict(encoder, dataloaders=[data_source.test_dataloader()], return_predictions=False)
nll_callback.report()

In [None]:
encoder.use_checkerboard = True
encoder.n_sampler_colors = 2

nll_callback = NllCallback()
trainer = instantiate(cfg.train.trainer, callbacks=[nll_callback])
trainer.predict(encoder, dataloaders=[data_source.test_dataloader()], return_predictions=False)
nll_callback.report()

In [None]:
encoder.use_checkerboard = False

nll_callback = NllCallback()
trainer = instantiate(cfg.train.trainer, callbacks=[nll_callback])
trainer.predict(encoder, dataloaders=[data_source.test_dataloader()], return_predictions=False)
nll_callback.report()

## Bright star

Create a batch synethic catalogs of one bright star at three positions: 0, 0.667, and 0.133 pixels from the border. (The first position is perfectly ambiguous, the second is somewhat ambiguous, the third is unambiguous.)

The nice thing about this setting is the lack of ambiguity in the (exact) posterior: there should be about one star detected. But is there with each of these posterior approximations?

In [None]:
def gen_loc_shift_data(pixel_shift, flux=10.0, n=100, add_galaxy=False):
    # 5.0 nmgy = 20.75 magnitude
    # n_sources = x["tile_catalog"]["n_sources"].unsqueeze(0).repeat(n, 1, 1)
    n_sources = torch.zeros(n, 20, 20)
    ht = int(pixel_shift // 4)
    hp = pixel_shift % 4
    n_sources[:, ht, 10] = 1

    locs = torch.ones((n, 20, 20, 1, 2)) * 0.5
    locs[:, ht, 10, 0, 0] = hp * 0.25

    if add_galaxy:
        n_sources[:, 10, 9] = 1
    
    source_type = torch.zeros(n, 20, 20, 1, 1)
    source_type[:, ht, 10] = 0
    source_type[:, 10, 9] = 1

    galaxy_params = torch.zeros(n, 20, 20, 1, 6)
    galaxy_params[:, 10, 9, 0, [3,5]] = 10.0

    star_fluxes = torch.zeros(n, 20, 20, 1, 5)
    star_fluxes[:, ht, 10] = flux

    galaxy_fluxes = torch.zeros(n, 20, 20, 1, 5)
    galaxy_fluxes[:, 10, 9] = 40.0

    true_catalog_dict = {
        "n_sources": n_sources,
        "source_type": source_type,
        "locs": locs,
        "star_fluxes": star_fluxes, 
        "galaxy_fluxes": galaxy_fluxes,
        "galaxy_params": galaxy_params,
    }
    true_catalog = TileCatalog(true_catalog_dict)

    images, psf_params = decoder.render_images(true_catalog)

    # one band (without using CachedDataset + OneBandTransform for simplicity)
    true_catalog["star_fluxes"] = true_catalog["star_fluxes"][..., 2:3]
    true_catalog["galaxy_fluxes"] = true_catalog["galaxy_fluxes"][..., 2:3]

    batch = {
        "images": images[:, 2:3].cuda(),
        "psf_params": psf_params[:, 2:3].cuda(),
        "tile_catalog": true_catalog,
    }

    return true_catalog, batch

Simulate one image for each catalog

Plot the r-band of sample images, one with each center

In [None]:
true_catalog, batch = gen_loc_shift_data(40, flux=10.0, n=10)

i = 0
plt.set_cmap('viridis')
plt.imshow(batch["images"][i, 0].cpu().numpy())
plt.grid(color='white', linewidth=1, linestyle='dotted')
plt.xticks(np.arange(20) * 4 + 3.5);
plt.yticks(np.arange(20) * 4 + 3.5);
ax = plt.gca()
ax.set_xticklabels([]);
ax.set_yticklabels([]);
ax.tick_params(axis='both', which='both', length=0);

In [None]:
loc1 = (true_catalog["locs"][i, 10, 10, 0] + 10 - 8) * 4 - 0.5
loc1 = loc1.cpu().numpy()

plt.imshow(batch["images"][i, 0, 32:48, 32:48].cpu().numpy())
plt.grid(color='white', linewidth=1, linestyle='dotted')
plt.xticks(np.arange(4) * 4 + 3.5)
plt.yticks(np.arange(4) * 4 + 3.5)
#plt.plot(loc1[1], loc1[0], 'ro', markersize=7)
ax = plt.gca()
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.tick_params(axis='both', which='both', length=0)
plt.show()

Clear the GPU memory so we don't run out in case we re-run this notebook

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory_allocated() / 1e9  # show current memory usage in GB
encoder = encoder.cuda()

### Checkerboard Tiling

Predict two catalogs: one the mode of the variational distribution and the other a sample of it

In [None]:
encoder.use_checkerboard = True

encoder.predict_mode_not_samples = True
mode_cat = encoder.predict_step(batch, 0, 0)

encoder.predict_mode_not_samples = False
sample_cat = encoder.predict_step(batch, 0, 0)

Tthe source should be found in tile [10,10], moving in the first dimension
from 0 to 0.5. We restrict our attention to a 2 tiles per image to avoid spurious detections,
which are inevitable in a large enough image due to Gaussian noise.

For all three locations, all 20 replicates show 1 source in the variational distribution mode.

In [None]:
mode_cat["n_sources"][:, 9:11, 10].sum([1])

The sample cat isn't as consistent, but there's clear dependence on the location (border vs interior). The twos are low-flux detections.

In [None]:
sample_cat["n_sources"][:, 9:11, 10].sum([1])

In [None]:
sample_cat2 = sample_cat.filter_by_flux(min_flux=5.0, band=0)
sample_cat2["n_sources"][:, 9:11, 10].sum([1])

### Independent Tiling

In [None]:
encoder.use_checkerboard = False

encoder.predict_mode_not_samples = True
mode_cat = encoder.predict_step(batch, 0, 0)

encoder.predict_mode_not_samples = False
sample_cat = encoder.predict_step(batch, 0, 0)

In [None]:
mode_cat["n_sources"][:, 9:11, 10].sum([1])

In [None]:
mode_cat2 = mode_cat.filter_by_flux(min_flux=5.0, band=0)
mode_cat2["n_sources"][:, 9:11, 10].sum([1])

In [None]:
sample_cat["n_sources"][:, 9:11, 10].sum([1])

In [None]:
sample_cat2 = sample_cat.filter_by_flux(min_flux=5.0, band=0)
sample_cat2["n_sources"][:, 9:11, 10].sum([1])

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory_allocated() / 1e9  # show current memory usage in GB

### All-pixel-shifts comparison

In [None]:
import seaborn as sns
import numpy as np

pixel_shifts = (0.0, 0.03, 0.05, 0.08, 0.1, 0.13, 0.2, 0.3, 0.8, 1.5)
pixel_shifts = np.array(pixel_shifts)
ps2 = np.flip(4 - pixel_shifts)
pixel_shifts = np.concatenate([pixel_shifts, ps2])
pixel_shifts = np.concatenate([pixel_shifts, pixel_shifts + 4, pixel_shifts + 8])
pixel_shifts += 32
pixel_shifts

encoder.predict_mode_not_samples=False

def all_pixel_shifts(flux):
    accuracy = [[], []]

    for pixel_shift in pixel_shifts:
        true_catalog, batch = gen_loc_shift_data(pixel_shift, flux=flux)
        for use_cb in range(0, 2):
            encoder.use_checkerboard = use_cb
            acc = 0
            for i in range(10):
                sample_cat = encoder.predict_step(batch, 0, 0)
                sample_cat2 = sample_cat.filter_by_flux(flux / 2, band=0)
                ht = int((pixel_shift - 2) // 4)
                acc += (sample_cat2["n_sources"][:, ht:(ht+2), 10].sum([1]) == 1).sum().item()
            accuracy[use_cb].append(acc / 1000)
            print(f"{use_cb}, {pixel_shift}: {acc}")

    return accuracy

accuracy_bright = all_pixel_shifts(6.5)

In [None]:
def plot_accuracy(accuracy):
    sns.set_style("whitegrid")
    sns.set_context("paper", font_scale=1.5)
    plt.plot(pixel_shifts, accuracy[0], '-', label="Independent")
    plt.plot(pixel_shifts, accuracy[1], '-', label="Checkerboard")
    plt.xlabel("Star vertical position (pixels)")
    plt.ylabel("Accuracy")
    plt.ylim(0.5, 1.01)
    plt.legend()
    plt.tight_layout()
    plt.show()

plot_accuracy(accuracy_bright);

## Faint Star

An ambiguous detection (~80% detection prob) at the border and at the center.

First, we generate a true catalog.

In [None]:
accuracy_faint = all_pixel_shifts(1.3)
plot_accuracy(accuracy_faint);

## Added galaxy, bright star

Now we render the images and backgrounds.

In [None]:
true_catalog, batch = gen_loc_shift_data(40.0, add_galaxy=True)

In [None]:
i = 0
loc1 = (true_catalog["locs"][i, 10, 10, 0] + 10 - 8) * 4 - 0.5

plt.imshow(batch["images"][i, 0, 32:48, 32:48].cpu().numpy(), cmap='viridis')
plt.grid(color='white', linewidth=1, linestyle='dotted')
plt.xticks(np.arange(4) * 4 + 3.5)
plt.yticks(np.arange(4) * 4 + 3.5)
plt.plot(loc1[1], loc1[0], 'ro', markersize=7)
plt.plot(5.5, 9.5, 'ro', markersize=7)
ax = plt.gca()
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.tick_params(axis='both', which='both', length=0)
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def marginal_detections(pred_marginal):  # noqa: WPS430
    est_cat = pred_marginal.sample(use_mode=False)
    est_cat["n_sources"][i, 10, 10] = 0  #only effects the conditional
    return est_cat

pred = encoder.infer(batch, marginal_detections)

on_prob = pred["marginal"].factors["on_prob"].probs[:, :, :, 1]
# Create a square heatmap using seaborn
fig, ax = plt.subplots(figsize=(4, 4))
sns.heatmap(on_prob[i, 8:12, 8:12].cpu().numpy(), annot=True, fmt=".2f", cmap="YlGnBu", linecolor='black', linewidths=0.5, cbar=False, ax=ax)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])

# Add black border to the bottom and right side
ax.hlines([4], *ax.get_xlim(), colors='black', linewidths=2)
ax.vlines([4], *ax.get_ylim(), colors='black', linewidths=2);

In [None]:
from copy import deepcopy
on_prob = deepcopy(pred["black"].factors["on_prob"].probs[:, :, :, 1])
on_prob = torch.where(pred["white_history_mask"] > .5, on_prob, torch.nan * torch.ones_like(on_prob))
# Create a square heatmap using seaborn
fig, ax = plt.subplots(figsize=(4, 4))
sns.heatmap(on_prob[i, 8:12, 8:12].cpu().numpy(), annot=True, fmt=".2f", cmap="YlGnBu", vmax=1.0, linecolor='black', linewidths=0.5, cbar=False, ax=ax)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])

# Add black border to the bottom and right side
ax.hlines([4], *ax.get_xlim(), colors='black', linewidths=2)
ax.vlines([4], *ax.get_ylim(), colors='black', linewidths=2);

In [None]:
def marginal_detections(pred_marginal):  # noqa: WPS430
    est_cat = pred_marginal.sample(use_mode=False)
    est_cat["n_sources"][i, 10, 10] = 1
    return est_cat

pred = encoder.infer(batch, marginal_detections)

from copy import deepcopy
on_prob = deepcopy(pred["black"].factors["on_prob"].probs[:, :, :, 1])
on_prob = torch.where(pred["white_history_mask"] > .5, on_prob, torch.nan * torch.ones_like(on_prob))
# Create a square heatmap using seaborn
fig, ax = plt.subplots(figsize=(4, 4))
sns.heatmap(on_prob[i, 8:12, 8:12].cpu().numpy(), annot=True, fmt=".2f", cmap="YlGnBu", linecolor='black', vmax=1.0, linewidths=0.5, cbar=False, ax=ax)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])

# Add black border to the bottom and right side
ax.hlines([4], *ax.get_xlim(), colors='black', linewidths=2)
ax.vlines([4], *ax.get_ylim(), colors='black', linewidths=2);