# Masking Multivariate Normal CDF

In [1]:
#!git clone --recursive https://github.com/SebastienMarmin/torch-mvnorm

In [2]:
#import sys
#sys.path.append("torch-mvnorm/")

In [1]:
import scipy
import scipy.stats as st

In [2]:
import torch
import torch.distributions as td
from mvnorm import multivariate_normal_cdf

In [11]:
import torch
import torch.distributions as td
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from itertools import product
import numpy as np
from torch.distributions.utils import _standard_normal, broadcast_all
from entmax import sparsemax


class GaussianSparsemax(td.Distribution):
    
    arg_constraints = {
        'loc': constraints.real, 
        'scale': constraints.positive
    }    
    support = td.constraints.simplex
    has_rsample = True
    
    @classmethod
    def all_faces(K):
        """Generate a list of 2**K - 1 bit vectors indicating all possible faces of a K-dimensional simplex."""
        return list(product([0, 1], repeat=K))[1:]

    def __init__(self, loc, scale, validate_args=None):
        self.loc, self.scale = broadcast_all(loc, scale)
        batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(GaussianSparsemax, _instance)
        batch_shape = torch.Size(batch_shape)
        new.loc = self.loc.expand(batch_shape + self.event_shape)
        new.scale = self.scale.expand(batch_shape + self.event_shape)
        super().__init__(batch_shape, self.event_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

    def rsample(self, sample_shape=torch.Size()):
        # sample_shape + batch_shape + (K,)
        z = td.Normal(loc=self.loc, scale=self.scale).rsample(sample_shape)
        return sparsemax(z, dim=-1)
    
    def log_prob(self, y, pivot_alg='first', tiny=1e-12, huge=1e12):
        K = y.shape[-1]
        # [B, K]
        loc = self.loc
        scale = self.scale
        var = scale ** 2
        
        # The face contains the set of coordinates greater than zero
        # [B, K]
        face = y > 0 

        # Chose a pivot coordinate (a non-zero coordinate)
        # [B]
        if pivot_alg == 'first':
            ind_pivot = torch.argmax((face > 0).float(), -1)
        elif pivot_alg == 'random':
            ind_pivot = td.Categorical(
                probs=face.float()/(face.float().sum(-1, keepdims=True))
            ).sample()
        # Select a batch of pivots 
        # [B, K]
        pivot_indicator = torch.nn.functional.one_hot(ind_pivot, K).bool()
        # All non-zero coordinates but the pivot
        # [B, K]
        others = torch.logical_xor(face, pivot_indicator)
        # The value of the pivot coordinate
        # [B]
        t = (y * pivot_indicator.float()).sum(-1)
        # Pivot mean and variance
        # [B]
        t_mean = torch.where(pivot_indicator, loc, torch.zeros_like(loc)).sum(-1)
        t_var = torch.where(pivot_indicator, var, torch.zeros_like(var)).sum(-1)

        # Difference with respect to the pivot
        # [B, K]
        y_diff = torch.where(others, y - t.unsqueeze(-1), torch.zeros_like(y))
        # [B, K]
        mean_diff = torch.where(
            others, 
            loc - t_mean.unsqueeze(-1),
            torch.zeros_like(loc)
        )
        
        # Joint log pdf for the non-zeros
        # [B, K, K]    
        diag = torch.diag_embed(torch.where(others, var, torch.ones_like(var)))
        offset = t_var.unsqueeze(-1).unsqueeze(-1)
        # We need a multivariate normal for the non-zero coordinates in `other`
        # but to batch mvns we will need to use K-by-K covariances
        # we can do so by embedding the lower-dimensional mvn in a higher dimensional mvn
        # with cov=I.
        # [B, K, K]
        cov_mask = others.unsqueeze(-1) * others.unsqueeze(-2)
        cov = torch.where(cov_mask, diag + offset, diag)
        # This computes log prob of y[other] under  the lower dimensional mvn
        # times log N(0|0,1) for the other dimensions
        # [B]
        log_prob = td.MultivariateNormal(mean_diff, cov).log_prob(y_diff)
        # so we discount the contribution from the masked coordinates
        # [B, K]
        log_prob0 = td.Normal(torch.zeros_like(mean_diff), torch.ones_like(mean_diff)).log_prob(torch.zeros_like(y_diff)) 
        log_prob = log_prob - torch.where(others, torch.zeros_like(log_prob0), log_prob0).sum(-1)

        # Joint log prob for the zeros (needs the cdf)
        # [B]
        constant_term = 1. / torch.where(face, 1./var, torch.zeros_like(var)).sum(-1)
        # Again, we aim to reason with lower-dimensional mvns via 
        # the td.MultivariateNormal interface. For that, I will mask the coordinates in face.
        # The non-zeros get a tiny variance
        # [B, K, K]
        diag_corrected = torch.diag_embed(torch.where(face, torch.zeros_like(var) + tiny, var)) 
        # [B, 1, 1]
        offset_corrected = constant_term.unsqueeze(-1).unsqueeze(-1)
        # These are the zeros only.
        # [B, K, K]
        cov_corrected_mask = torch.logical_not(face).unsqueeze(-1) * torch.logical_not(face.unsqueeze(-2))    
        cov_corrected = torch.where(cov_corrected_mask, diag_corrected + offset_corrected, diag_corrected)    

        # The non-zeros get a large negative mean.
        # [B]
        mean_constant_term = constant_term * torch.where(face, (y - loc)/var, torch.zeros_like(y)).sum(-1)
        # [B, K]
        #  see that for non-zeros I move the location to something extremely negative
        #  in combination with tiny variace this makes the density of 0 evaluate to 0
        #  and the cdf of 0 evaluate to 1, for those coordinates
        mean_corrected = torch.where(face, torch.zeros_like(y) - huge, loc + mean_constant_term.unsqueeze(-1))

        # [B]
        cdf = multivariate_normal_cdf(
            torch.zeros_like(y),
            mean_corrected, cov_corrected
        )
        log_cdf = cdf.log()

        # [B]
        log_det = face.float().sum(-1).log()

        # [B]
        return log_prob + log_cdf + log_det


@td.register_kl(GaussianSparsemax, GaussianSparsemax)
def _kl_gaussiansparsemax_gaussiansparsemax(p, q):
    x = p.rsample()
    return p.log_prob(x) - q.log_prob(x)

In [4]:
def density(y, base_dists, ind_pivot=0):
    """
    Evaluate the density of y. Can optionally pick a pivot.
    y_nz contains only the non-zeros, corresponding to the non-masked
    entries of the face.
    """
    face = y > 0
    y_nz = y[np.nonzero(face)[0]]
    from scipy.stats import multivariate_normal

    nz = np.nonzero(face)[0]
    zeros = np.nonzero(1-np.array(face))[0]
    pivot = nz[ind_pivot]
    nz_minus_pivot = np.array(list(set(nz) - {pivot}))
    y = np.zeros(len(face))
    y[nz] = y_nz

    # Compute density contribution by the nonzeros - requires pdf.
    y_diff = np.array([y[s] - y[pivot] for s in nz_minus_pivot])
    mean_diff = np.array([base_dists[s].mean() - base_dists[pivot].mean() for s in nz_minus_pivot])
    if len(nz_minus_pivot):
        cov = np.diag(np.array([base_dists[s].var() for s in nz_minus_pivot])) + base_dists[pivot].var()
        val = multivariate_normal(mean=mean_diff, cov=cov).pdf(y_diff)
    else:
        val = 1
    
    # Compute density contribution by the zeros - requires cdf.
    if len(zeros):
        constant_term = 1 / sum([1/base_dists[s].var() for s in nz])
        cov_corrected = np.diag(np.array([base_dists[r].var() for r in zeros])) + constant_term
        mean_constant_term = (constant_term *
                              sum(np.array([(y[s] - base_dists[s].mean())/base_dists[s].var() for s in nz])))
        mean_corrected = np.array([base_dists[r].mean() + mean_constant_term for r in zeros])
        cdf = multivariate_normal(mean=mean_corrected, cov=cov_corrected).cdf(np.zeros(len(zeros)))
        val *= cdf    
    val *= sum(face)  # This is the determinant of the transformation Y -> U = g(Y) for the nonzeros.
    return val

In [10]:
batch_size = 5
for K in [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 10, 10]:
    print(f"batch_size={batch_size} K={K}")
    all_y = []
    all_u = []
    all_s = []
    all_r1 = []
    for n in range(batch_size):
        u = torch.tensor(np.random.normal(0., 1., size=K), requires_grad=True)
        s = torch.tensor(np.random.normal(1., 1., size=K), requires_grad=True)
        s = torch.nn.functional.softplus(s)
        all_u.append(u)
        all_s.append(s)
        y = GaussianSparsemax(u, s).sample()
        all_y.append(y)
        all_r1.append(np.log(density(y.detach().numpy(), [st.norm(loc=u.detach().numpy()[k], scale=s.detach().numpy()[k]) for k in range(K)])))    

    # make a batch
    Y = torch.stack(all_y)
    U = torch.stack(all_u)
    S = torch.stack(all_s)
    all_r2 = GaussianSparsemax(U, S).log_prob(Y, pivot_alg='first')
    all_r3 = GaussianSparsemax(U, S).log_prob(Y, pivot_alg='random')

    print(" pass pivot='first':", torch.isclose(torch.tensor(all_r1), all_r2, atol=1e-3).all().item())
    print(" pass pivot='random':", torch.isclose(torch.tensor(all_r1), all_r3, atol=1e-3).all().item())
    
    def get_loss(values, locs, scales, reduce=True):
        loss = - GaussianSparsemax(locs, scales).log_prob(values)
        return loss.mean() if reduce else loss
    
    grad_check = torch.autograd.gradcheck(
        get_loss, 
        [Y, U, S], 
        eps=1e-02, 
        atol=0.1, 
        rtol=0.001, 
        raise_exception=True, 
        check_sparse_nnz=False, 
        nondet_tol=0.1, 
        check_undefined_grad=True, 
        check_grad_dtypes=False, 
        check_batched_grad=False
    )
    print(" pass gradient check:", grad_check)

batch_size=5 K=1
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=1
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=2
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=2
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=3
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=3
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=4
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=4
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=5
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=5
 pass pivot='first': True
 pass pivot='random': True
 pass gradient check: True
batch_size=5 K=10
 pass pivot=

In [13]:
p = GaussianSparsemax(td.Normal(0., 1.).sample((10,)), td.Gamma(1., 1.).sample((10,)))
q = GaussianSparsemax(td.Normal(0., 1.).sample((10,)), td.Gamma(1., 1.).sample((10,)))

In [31]:
td.kl_divergence(p, p), td.kl_divergence(p.expand((100,)), p.expand((100,))).mean(0)

(tensor(4.5300e-06), tensor(-3.9628e-07, dtype=torch.float64))

In [32]:
td.kl_divergence(p, q)

tensor(8.6511)

In [36]:
td.kl_divergence(p.expand((1000,)), q.expand((1000,))).mean(0)

tensor(8.3027, dtype=torch.float64)

In [37]:
x = p.sample((1000,))

In [38]:
(p.log_prob(x) - q.log_prob(x)).mean(0)

tensor(8.9562, dtype=torch.float64)