## Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "3"

from astropy.io import fits
from astropy.wcs import WCS

from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf

from bliss.main import predict
from bliss.catalog import TileCatalog, FullCatalog

import torch
from pytorch_lightning.callbacks import Callback

import numpy as np
from matplotlib import pyplot as plt

torch.set_grad_enabled(False)

# ckpt = "/home/regier/bliss_output/sep21_minimalist/version_0/checkpoints/encoder_2_0.59.ckpt"
ckpt = "/home/regier/bliss_output/20250821_m2_anyorderfalse_1xval_nometrics_minimalist/version_0/checkpoints/encoder_4.ckpt"
with initialize(config_path=".", version_base=None):
    cfg0 = compose("config", {
        f"train.pretrained_weights={ckpt}",
        f"predict.weight_save_path={ckpt}",
        "cached_simulator.splits=0:80/80:90/97:98",
        "cached_simulator.num_workers=0",
        "encoder.minimalist_conditioning=True",
        "encoder.use_checkerboard=True",
        "encoder.n_sampler_colors=4",
    })

plot_title = "Independent ($K=1$)" if cfg0.encoder.n_sampler_colors == 1 else f"Autoregressive ($K=4$)"

## Load SDSS image data

In [None]:
f = fits.open('/home/regier/bliss/tests/data/sdss/2583/2/136/frame-r-002583-2-0136.fits')
w = WCS(f[0].header)

# lower-left corner of the 100x100-pixel study area is at pixel (310, 630)
w.pixel_to_world(310, 630)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(f[0].data, origin='lower', cmap='Greys_r')
print("Behold, the M2 globular cluster!")

In [None]:
logimage = np.log(f[0].data - f[0].data.min() + 1)
plt.imshow(logimage, origin='lower', cmap='Greys_r');

In [None]:
from matplotlib.patches import Rectangle

plt.imshow(logimage, origin='lower', cmap='Greys_r')
rect = Rectangle((310, 630), 100, 100, linewidth=2, edgecolor='r', facecolor='none')
_ = plt.gca().add_patch(rect)
plt.xticks([])
plt.yticks([]);

In [None]:
original = f[0].data[630:730, 310:410]

arcsinh_median = np.arcsinh((original - np.median(original)))

clipped = original.clip(max=np.quantile(original, 0.98))
arcsinh_clipped = np.arcsinh((clipped - np.median(clipped)));

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10, 10))

images = [original, arcsinh_median, arcsinh_clipped]
titles = ['original', 'arcsinc', 'arcsinc with clipping']

for i, img in enumerate(images):
    ax = axs[i]
    ax.imshow(img, origin='lower', cmap='Greys_r')
    ax.set_title(titles[i])
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

## Load and view HST predictions

In [None]:
# wget https://archive.stsci.edu/pub/hlsp/acsggct/ngc7089/hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt
hubble_cat_file = "/home/regier/hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt"
hubble_cat = np.loadtxt(hubble_cat_file, skiprows=3, usecols=(9,21,22))

hst_r_mag_all = torch.from_numpy(hubble_cat[:, 0])
ra = torch.from_numpy(hubble_cat[:, 1])
dec = torch.from_numpy(hubble_cat[:, 2])

plocs_all = FullCatalog.plocs_from_ra_dec(ra, dec, w)

In [None]:
in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)
in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)
in_bounds.sum()

In [None]:
hst_r_mag = hst_r_mag_all[in_bounds]
plocs = plocs_all[in_bounds]

In [None]:
plocs_square = plocs - torch.tensor([630, 310])

from bliss.catalog import convert_mag_to_nmgy, convert_nmgy_to_mag
hst_r_nmgy = convert_mag_to_nmgy(hst_r_mag)

# these magnitudes are about 15% off: the hubble fw606 band filter curve
#  isn't exactly the sdss r band filter curve
sdss_r_nmgy = hst_r_nmgy * 1.15
sdss_r_mag = convert_nmgy_to_mag(sdss_r_nmgy)

In [None]:
d = {
    "plocs": plocs_square.unsqueeze(0),
    "fluxes": sdss_r_nmgy.unsqueeze(0).unsqueeze(2),
    "n_sources": torch.tensor(plocs.shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs.shape[0]).unsqueeze(0).unsqueeze(2).long(),
}

In [None]:
true_cat_all = FullCatalog(100, 100, d)
true_cat_all["n_sources"].sum()

In [None]:
true_tile_cat_all = true_cat_all.to_tile_catalog(2, 11)
true_tile_cat_all["n_sources"].sum()

In [None]:
is_bright = sdss_r_mag < 22.565
is_bright.sum(), convert_mag_to_nmgy(22.565)

In [None]:
# target number of sources in 1114
cutoff_mag_1114 = 22.130
(sdss_r_mag < cutoff_mag_1114).sum()
cutoff_nmgy_1114 = convert_mag_to_nmgy(cutoff_mag_1114)

In [None]:
d = {
    "plocs": plocs_square[is_bright].unsqueeze(0),
    "fluxes": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "n_sources": torch.tensor(plocs[is_bright].shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs[is_bright].shape[0]).unsqueeze(0).unsqueeze(2).long(),
}
true_cat = FullCatalog(100, 100, d)
true_cat["n_sources"].sum()

In [None]:
true_tile_cat = true_cat.to_tile_catalog(2, 5)
true_tile_cat["n_sources"].sum()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10, 10))

cutoffs = [20, 22.065, 24]

for i, cutoff in enumerate(cutoffs):
    is_bright = sdss_r_mag < cutoff
    plocs_square_bright = plocs_square[is_bright]
    ax = axs[i]
    ax.imshow(arcsinh_clipped, origin='lower', cmap='Greys_r')
    ax.scatter(plocs_square_bright[:, 1], plocs_square_bright[:, 0], s=5, c='r')
    ax.set_title(f"magnitude < {cutoff}")
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 100)
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()


## BLISS performance on M2

In [None]:
preds = predict(cfg0.predict)
bliss_cat, = preds.values()  # singleton dict
bliss_cat = bliss_cat.symmetric_crop(3).to_full_catalog(cfg0.encoder.tile_slen)

matcher = instantiate(cfg0.encoder.matcher)
mode_metrics = instantiate(cfg0.encoder.mode_metrics)

matching = matcher.match_catalogs(true_cat, bliss_cat)
c_dp_real = mode_metrics(true_cat, bliss_cat, matching)

p = c_dp_real["detection_precision"].item()
r = c_dp_real["detection_recall"].item()
f = c_dp_real["detection_f1"].item()
print(f"precision: {p:.4}  recall: {r:.4}  f1: {f:.4}")

In [None]:
dataset = instantiate(cfg0.surveys.sdss, load_image_data=True)
dataset.prepare_data()
sdss_frame, = dataset.predict_dataloader()
obs_image_padded = sdss_frame["images"][:, 2:3, 624:736, 304:416]
obs_image_cropped = obs_image_padded[0, 0, 6:-6, 6:-6]

In [None]:
batch = {
    "images": obs_image_padded.expand(50, -1, -1, -1).cuda(),
}

cfg_sample = OmegaConf.merge(cfg0, {"encoder": {"predict_mode_not_samples": False}})

encoder = instantiate(cfg_sample.train.encoder).cuda()
enc_state_dict = torch.load(cfg_sample.train.pretrained_weights)
enc_state_dict = enc_state_dict["state_dict"]
encoder.load_state_dict(enc_state_dict)
encoder.eval()

counts = []
for i in range(20):
    sample_cat = encoder.predict_step(batch, 0)
    sample_cat = sample_cat.symmetric_crop(3)
    bliss_sources = (sample_cat.on_fluxes > cutoff_nmgy_1114).sum([1,2,3,4])
    counts.append(bliss_sources)


In [None]:
cs = torch.cat(counts).float()
c_ci_real = (cs.quantile(0.05).item(), cs.mean().item(), cs.quantile(0.95).item())
print(c_ci_real)

## BLISS performance on synthetic data

In [None]:
class NllCallback(Callback):
    def __init__(self):
        super().__init__()
        self.nlls = []
        self.precisions = []
        self.recalls = []
        self.f1s = []

    def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        nlls = pl_module.compute_sampler_nll(batch).sum([1,2])
        self.nlls.append(nlls)

        pl_module.update_metrics(batch, batch_idx)
        m = pl_module.mode_metrics["detection_performance"].compute()
        self.precisions.append(m["detection_precision"].item())
        self.recalls.append(m["detection_recall"].item())
        self.f1s.append(m["detection_f1"].item())
        pl_module.mode_metrics.reset()

    def report(self):
        nlls = torch.cat(self.nlls)

        nll_sd = nlls.std().item() / np.sqrt(nlls.size(0))
        print(f"Mean NLL: {nlls.mean().item():.2f} ({nll_sd:.2f})")

        precision_sd = np.std(self.precisions) / np.sqrt(len(self.precisions))
        print(f"Mean precision: {np.mean(self.precisions):.4f} ({precision_sd:.4f})")

        recall_sd = np.std(self.recalls) / np.sqrt(len(self.recalls))
        print(f"Mean recall: {np.mean(self.recalls):.4f} ({recall_sd:.4f})")

        f1_sd = np.std(self.f1s) / np.sqrt(len(self.f1s))
        print(f"Mean F1: {np.mean(self.f1s):.4f} ({f1_sd:.4f})")


data_module = instantiate(cfg0.train.data_source)
data_module.setup("fit")
data_module.setup("test")
test_dl = data_module.test_dataloader()

encoder = instantiate(cfg0.train.encoder)
enc_state_dict = torch.load(cfg0.train.pretrained_weights)
if cfg0.train.pretrained_weights.endswith(".ckpt"):
    enc_state_dict = enc_state_dict["state_dict"]
encoder.load_state_dict(enc_state_dict)

nll_cb = NllCallback()
trainer = instantiate(cfg0.predict.trainer, callbacks=[nll_cb])
trainer.predict(encoder, dataloaders=[test_dl], return_predictions=False)
nll_cb.report()

In [None]:
def synthetic_metric(cfg):
    cfg = cfg.copy()
    cfg.train.data_source.nontrain_transforms[2]['min_flux'] = 1.5
    data_module = instantiate(cfg.train.data_source)
    data_module.setup("fit")
    data_module.setup("test")

    encoder = instantiate(cfg.train.encoder)
    enc_state_dict = torch.load(cfg.train.pretrained_weights)
    if cfg.train.pretrained_weights.endswith(".ckpt"):
        enc_state_dict = enc_state_dict["state_dict"]
    encoder.load_state_dict(enc_state_dict)

    trainer = instantiate(cfg0.predict.trainer)
    x = trainer.test(encoder, datamodule=data_module)

    # could use x instead here instead, but need to output bins
    # as a vector
    dp = encoder.mode_metrics["detection_performance"]
    two_pt = encoder.sample_metrics["two_point"]

    return dp.compute(), two_pt.compute()

In [None]:
c_dp_synthetic, c_two_pt_synthetic = synthetic_metric(cfg0)

### Assess the two-point correlation function

In [None]:
radii = [float(r[2:]) for r in c_two_pt_synthetic.keys()]

plt.figure(figsize=(8, 6))
plt.plot(radii, c_two_pt_synthetic.values(), marker="s", label=f"Rank $K$={cfg0.encoder.n_sampler_colors} checkerboard")
plt.axhline(y=0, color='black', linestyle='dotted', label='ideal')
plt.legend()
plt.xscale("log")
plt.xlabel("Distance (pixels)")
plt.ylabel("Two-point correlation")
plt.xticks([0.1, 0.3, 1, 3], labels=["0.1", "0.3", "1", "3"])
plt.tight_layout()

In [None]:
class CiCallback(Callback):
    def __init__(self):
        super().__init__()
        self.residual_sources = []

    def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        bliss_cat = outputs
        bliss_cat = bliss_cat.symmetric_crop(3)
        bliss_sources = (bliss_cat.on_fluxes > cutoff_nmgy_1114).sum([1,2,3,4])
        true_fluxes = TileCatalog(batch["tile_catalog"]).symmetric_crop(3).on_fluxes
        true_sources =  (true_fluxes > cutoff_nmgy_1114).sum([1,2,3,4])
        residual_sources = true_sources - bliss_sources 
        self.residual_sources.append(residual_sources)

    def report(self):
        counts = torch.cat(self.residual_sources).float()
        mean = counts.mean().item()
        mean_std = counts.std().item() / np.sqrt(counts.size(0))
        print(f"Mean residual sources: {mean:.2f} ({mean_std:.2f})")

def synthetic_calibration(cfg):
    cfg_sample = OmegaConf.merge(cfg, {"encoder": {"predict_mode_not_samples": False}})

    data_module = instantiate(cfg.train.data_source)
    data_module.setup("test")
    test_dl = data_module.test_dataloader()

    encoder = instantiate(cfg_sample.train.encoder)
    enc_state_dict = torch.load(cfg_sample.train.pretrained_weights)
    if cfg_sample.train.pretrained_weights.endswith(".ckpt"):
        enc_state_dict = enc_state_dict["state_dict"]
    encoder.load_state_dict(enc_state_dict)

    ci_cb = CiCallback()
    trainer = instantiate(cfg.predict.trainer, callbacks=[ci_cb])
    trainer.predict(encoder, dataloaders=[test_dl], return_predictions=False)
    ci_cb.report()

synthetic_calibration(cfg)


## Assess the model and BLISS fit visually

In [None]:
decoder = instantiate(cfg0.decoder, with_noise=False)
truth_images, _psf_params = decoder.render_images(true_tile_cat_all)
true_recon_all = truth_images[0, 2]

In [None]:
preds = predict(cfg0.predict)
bliss_cat, = preds.values()  # singleton dict
bliss_cat["source_type"] = torch.zeros_like(bliss_cat["fluxes"], dtype=torch.long)
bliss_images, _psf_params = decoder.render_images(bliss_cat)
bliss_recon = bliss_images[0, 2][6:-6, 6:-6]

In [None]:
titles = ['SDSS image', 'HST reconstruction', 'Our reconstruction']

images = [obs_image_cropped, true_recon_all, bliss_recon]
images = [img.clip(max=obs_image_cropped.quantile(0.99)) for img in images]
images = [np.arcsinh((img - np.median(obs_image_cropped) / 50)) for img in images]

fig, axs = plt.subplots(1, 3, figsize=(10, 10))

vmin = min(img.min() for img in images)
vmax = max(img.max() for img in images)

plt.set_cmap("viridis")

for i, img in enumerate(images):
    ax = axs[i]
    ax.imshow(img, origin='lower', vmin=vmin, vmax=vmax, cmap='Greys_r')
    ax.set_title(titles[i])
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()

## Flux Prior Elicitation

In [None]:
oob = (plocs_all[:, 1] > 210) & (plocs_all[:, 1] < 510)
oob &= (plocs_all[:, 0] > 530) & (plocs_all[:, 0] < 830)
oob &= ~in_bounds
oob.sum() # some of this region (about half) is outside of our HST cat coverage

In [None]:
hst_oob = hst_r_mag_all[oob]
hst_oob_nmgy = convert_mag_to_nmgy(hst_oob) * 1.15
hst_oob_mag = convert_nmgy_to_mag(hst_oob_nmgy)
training_data = hst_oob_nmgy[hst_oob_mag < 24]
training_data.shape[0], training_data.max().item()

In [None]:
from scipy.stats import truncpareto
alpha, trunc, loc, scale = truncpareto.fit(training_data)
alpha, trunc, loc, scale

In [None]:
from scipy.stats import truncpareto

x = np.logspace(hst_oob_nmgy.log10().min(), hst_oob_nmgy.log10().max(), num=100)

_ = plt.plot(x, truncpareto.pdf(x, alpha, trunc, loc, scale), 'r-', lw=3, alpha=0.7, label='new prior')
_ = plt.plot(x, truncpareto.pdf(x, 0.5, 1014, 0, 0.63), 'g-', lw=3, alpha=0.7, label='old prior')
_ = plt.hist(hst_oob_nmgy, log=True, bins=100, label='star_fluxes histogram', density=True)
plt.legend()

In [None]:
from scipy.stats import truncpareto

x = np.linspace(hst_oob_nmgy.log10().min(), 100, num=100)

_ = plt.plot(x, truncpareto.pdf(x, 0.01, 100, 3.0, 3.0), 'g-', lw=3, alpha=0.7, label='old prior')
plt.legend()

In [None]:
_ = plt.plot(x, truncpareto.pdf(x, alpha, trunc, loc, scale), 'r-', lw=3, alpha=0.7, label='new prior')
_ = plt.plot(x, truncpareto.pdf(x, 0.5, 1014, 0, 0.63), 'g-', lw=3, alpha=0.7, label='old prior')
plt.legend()
plt.loglog()

In [None]:
samples = truncpareto.rvs(alpha, trunc, loc, scale, size=1500)
sorted(samples, reverse=True)[:10]

In [None]:
prior = instantiate(cfg0.prior)
prior.sample().on_fluxes[0, :, :, :, 2].view(-1).topk(100)[0]

In [None]:
# estimate rate with oob data
(hst_oob_mag < 24).sum() / (4 * 1e4)

In [None]:
# the per-tile source density
(1114 / 50**2) / (1 - truncpareto.cdf(cutoff_nmgy_1114, alpha, trunc, loc, scale))

## Semi-synthetic M2 inference

In [None]:
from copy import deepcopy

decoder = instantiate(cfg0.decoder, with_noise=False)

#TODO: crop 6 pixels from each side (to 100x100)
d2 = deepcopy(true_cat_all)
d2["plocs"] += 6
true_cat_pad = FullCatalog(112, 112, d2)

truth_images, _ = decoder.render_images(true_cat_pad.to_tile_catalog(2, 11))

In [None]:
semisynth_image = truth_images[:, 2:3]
plt.imshow(semisynth_image[0, 0, 6:-6, 6:-6].numpy(), origin='lower', cmap='Greys_r');

In [None]:
batch = {
    "images": semisynth_image.cuda(),
}

def semisynth_dp(cfg):
    cfg_sample = OmegaConf.merge(cfg, {"encoder": {"predict_mode_not_samples": True}})

    encoder = instantiate(cfg_sample.train.encoder).cuda()
    enc_state_dict = torch.load(cfg_sample.train.pretrained_weights)
    enc_state_dict = enc_state_dict["state_dict"]
    encoder.load_state_dict(enc_state_dict)
    encoder.eval()

    bliss_cat = encoder.predict_step(batch, 0)

    bliss_cat = bliss_cat.symmetric_crop(3)
    bliss_cat = bliss_cat.to_full_catalog(cfg_sample.encoder.tile_slen)
    true_cat_cuda = true_cat_all.to("cuda:0")
    matching = encoder.matcher.match_catalogs(true_cat_cuda, bliss_cat)
    return encoder.mode_metrics(true_cat_cuda, bliss_cat, matching)

c_dp_semisynth = semisynth_dp(cfg0)

In [None]:
mbc = cfg0.star_metrics.detection_performance.base_flux_bin_cutoffs
mbc = convert_nmgy_to_mag(torch.tensor(mbc)).tolist()
mbc.reverse()

titles = ["Fully Synthetic", "Semi-Synthetic", "Real"]
dp_metrics = [c_dp_synthetic, c_dp_semisynth, c_dp_real]

xlabels = [f"[{mbc[i]:.1f}, {mbc[i+1]:.1f}]" for i in range(len(mbc) - 1)]
xlabels = [f"< {mbc[0]:.1f}"] + xlabels + ["> " + str(mbc[-1])]
xlabels = xlabels[:-1]

fig, axs = plt.subplots(1, 2, figsize=(10, 5))

for i, dp in enumerate(dp_metrics):
    recall = [v.item() for k, v in dp.items() if k[:-1] == "detection_recall_bin_"]
    precision = [v.item() for k, v in dp.items() if k[:-1] == "detection_precision_bin_"]
    axs[0].plot(recall, marker="s", label=titles[i])
    axs[1].plot(precision, marker="s", label=titles[i])

axs[0].set_title("Recall")
axs[1].set_title("Precision")

for ax in axs:
    ax.set_xticks(range(len(xlabels)))
    ax.set_xticklabels(xlabels, rotation=45)
    ax.set_ylim([0, 1])
    ax.legend()

plt.tight_layout()

## PPC confusion matrices

In [None]:
from pytorch_lightning.callbacks import Callback
import seaborn as sns
from matplotlib import pyplot as plt
from bliss.catalog import TileCatalog


class VsbcCallback(Callback):
    def __init__(self):
        super().__init__()
        self.confusion_matrix = torch.zeros((5, 5), dtype=torch.int64)

    def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        tc = TileCatalog(batch["tile_catalog"])
#        true_n_sources = batch["tile_catalog"]["n_sources"].cpu().clamp(0, 1)
#        sampled_n_sources = outputs["n_sources"].cpu().clamp(0, 1)
#        true_n_sources = (tc.on_fluxes > 10).sum([-2, -1])
#        sampled_n_sources = (outputs.on_fluxes > 10).sum([-2, -1])
        frac = 1/8
        true_left = ((tc["locs"][:, :, :, :, 0] < frac) * tc.is_on_mask).sum(-1).clamp(0, 2)
        true_right = ((tc["locs"][:, :, :, :, 0] > (1-frac)) * tc.is_on_mask).sum(-1).clamp(0, 2)
        true_n_sources = true_left[:, 1:, :] + true_right[:, :-1, :]

        sampled_left = ((outputs["locs"][:, :, :, :, 0] < frac) * outputs.is_on_mask).sum(-1).clamp(0, 2)
        sampled_right = ((outputs["locs"][:, :, :, :, 0] > (1 - frac)) * outputs.is_on_mask).sum(-1).clamp(0, 2)
        sampled_n_sources = (sampled_left[:, 1:, :] + sampled_right[:, :-1, :])

#        vertical_true_sums = true_n_sources.unfold(1, 2, 1).sum(-1)
#        vertical_sampled_sums = sampled_n_sources.unfold(1, 2, 1).sum(-1)
#        horizontal_true_sums = vertical_true_sums.unfold(2, 2, 1).sum(-1).view(-1)
#        horizontal_sampled_sums = vertical_sampled_sums.unfold(2, 2, 1).sum(-1).view(-1)
#        indices = (horizontal_true_sums, horizontal_sampled_sums)

        indices = (true_n_sources.view(-1), sampled_n_sources.view(-1))
        values = torch.ones(indices[0].size(0), dtype=torch.int64)
        self.confusion_matrix.index_put_(indices, values, accumulate=True)

    def report(self):
        print(self.confusion_matrix)

vsbc_callback = VsbcCallback()
trainer = instantiate(cfg0.train.trainer, callbacks=[vsbc_callback])
encoder = instantiate(cfg0.train.encoder, predict_mode_not_samples = False)
enc_state_dict = torch.load(cfg0.train.pretrained_weights)
if cfg0.train.pretrained_weights.endswith(".ckpt"):
    enc_state_dict = enc_state_dict["state_dict"]
encoder.load_state_dict(enc_state_dict)

data_source = instantiate(cfg0.train.data_source)
data_source.setup("test")

trainer.predict(encoder, dataloaders=[data_source.test_dataloader()], return_predictions=False)

vsbc_callback.report()


In [None]:
def make_confusion_matrix(counts, title=""):
    counts_list = counts.tolist()
    annotations = [[f'{val:,}' for val in row] for row in counts_list]
    ax = sns.heatmap(
        (counts + 1).log(),
        annot=annotations,
        fmt='s',
        cbar=False,
        cmap='Blues',
        annot_kws={"fontsize": 14},  # Adjust annotation font size here
    )
    ax.set_xlabel("Predicted source count", fontsize=16)  # Adjust x-axis label font size here
    ax.set_ylabel("Actual source count", fontsize=16)  # Adjust y-axis label font size here
    ax.set_title(title, pad=20, fontsize=18)  # Adjust title font size here
    ax.tick_params(axis='both', which='major', labelsize=15)
    plt.show()
    
make_confusion_matrix(vsbc_callback.confusion_matrix, title=plot_title)

In [None]:
def make_factor_matrix(counts, title=""):
    counts_t = counts.transpose(0, 1)
    diff_factor = (counts - counts_t) / torch.min(counts, counts_t)
    annotations = [[f'{val:.1f}\u00D7' for val in row] for row in diff_factor.tolist()]

    for i in range(len(annotations)):
        annotations[i][i] = ""
        for j in range(len(annotations)):
            if torch.min(counts[i, j], counts[j, i]) < 100:
                annotations[i][j] = ""
                diff_factor[i, j] = 0

    ax = sns.heatmap(
        (1 + diff_factor.abs()).log(),
        annot=annotations,
        fmt='s',
        cbar=False,
        cmap='YlOrRd',
        vmin=0,
        vmax=1.5,
        annot_kws={"fontsize": 15},  # Adjust annotation font size here
    )
    ax.set_xlabel("Predicted source count", fontsize=16)  # Adjust x-axis label font size here
    ax.set_ylabel("Actual source count", fontsize=16)  # Adjust y-axis label font size here
    ax.set_title(title, pad=20, fontsize=18)  # Adjust title font size here
    ax.tick_params(axis='both', which='major', labelsize=15)
    plt.show()

make_factor_matrix(vsbc_callback.confusion_matrix, title=plot_title)