In [None]:
import torch
import matplotlib.pyplot as plt
from hydra import compose, initialize

import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parents[1]))

from maps_to_cosmology.encoder import Encoder
from maps_to_cosmology.datamodule import ConvergenceMapsModule

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Initialize the config:

In [None]:
with initialize(config_path="../configs", version_base=None):
    cfg = compose("train_npe")

Instantiate the encoder using the checkpoint from a previous training run:

In [None]:
ckpt = "/data/scratch/blissWL_checkpoints/example_encoder.ckpt"
encoder = Encoder.load_from_checkpoint(ckpt)
encoder.eval()
encoder.to(device)

Instantiate the test dataloader:

In [None]:
datamodule = ConvergenceMapsModule(
    data_dir=cfg.paths.data_dir,
    batch_size=cfg.convergence_maps.batch_size,
    num_workers=cfg.convergence_maps.num_workers,
    val_split=cfg.convergence_maps.val_split,
    test_split=cfg.convergence_maps.test_split,
    seed=cfg.seed,
)
datamodule.setup()
test_loader = datamodule.test_dataloader()

Load in the convergence maps from the test set and evaluate the encoder on them. Also load the correspoding cosmological parameters from the test set:

In [None]:
posterior_means = []
posterior_stdevs = []
true_params = []

with torch.no_grad():
    for maps, params in test_loader:
        maps = maps.to(device)
        
        out = encoder(maps)
        posterior_mean = out[:, 0::2]
        posterior_means.append(posterior_mean.cpu())
        
        posterior_stdev = out[:, 1::2].clamp(-10, 10).exp().sqrt()
        posterior_stdevs.append(posterior_stdev.cpu())
        
        true_params.append(params)

posterior_means = torch.cat(posterior_means, dim=0)
posterior_stdevs = torch.cat(posterior_stdevs, dim=0)
true_params = torch.cat(true_params, dim=0)

Plot the true cosmological parameters versus the corresponding posterior means:

In [None]:
param_names = ["omega_c", "omega_b", "sigma_8", "h_0", "n_s", "w_0"]

fig, axes = plt.subplots(2, 3, figsize=(12, 8))

for i, (ax, name) in enumerate(zip(axes.flat, param_names)):
    ax.scatter(true_params[:, i], posterior_means[:, i], alpha=0.5, s=10)
    lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]),
            max(ax.get_xlim()[1], ax.get_ylim()[1])]
    ax.plot(lims, lims, 'k--', alpha=0.5)
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    ax.set_xlabel(f"True {name}")
    ax.set_ylabel(f"Posterior mean {name}")
    ax.set_title(name)

plt.tight_layout()