In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
import torch
import torch.nn as nn
from torch import Tensor
from sbi.types import Shape
from pyknos.nflows.flows import Flow
from sbi.neural_nets.density_estimators import DensityEstimator
from typing import Optional, Tuple, Union, List
from sbi.utils import assert_all_finite, mog_log_prob
from copy import deepcopy

In [20]:
class MDNDensityEstimator(DensityEstimator):
    def __init__(self, flow: Flow, condition_shape: torch.Size) -> None:
        super().__init__(flow, condition_shape)
        self._logits = None
        self._means = None
        self._precisions = None

    @property
    def embedding_net(self) -> nn.Module:
        r"""Return the embedding network."""
        return self.net._embedding_net
    
    @property
    def distribution(self):
        r"""Return the distribution of the density estimator."""
        return self.net._distribution
    
    @property
    def mog_parameters(self):
        r"""Return the parameters of the mixture of Gaussians."""
        if self._logits is None:
            self.mog_parameters = self.distribution.get_mixture_components()
        return self._logits, self._means, self._precisions
    
    @mog_parameters.setter
    def mog_parameters(self, mog_parameters: Tuple[Tensor, Tensor, Tensor]):
        r"""Set the parameters of the mixture of Gaussians."""
        self._logits, self._means, self._precisions = mog_parameters

    @property
    def logsumexplogits(self):
        return torch.logsumexp(self._logits, dim=-1, keepdim=True)

    def reset_mog_parameters(self):
        r"""Reset the parameters of the mixture of Gaussians."""
        self.mog_parameters = self.distribution.get_mixture_components()

    def mog_log_prob(self, input: Tensor) -> Tensor:
        r"""Return the log probability of the input under the mixture of Gaussians."""
        return mog_log_prob(input, *self.mog_parameters)

    def correct_for_proposal(self, proposal: MDNDensityEstimator, inplace: bool = True):
        logits_d, m_d, prec_d = self.mog_parameters
        logits_pp, m_pp, prec_pp = proposal.mog_parameters

        logits_pp, m_pp, prec_pp = self._proposal_posterior_transformation(logits_pp, m_pp, prec_pp, logits_d, m_d, prec_d)
        self.mog_parameters = (logits_pp, m_pp, prec_pp)

    def condition(self, condition: Tensor, mask: List[bool] = None, inplace: bool = True) -> MDNDensityEstimator:
        # return copy of the MDN with conditioned parameters
        condition = atleast_2d_float32_tensor(condition)

        dims_to_sample = torch.where(torch.tensor(mask))[0]
        cond_logits, cond_means, cond_precfs, cond_sumlogdiag = condition_mog(condition, dims_to_sample, *self.mog_parameters)
        cond_precs = cond_precfs.transpose(3, 2) @ cond_precfs
        # TODO: CHECK IF THIS IS CORRECT OR cond_logits need to be normalized

        if not inplace:
            conditioned_mdn = deepcopy(self)
            conditioned_mdn.mog_parameters = (cond_logits, cond_means, cond_precs)
            return conditioned_mdn
        self.mog_parameters = (cond_logits, cond_means, cond_precs)

    def set_context(self, context: Tensor, inplace: bool = True):
        self._check_condition_shape(context)
        
        embedded_x = self.embedding_net(context)
        logits_d, m_d, prec_d, _, _ = self.distribution.get_mixture_components(embedded_x)
        norm_logits_d = logits_d - torch.logsumexp(logits_d, dim=-1, keepdim=True)
        if not inplace:
            mdn = deepcopy(self)
            mdn.mog_parameters = (norm_logits_d, m_d, prec_d)
            return mdn
        self.mog_parameters = (norm_logits_d, m_d, prec_d)

    def marginalize(self, inplace: bool = True):
        # return copy of the MDN with marginalized parameters
        raise NotImplementedError

    def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
        self._check_condition_shape(condition)
        logits_p, m_p, prec_p = self.mog_parameters
        prec_factors_p = torch.linalg.cholesky(prec_p, upper=True)

        num_samples = torch.Size(sample_shape).numel()
        batch_size = 1 if sample_shape.ndim == 1 else sample_shape[0]
        # Replicate to use batched sampling from pyknos.
        if batch_size is not None and batch_size > 1:
            logits_p = logits_p.repeat(batch_size, 1)
            m_p = m_p.repeat(batch_size, 1, 1)
            prec_factors_p = prec_factors_p.repeat(batch_size, 1, 1, 1)

        # Get (optionally z-scored) MoG samples.
        theta = self.distribution.sample_mog(num_samples, logits_p, m_p, prec_factors_p)
        embedded_context = self.embedding_net(condition)
        if embedded_context is not None:
            # Merge the context dimension with sample dimension in order to
            # apply the transform.
            theta = torchutils.merge_leading_dims(theta, num_dims=2)
            embedded_context = torchutils.repeat_rows(
                embedded_context, num_reps=num_samples
            )

        theta, _ = self.net._transform.inverse(theta, context=embedded_context)

        if embedded_context is not None:
            # Split the context dimension from sample dimension.
            theta = torchutils.split_leading_dim(theta, shape=[-1, num_samples])

        return theta.reshape(*sample_shape,-1)

    def log_prob(self, input: Tensor, condition: Tensor, proposal: Optional[MDNDensityEstimator]) -> Tensor:
        self._check_condition_shape(condition)
        condition_dims = len(self._condition_shape)

        # PyTorch's automatic broadcasting
        batch_shape_in = input.shape[:-1]
        batch_shape_cond = condition.shape[:-condition_dims]
        batch_shape = torch.broadcast_shapes(batch_shape_in, batch_shape_cond)
        # Expand the input and condition to the same batch shape
        input = input.expand(batch_shape + (input.shape[-1],))
        condition = condition.expand(batch_shape + self._condition_shape)
        # Flatten required by nflows, but now both have the same batch shape
        input = input.reshape(-1, input.shape[-1])
        condition = condition.reshape(-1, *self._condition_shape)

        # z-score theta if it z-scoring had been requested.
        theta = self._maybe_z_score_theta(input)

        self.set_context(condition, inplace=True)
        if proposal is not None:
            self.correct_for_proposal(proposal)
        log_probs = self.mog_log_prob(theta)
        log_probs = log_probs.reshape(batch_shape) # \hat{p} from eq (3) in [1]

        assert_all_finite(log_probs, "proposal posterior eval")

        return log_probs   

    def loss(self, input: Tensor, condition: Tensor) -> Tensor:
        r"""Return the loss for training the density estimator.

        Args:
            input: Inputs to evaluate the loss on of shape (batch_size, input_size).
            condition: Conditions of shape (batch_size, *condition_shape).

        Returns:
            Negative log_probability (batch_size,)
        """

        return -self.log_prob(input, condition)