In [None]:
from deephpx.ili import (
    HealpixEmbeddingNet,
    make_ili_torch_loader_from_files,
    SplitSpec,
    train_lampe_posterior,
)
import ili
import numpy as np


In [None]:

maps_glob = "/path/to/maps/*.fits.gz"     # or *.fits or *.npy
theta_path = "/path/to/theta.npy"         # shape (N, D)

loader = make_ili_torch_loader_from_files(
    maps=maps_glob,
    theta=theta_path,
    batch_size=32,
    split=SplitSpec(val_fraction=0.1, seed=0),
    num_workers=4,
    pin_memory=True,
    cache_dir="/tmp/deephpx_cache",        # optional but recommended
    output_ordering="NEST",
)



In [None]:
# Build prior with vector low/high of shape (D,)
theta = np.load(theta_path)
D = theta.shape[1]
prior = ili.utils.Uniform(low=-np.ones(D), high=np.ones(D), device="cuda")

embedding_net = HealpixEmbeddingNet(
    nside=64,
    levels=3,
    in_channels=1,
    conv_channels=(16, 32, 64),
    embedding_dim=128,
    K=3,
    pool="average",
    global_pool="mean",
    device="cuda",
)


In [None]:

posterior, summaries, runner = train_lampe_posterior(
    loader,
    prior=prior,
    embedding_net=embedding_net,
    model="maf",
    nde_kwargs=dict(hidden_features=128, num_transforms=5, x_normalize=True, theta_normalize=True),
    train_args=dict(learning_rate=5e-4, stop_after_epochs=20, clip_max_norm=5.0),
    out_dir="./_out",
    device="cuda",
)
