In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from os import environ
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from hydra import compose, initialize
from hydra.utils import instantiate

from bliss.catalog import TileCatalog

In [None]:
# set bliss home directory and load config
environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path=".", version_base=None):
    cfg = compose("config")

# Load data

In [None]:
with open("data/synthetic_image.pt", "rb") as f:
    data = torch.load(f)

dataloader = DataLoader(data, batch_size=1, shuffle=False)
batch = next(iter(dataloader))

In [None]:
tile_cat = TileCatalog(4, batch["tile_catalog"])
full_cat = tile_cat.to_full_catalog()

# generate target catalog by restricting to brightest source in tile and filtering by detectable flux
target_cat = tile_cat.get_brightest_sources_per_tile()
target_cat = target_cat.filter_tile_catalog_by_flux(min_flux=cfg.encoder.min_flux_threshold)
target_cat = target_cat.to_full_catalog()

# Load Model

In [None]:
model = instantiate(cfg.encoder)
model.load_state_dict(torch.load("/data/scratch/aakash/models/multi_source/single_band_filtered_flux.pt"))
model.eval();

In [None]:
# make predictions on batch
results = model.predict_step(batch, None, None)

# Plot Results

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

# Plot image
image = batch["images"][0, 2]
h, w = image.shape
ax.matshow(image)

# Plot locs
est_cat = results["est_cat"].to_full_catalog()

full_cat.plot_plocs(ax, 0, "all", c="r", s=30, marker="X", linewidths=0.5, edgecolors="k", label="Filtered Out")
target_cat.plot_plocs(ax, 0, "all", c="b", s=30, marker="X", linewidth=0.5, edgecolors="w", label="Target")
est_cat.plot_plocs(ax, 0, "all", 4, c="y", s=30, marker="P", linewidth=0.5, edgecolors="k", label="Predicted")

# Show grid and legend 
ticks = np.arange(-0.5, h - 0.5, 4)
labels = (ticks + 0.5).astype(int)
ax.set_xticks(ticks, labels)
ax.set_yticks(ticks, labels)
ax.grid(linestyle="--")
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=3, bbox_to_anchor=(0, -0.01, 1, 1), fontsize=10)
fig.tight_layout()