# Multiple light sources

Two bright (easily detected) stars in adjacent tiles. Comparing samples with independent and dependent tiling.

### Preliminaries

In [None]:
%matplotlib inline

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "3"

from pathlib import Path
import torch
from hydra import initialize, compose
from hydra.utils import instantiate
from bliss.encoder.encoder import Encoder
from bliss.catalog import TileCatalog
from matplotlib import pyplot as plt

In [None]:
environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path="../../bliss/conf", version_base=None):
    cfg = compose("base_config")

encoder: Encoder = instantiate(cfg.encoder)
encoder.load_state_dict(torch.load("../../data/pretrained_models/clahed_logged_20percent.pt"))
encoder.eval()
encoder.cuda()
torch.set_grad_enabled(False)

simulator = instantiate(cfg.simulator)

### Generating Data

In [None]:
n = 50
n_sources = torch.zeros(n, 20, 20, dtype=torch.long)
n_sources[:, 9, 10] = 1
n_sources[:, 10, 10] = 1

locs = torch.ones(n, 20, 20, 1, 2) * 0.5
locs[:, 9, 10, 0, 0] = 0.7
locs[:, 10, 10, 0, 0] = 0.01

source_type = torch.zeros(n, 20, 20, 1, 1, dtype=torch.long)

true_catalog_dict = {
    "n_sources": n_sources,
    "source_type": source_type,
    "locs": locs,
    "star_fluxes": torch.ones(n, 20, 20, 1, 5) * 4,  # 21 magnitude
    "galaxy_fluxes": torch.ones(n, 20, 20, 1, 5),
}
true_catalog = TileCatalog(4, true_catalog_dict)

In [None]:
rcfs, rcf_indices = simulator.randomized_image_ids(true_catalog.n_sources.size(0))

In [None]:
image, background, _, _ = simulator.simulate_images(true_catalog, rcfs, rcf_indices)

In [None]:
import numpy as np

plt.imshow(image[1, 2, 32:48, 32:48].numpy())
plt.xticks(np.arange(4) * 4 + 3.5)
plt.yticks(np.arange(4) * 4 + 3.5)
ax = plt.gca()
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.tick_params(axis='both', which='both', length=0)
plt.colorbar();

In [None]:
import numpy as np

plt.imshow(image[1, 2, 32:48, 32:48].numpy())
plt.grid(color='white', linewidth=1, linestyle='dotted')
plt.xticks(np.arange(4) * 4 + 3.5)
plt.yticks(np.arange(4) * 4 + 3.5)

plt.plot(7.5 + 0.5 * 4, 3.5 + locs[1, 9, 10, 0, 0] * 4, 'ro', markersize=7)
plt.plot(7.5 + 0.5 * 4, 7.5 + locs[1, 10, 10, 0, 0] * 4, 'ro', markersize=7)
ax = plt.gca()
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.tick_params(axis='both', which='both', length=0)
plt.colorbar();

In [None]:
batch = {"images": image.cuda(), "background": background.cuda()}

### Comparing joint and marginal

In [None]:
dependent_cat = encoder.sample(batch, use_mode=False)

In [None]:
encoder.use_checkerboard = False
independent_cat = encoder.sample(batch, use_mode=False)

In [None]:
dependent_cat["n_sources"][:, 8:10, 9].sum(dim=0)

In [None]:
independent_cat["n_sources"][:, 8:10, 9].sum(dim=0)

In [None]:
(
    dependent_cat["n_sources"][:, 8:10, 9].sum(1).float().mean().item(),
    independent_cat["n_sources"][:, 8:10, 9].sum(1).float().mean().item()
)

In [None]:
# percent of detections misclassified as galaxies
is_on = independent_cat["n_sources"][:, 8:10, 9]
((independent_cat["source_type"][:, 8:10, 9, 0, 0] == 1) * is_on).sum() / is_on.sum()

In [None]:
# percent of detections misclassified as galaxies
is_on = dependent_cat["n_sources"][:, 8:10, 9]
((dependent_cat["source_type"][:, 8:10, 9, 0, 0] == 1) * is_on).sum() / is_on.sum()

Conclusions from several runs:
 * When both star are 4 pixels apart, joint is slightly likely to detect both. This is surprising because only joint can surpress one light source if has already detected another nearby. But probably 4 pixels is so far that there's minimal ambiguity.
 * When stars get closer, to within 2.4 pixels, there's quite a bit of ambiguity. Many stars are classified as galaxies, and then only one galaxy is detected. The flux needs to increase a lot (to 4 or so, whereas only 0.5 flux is needed for a single detection) before close stars are both usually (>50%) detected.
 * Joint seems to make many fewer (~30% fewer) misclassifications of stars as galaxies.