### Testing SBI's Implementation of the Simformer for Model Comparison

You will need to run the following commands to pull the simformer branch to your cloned 'sbi' repository.

```bash

git fetch origin pull/1621/head:simformer
git checkout simformer

```

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np

from synference import SBI_Fitter, load_unc_model_from_hdf5

### Load SBI Fitter, create noise models and feature arrays


In [None]:
bpass_path = "/home/tharvey/work/synference/libraries/grid_BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode.hdf5"  # noqa: E501

fitter_bpass = SBI_Fitter.init_from_hdf5(
    "BPASS_Chab_DenseBasis_SFH_0.01_z_14_logN_5.0_Calzetti_v3_multinode", bpass_path
)

noise_model_path = "/home/tharvey/work/synference/priv/data/JOF_psfmatched_asinh_noise_model.h5"
noise_models = load_unc_model_from_hdf5(noise_model_path)

hst_filters = ["F435W", "F606W", "F775W", "F814W", "F850LP"]

nm = {f"JWST/NIRCam.{i}": j for i, j in noise_models.items() if i not in hst_filters}

fa_input = dict(
    extra_features=["redshift"],
    empirical_noise_models=nm,
    include_errors_in_feature_array=False,
    scatter_fluxes=False,
    # normed_flux_units="asinh",
    photometry_to_remove=[
        "CTIO/DECam.u",
        "CTIO/DECam.g",
        "CTIO/DECam.r",
        "CTIO/DECam.i",
        "CTIO/DECam.z",
        "CTIO/DECam.Y",
        "LSST/LSST.u",
        "LSST/LSST.g",
        "LSST/LSST.r",
        "LSST/LSST.i",
        "LSST/LSST.z",
        "LSST/LSST.Y",
        "PAN-STARRS/PS1.g",
        "PAN-STARRS/PS1.r",
        "PAN-STARRS/PS1.i",
        "PAN-STARRS/PS1.w",
        "PAN-STARRS/PS1.z",
        "PAN-STARRS/PS1.y",
        "Paranal/VISTA.Z",
        "Paranal/VISTA.Y",
        "Paranal/VISTA.J",
        "Paranal/VISTA.H",
        "Paranal/VISTA.Ks",
        "Subaru/HSC.g",
        "Subaru/HSC.r",
        "Subaru/HSC.i",
        "Subaru/HSC.z",
        "Subaru/HSC.Y",
        "CFHT/MegaCam.u",
        "CFHT/MegaCam.g",
        "CFHT/MegaCam.r",
        "CFHT/MegaCam.i",
        "CFHT/MegaCam.z",
        "Euclid/VIS.vis",
        "Euclid/NISP.Y",
        "Euclid/NISP.J",
        "Euclid/NISP.H",
        "HST/ACS_WFC.F475W",
        "HST/WFC3_IR.F105W",
        "JWST/NIRCam.F070W",
        "HST/WFC3_IR.F110W",
        "HST/WFC3_IR.F125W",
        "JWST/NIRCam.F140M",
        "HST/WFC3_IR.F140W",
        "HST/WFC3_IR.F160W",
        "JWST/NIRCam.F360M",
        "JWST/NIRCam.F430M",
        "JWST/NIRCam.F460M",
        "JWST/NIRCam.F480M",
        "JWST/MIRI.F560W",
        "JWST/MIRI.F770W",
        "JWST/MIRI.F1000W",
        "JWST/MIRI.F1130W",
        "JWST/MIRI.F1280W",
        "JWST/MIRI.F1500W",
        "JWST/MIRI.F1800W",
        "JWST/MIRI.F2100W",
        "JWST/MIRI.F2550W",
        "Spitzer/IRAC.I1",
        "Spitzer/IRAC.I2",
        "Spitzer/IRAC.I3",
        "Spitzer/IRAC.I4",
        "HST/ACS_WFC.F435W",
        "HST/ACS_WFC.F606W",
        "HST/ACS_WFC.F775W",
        "HST/ACS_WFC.F814W",
        "HST/ACS_WFC.F850LP",
    ],
)

fitter_bpass.create_feature_array_from_raw_photometry(
    **fa_input, parameter_transformations={"Av": np.log10}
);

In [None]:
from synference.utils import make_serializable

In [None]:
posterior = fitter_bpass.run_single_simformer(verbose=True, sde_type="vp")

In [None]:
X_test = fitter_bpass.feature_array[fitter_bpass._test_indices]
y_test = fitter_bpass._y_test
random_sample_idx = np.random.choice(X_test.shape[0])
random_sample = X_test[random_sample_idx]

random_sample

In [None]:
make_serializable(fitter_bpass.simformer._posterior)

In [None]:
fitter_bpass.simformer._posterior.vector_field_estimator

In [None]:
import dill
import torch

# Save with dill
dill.detect.trace(True)

with open("bpass_simformer_posterior_vp.pt", "wb") as f:
    dill.dump(fitter_bpass.simformer, f)

# Unpickle using dill
with open("bpass_simformer_posterior_vp.pt", "rb") as f:
    posterior = dill.load(f)
    print(posterior)

In [None]:
# Unpickle using dill
from synference.utils import CPU_Unpickler

with open("bpass_simformer_posterior_vp.pt", "rb") as f:
    posterior = CPU_Unpickler(f).load()
    print(posterior)

In [None]:
random_sample

In [None]:
?fitter_bpass.posteriors.sample

In [None]:
samples = fitter_bpass.posteriors.sample(x=random_sample, sample_shape=(1000,), steps=1000)

In [None]:
# do a corner plot
import corner

corner.corner(
    samples.cpu().numpy(),
    labels=fitter_bpass.fitted_parameter_names,
    show_titles=True,
    truths=fitter_bpass._y_test[random_sample_idx],
);

In [None]:
val_loss = fitter_bpass.simformer._summary["validation_loss"]
train_loss = fitter_bpass.simformer._summary["training_loss"]

import matplotlib.pyplot as plt

plt.plot(train_loss, label="Train Loss")
plt.plot(val_loss, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

In [None]:
from sbi.inference.posteriors.direct_posterior import DirectPosterior

posterior = DirectPosterior(posterior_estimator=posterior, prior=fitter_bpass.prior)

In [None]:
# make edge mask
# parameters attend to themselves and redshift


fitter_bpass.simformer.build_l

# Stolen Evidence Networks from LtU-ILI

In [None]:
import logging
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# def smooth_sign(x, k=100.):
#     return 2. * torch.sigmoid(k * x) - 1


def parity_odd_power(x, alpha=2):
    """Parity odd power activation function."""
    return x * (torch.abs(x) ** (alpha - 1))


def leaky_parity_odd_power(x, alpha=2):
    """Leaky version of the parity odd power activation function."""
    return x + parity_odd_power(x, alpha)


class POPExpLoss(nn.Module):
    """Exponential loss function for evidence network training."""

    def forward(self, model_pred, model_label):
        """Exponential loss function for evidence network training."""
        model_pred = leaky_parity_odd_power(model_pred, alpha=1)
        model_pred = torch.clamp(model_pred, -50, 50)
        loss_val = torch.exp((0.5 - model_label) * model_pred)
        return torch.mean(loss_val)


class ExpLoss(nn.Module):
    """Exponential loss function for evidence network training."""

    def forward(self, model_pred, model_label):
        """Exponential loss function for evidence network training."""
        loss_val = (0.5 - model_label) * model_pred
        out = torch.logsumexp(loss_val, dim=0)
        return out


class EvidenceNetwork(nn.Module):
    """A deep neural network designed to approximate the log-Bayes factor.

    The architecture consists of an optional embedding network, followed by an
    initial fully-connected layer, a series of residual blocks, and a final
    output layer. This structure allows the network to learn complex features from
    the input data.

    Args:
        input_size (int): The dimension of the input features that are fed into the
            first fully-connected layer. This should match the output dimension
            of the `embedding_net`.
        embedding_net (nn.Module, optional): A PyTorch module to preprocess or
            embed the raw input data before it enters the main network. Defaults
            to nn.Identity().
        layer_width (int, optional): The number of neurons in the hidden layers.
            Defaults to 16.
        added_layers (int, optional): The number of residual blocks to include in
            the network. Defaults to 3.
        batch_norm_flag (int, optional): A flag to enable (1) or disable (0)
            batch normalization. Defaults to 1.
        alpha (int, optional): The exponent for the `leaky_parity_odd_power`
            activation function applied to the output. Defaults to 2.
    """

    def __init__(
        self,
        input_size,
        embedding_net=nn.Identity(),
        layer_width=16,
        added_layers=3,
        batch_norm_flag=1,
        alpha=2,
    ):
        """Initializes the EvidenceNetwork with specified architecture."""
        super().__init__()
        self.input_size = input_size
        self.embedding_net = embedding_net
        self.layer_width = layer_width
        self.added_layers = added_layers
        self.bn = batch_norm_flag
        self.alpha = alpha

        self.initial_layer = self.simple_layer(self.input_size, self.layer_width, 1)

        self.hidden_layers = nn.ModuleList(
            [self.simple_layer(self.layer_width, self.layer_width, 1) for _ in range(1)]
        )

        self.residual_layers = nn.ModuleList(
            [
                self.residual_layer(self.layer_width, self.layer_width, self.bn)
                for _ in range(self.added_layers)
            ]
        )

        self.post_residual = self.simple_layer(
            self.layer_width, self.layer_width, batch_norm_flag=0
        )
        self.output_layer = nn.Linear(self.layer_width, 1)

    def simple_layer(self, in_features, out_features, batch_norm_flag=1):
        """Creates a simple fully-connected layer with LeakyReLU activation."""
        layers = [nn.Linear(in_features, out_features), nn.LeakyReLU(0.1)]
        if batch_norm_flag == 1:
            layers.append(nn.BatchNorm1d(self.layer_width))
        return nn.Sequential(*layers)

    def residual_layer(self, in_features, out_features, batch_norm_flag=1):
        """Creates a residual block with two fully-connected layers."""
        layers = [
            self.simple_layer(in_features, out_features, batch_norm_flag),
            self.simple_layer(out_features, out_features, batch_norm_flag),
        ]
        return nn.Sequential(*layers)

    def forward(self, x):
        """Forward pass through the EvidenceNetwork."""
        # Embedding Network
        x = self.embedding_net(x)

        # Initial layer (resizing input to layer width)
        x = self.initial_layer(x)

        # Hidden layers
        for layer in self.hidden_layers:
            x = layer(x)

        # Residual layers
        for layer in self.residual_layers:
            x = x + layer(x)

        # Output layer
        x = self.post_residual(x)
        x = self.output_layer(x)
        x = 0.1 * x + 0.001
        x = leaky_parity_odd_power(x, alpha=self.alpha)
        return x


class K_EvidenceNetwork:
    """A runner class to train an EvidenceNetwork for approximating the Bayes Factor (K).

    This class handles data preparation, training loops, validation, early stopping,
    and model saving. It can optionally use a custom embedding network to process
    complex data modalities before the main fully-connected layers.

    Args:
        embedding_net (nn.Module, optional): A PyTorch module to preprocess input
            data. If None, an identity mapping is used. Defaults to None.
        layer_width (int, optional): The number of neurons in the hidden layers of
            the EvidenceNetwork. Defaults to 16.
        added_layers (int, optional): The number of residual blocks in the
            EvidenceNetwork. Defaults to 3.
        batch_norm_flag (int, optional): A flag to enable (1) or disable (0)
            batch normalization. Defaults to 1.
        alpha (int, optional): The exponent for the leaky_parity_odd_power
            activation function. Defaults to 2.
        train_args (dict, optional): A dictionary of training hyperparameters to
            override the defaults. Defaults to {}.
        device (str, optional): The device to run training on (e.g., 'cpu' or
            'cuda'). Defaults to 'cpu'.

    Attributes:
        model (EvidenceNetwork): The trained PyTorch model. This is None until
            the `train` method is successfully called.
        best_val (float): The best validation loss achieved during training.
        loss_fn (nn.Module): The loss function instance used for training.
    """

    def __init__(
        self,
        embedding_net=None,
        layer_width=16,
        added_layers=3,
        batch_norm_flag=1,
        alpha=2,
        train_args={},
        device="cpu",
    ):
        """Initializes the K_EvidenceNetwork with specified architecture and training parameters."""
        self.embedding_net = embedding_net or nn.Identity()
        self.layer_width = layer_width
        self.added_layers = added_layers
        self.batch_norm_flag = batch_norm_flag
        self.alpha = alpha
        self.train_args = dict(
            training_batch_size=32,
            learning_rate=1e-5,
            stop_after_epochs=30,
            clip_max_norm=5,
            max_epochs=int(1e4),
            validation_fraction=0.1,
        )
        self.train_args.update(train_args)
        self.device = device

        self.embedding_net.to(self.device)

        self.loss_fn = ExpLoss()
        self.best_val = float("inf")
        self.model = None

    def _loss(self, model, theta, x):
        """Compute the loss function for a given model."""
        logK = model(x)
        return self.loss_fn(logK, theta.view(-1, 1))

    def _train_epoch(self, model, train_loader, val_loader, optimizer):
        """Train a single epoch of a neural network model."""
        model.train()
        loss_train, count_train = [], 0
        for x, theta in train_loader:
            x, theta = x.to(self.device), theta.to(self.device)
            optimizer.zero_grad()
            loss = self._loss(model, theta, x)
            loss.backward()

            norm = nn.utils.clip_grad_norm_(model.parameters(), self.train_args["clip_max_norm"])

            if torch.isfinite(norm):
                optimizer.step()
            # Record
            loss_train.append(loss.item() * len(theta))
            count_train += len(theta)
        loss_train = sum(loss_train) / count_train

        model.eval()
        with torch.no_grad():
            loss_val, count_val = [], 0
            for x, theta in val_loader:
                x, theta = x.to(self.device), theta.to(self.device)
                loss_val.append(self._loss(model, theta, x).item() * len(theta))
                count_val += len(theta)
        loss_val = sum(loss_val) / count_val
        return loss_train, loss_val

    def train(self, loader1, loader2, show_progress_bars=True):
        """Trains the evidence network.

        Args:
            loader1: A data loader object for the first model. Must have a
                     `get_all_data()` method.
            loader2: A data loader object for the second model.
            show_progress_bars (bool): Whether to display tqdm progress bars.
        """
        # Aggregate data from different models, and label them
        # Assumes loaders have a 'get_all_data' method returning numpy arrays
        x1 = loader1.get_all_data().astype(np.float32)
        x2 = loader2.get_all_data().astype(np.float32)
        x = np.concatenate([x1, x2], axis=0)
        labels = np.concatenate([np.zeros(len(x1)), np.ones(len(x2))], axis=0)

        # Train/Validation split
        indices = np.random.permutation(len(x))
        val_size = int(len(x) * self.train_args["validation_fraction"])
        val_indices, train_indices = indices[:val_size], indices[val_size:]
        x_train, labels_train = x[train_indices], labels[train_indices]
        x_val, labels_val = x[val_indices], labels[val_indices]

        # Create data loaders
        train_data = TensorDataset(torch.tensor(x_train), torch.tensor(labels_train))
        val_data = TensorDataset(torch.tensor(x_val), torch.tensor(labels_val))
        train_loader = DataLoader(
            train_data,
            batch_size=self.train_args["training_batch_size"],
            shuffle=True,
            drop_last=True,
        )
        val_loader = DataLoader(
            val_data,
            batch_size=self.train_args["training_batch_size"],
            shuffle=False,
            drop_last=True,
        )

        # Determine the input dimension for the dense network
        if not isinstance(self.embedding_net, nn.Identity):
            with torch.no_grad():
                # Use a single-item batch to find the output dimension
                sample_input = next(iter(train_loader))[0].to(self.device)
                embedded_output = self.embedding_net(sample_input)
                ndim = embedded_output.shape[-1]
        else:
            # If no embedding net, the input dim is from the data itself
            ndim = x_train.shape[-1]

        # Create model, passing the embedding net
        model = EvidenceNetwork(
            input_size=ndim,
            embedding_net=self.embedding_net,
            layer_width=self.layer_width,
            added_layers=self.added_layers,
            batch_norm_flag=self.batch_norm_flag,
            alpha=self.alpha,
        )
        model.to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.train_args["learning_rate"])

        # Train model
        wait = 0
        summary = {"training_loss": [], "validation_loss": []}
        best_model_state = deepcopy(model.state_dict())

        with tqdm(
            iter(range(self.train_args["max_epochs"])),
            unit=" epochs",
            disable=not show_progress_bars,
        ) as tq:
            for epoch in tq:
                loss_train, loss_val = self._train_epoch(
                    model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    optimizer=optimizer,
                )
                if show_progress_bars:
                    tq.set_postfix(loss_train=f"{loss_train:.4f}", loss_val=f"{loss_val:.4f}")

                summary["training_loss"].append(loss_train)
                summary["validation_loss"].append(loss_val)

                if loss_val < self.best_val:
                    self.best_val = loss_val
                    best_model_state = deepcopy(model.state_dict())
                    wait = 0
                elif wait >= self.train_args["stop_after_epochs"]:
                    logging.info(f"Stopping early after {epoch} epochs.")
                    break
                else:
                    wait += 1
            else:
                logging.warning(
                    f"Training did not converge in {self.train_args['max_epochs']} epochs."
                )

        summary["best_validation_loss"] = self.best_val
        summary["epochs_trained"] = epoch + 1
        self.model = model
        self.model.load_state_dict(best_model_state)
        return summary

    def predict(self, x):
        """Predict the log-Bayes Ratio for a given input."""
        if self.model is None:
            raise ValueError("Model has not been trained yet. Please call .train() first.")
        self.model.eval()
        with torch.no_grad():
            x_tensor = torch.from_numpy(np.atleast_2d(x).astype(np.float32)).to(self.device)
            return self.model(x_tensor).cpu().numpy()

Setup Evidence Networks approach between BPASS and LTU-ILI


In [None]:
network = K_EvidenceNetwork()

# Create data loaders for two models

from torch.utils.data import DataLoader, TensorDataset

bpass_dataset = TensorDataset(
    torch.tensor(fitter_bpass._X_test), torch.tensor(fitter_bpass._y_test)
)
bpass_loader = DataLoader(bpass_dataset, batch_size=64, shuffle=False)

# ltu_dataset = TensorDataset(torch.tensor(fitter_p._X_test), torch.tensor(fitter_ltu._y_test))
# ltu_loader = DataLoader(ltu_dataset, batch_size=64, shuffle=False)