## Gaussian Network Moments (GNMs)

Let $\mathbf{x}\sim\mathcal{N}\left(\mathbf{\mu}, \mathbf{\Sigma}\right)$ and $q(x) = \max(x, 0)$ where $\Phi(x)$ and $\varphi(x)$ are the CDF and PDF of the normal distribution,

$\mathbb{E}\left[q(\mathbf{x})\right] = \mathbf{\mu}\odot\Phi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right) + \mathbf{\sigma}\odot\varphi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right)$
where $\mathbf{\sigma} = \sqrt{\text{diag}\left(\mathbf{\Sigma}\right)}$ with
$\odot$ and $\oslash$ as element-wise product and division.

$\mathbb{E}\left[q^2(\mathbf{x})\right] = 
\left(\mathbf{\mu}^2+\mathbf{\sigma}^2\right) \odot \Phi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right) + \mathbf{\mu} \odot \mathbf{\sigma} \odot \varphi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right)$
where $\text{var}\left[q(\mathbf{x})\right] = \mathbb{E}\left[q^2(\mathbf{x})\right] - \mathbb{E}\left[q(\mathbf{x})\right]^2$

$\left.\mathbb{E}\left[q(\mathbf{x})q(\mathbf{x})^\top\right]\right|_{\mathbf{\mu} = \mathbf{0}} = c\left(\mathbf{\Sigma}\oslash\mathbf{\sigma}\mathbf{\sigma}^\top\right) \odot \mathbf{\sigma}\mathbf{\sigma}^\top$
where $c(x) = \frac{1}{2\pi}\left(x\cos^{-1}(-x)+\sqrt{1-x^2}\right)$
(Note: $\left|c(x) - \Phi(x - 1)\right| < 0.0241$)

$\text{cov}\left[q(\mathbf{x})\right] = \mathbb{E}\left[q(\mathbf{x})q(\mathbf{x})^\top\right] - \mathbb{E}\left[q(\mathbf{x})\right]\mathbb{E}\left[q(\mathbf{x})\right]^\top$
where $\left.\text{cov}\left[q(\mathbf{x})\right]\right|_{\mathbf{\mu} = \mathbf{0}} = \left.\mathbb{E}\left[q(\mathbf{x})q(\mathbf{x})^\top\right]\right|_{\mathbf{\mu} = \mathbf{0}} - \frac{1}{2\pi}\mathbf{\sigma}\mathbf{\sigma}^\top$

In [0]:
from math import gamma, pi, sqrt

import torch
from torch.autograd import Function

from scipy.special import hermite as scipy_hermite
from scipy.special import gammainc as scipy_gammainc


class GammaInc(Function):
    """The normalized lower incomplete gamma function."""
    @staticmethod
    def forward(ctx, a, x):  # pylint: disable=unused-argument
        """Perform the forward pass."""
        ctx.a = a
        ctx.save_for_backward(x)
        return x.new(scipy_gammainc(a, x.data.cpu().numpy()))

    @staticmethod
    def backward(ctx, grad_output):
        """Perform the backward pass."""
        x, = ctx.saved_tensors
        return None, (-x).exp() * x ** (ctx.a - 1)

gammainc = GammaInc.apply

def hermite(n, x):
    """Evaluate the `n`th-degree Hermite polynomial."""
    coefficients = iter(scipy_hermite(n).coeffs.tolist())
    out = x.new_zeros(x.size()) if torch.is_tensor(x) else 0
    out += next(coefficients, 0)
    for c in coefficients:
        out *= x
        out += c
    return out

def erf_exp_integral(a, b, c, n):
    """Integrate `erf(a*x+b) / exp(x**2)` from `c` to infinity."""
    assert isinstance(n, int) and n >= 0
    sign_a = a.sign()
    cond = a.abs() <= 1
    v = a.where(cond, 1 / a)
    x = a.where(cond, v.abs()) * 0.5
    y = b.where(cond, -b * v)
    z = c.where(cond, a * c + b)
    w = c.where(cond, sign_a * z).sign()
    series = 0
    z_squared = z * z
    term = lambda h, e, d: 1 / gamma(d) * x**e * hermite(h, y)
    for i in range(n + 1):
        s0 = term(2 * i + 1, 2 * i + 2, i + 2)
        s1 = term(2 * i, 2 * i + 1, i + 3 / 2)
        f0 = w * gammainc(i + 3 / 2, z_squared)
        f1 = 1 - gammainc(i + 1, z_squared)
        series += f0 * s0 + f1 * s1
    neg_y = -y
    erf_z = z.erf()
    out = w * erf_z.abs() * neg_y.erf() + (y / (1 + v * v).sqrt()).erf()
    out = sqrt(pi) / 2 * out + (neg_y * y).exp() * series
    return out.where(cond, sign_a * (1 - out) - c.erf() * erf_z)

In [0]:
def approximate_integral(mean_std, correlation, cdf=None):
    # `out` is the approximate component in computing `I`
    sqrt(1 / pi) * 
    # I = erf_exp_integral(a, b, c, float('inf'))
    # a = correlation / icorrelation
    # b = sqrt(1 / 2) * mean_std / icorrelation
    # c = -sqrt(1 / 2) * mean_std[:, None]
    # icorrelation = sqrt(1 + correlation * correlation)
    x = (correlation + 1) / 2
    a = x**(1 / (2 / sqrt(pi) + 1))
    a = a * (sqrt(2) - sqrt(pi) / 2) + sqrt(pi) / 2
    b = (2 * x**sqrt(1 / pi)).erf()
    b = b * sqrt(1 / 2) + (2 - sqrt(2) / 2)
    out = (-mean_std.abs() / a).exp()**b
    if cdf is None:
        cdf = 1 / 2 * (1 + (mean_std * sqrt(1 / 2)).erf())
    return sqrt(pi) * cdf - sqrt(1 / pi) * correlation.acos() * out

In [0]:
def gaussian_relu_moments(covariance, mean=None)  # , n=10):
    assert 0 < covariance.ndim <= 2
    if mean is None:
        relu_mean = 1 / sqrt(2 * pi) * std
    if covariance.ndim == 1:  # diagonal covariance matrix
        variance = covariance
        std = variance.sqrt()
        if mean is None:
            relu_variance = 1 / 2 * variance
            return relu_variance, relu_mean
    else:
        variance = covariance.diagonal()
        std = variance.sqrt()
        sigma = std[:, None] @ std[None, :]
        correlation = covariance / sigma
        if mean is None:  # or n < 0:  # zero mean
            c = lambda x: x * (-x).acos() + (1 - x * x).sqrt() - 1
            relu_covariance = 1 / (2 * pi) * sigma * c(correlation)
            return relu_covariance, relu_mean
    mean_std = mean / std
    cdf = 1 / 2 * (1 + (mean_std * sqrt(1 / 2)).erf())
    pdf = 1 / sqrt(2 * pi) * (-1 / 2 * mean_std * mean_std).exp()
    std_pdf = std * pdf
    relu_mean = mean * cdf + std_pdf
    relu_2nd_moment = (mean * mean + variance) * cdf + mean * std_pdf
    relu_variance = relu_2nd_moment - relu_mean * relu_mean
    if covariance.ndim == 1:
        return relu_variance, relu_mean
    mu = mean[:, None] @ mean[None, :]
    mu_std = mean[:, None] @ std[None, :]
    relu_mu = relu_mean[:, None] @ relu_mean[None, :]
    ir = (1 - correlation * correlation).sqrt()
    m_sir = mean_std / ir
    arg = m_sir.t() - m_sir * correlation
    t1 = (mu + covariance) * (sqrt(1 / pi) * integral + cdf[:, None])
    t2 = pdf * (1 + (arg * sqrt(0.5)).erf()) * mu_std
    t3 = sigma * ir * (-0.5 * (arg * arg + mean_std**2)).exp()
    relu_cross_correlation = 0.5 * (t1 + t2 + t2.t() + 1 / pi * t3)
    relu_covariance = relu_cross_correlation - relu_mu
    relu_covariance.diagonal().copy_(relu_variance)
    return relu_covariance, relu_mean

In [0]:
def rand_cov(dim, *, dtype=None, device=None):
    """Generate a random covariance matrix."""
    eigen = 1 - torch.rand(dim, dtype=dtype, device=device)
    eigen *= (2 * dim)**0.5 / eigen.norm()
    q, _ = torch.randn(dim, dim, dtype=dtype, device=device).qr()
    return (q * eigen) @ q.t()

def compute_moments(m, unbiased=True):
    """Estimate the mean and covariance given data."""
    mean = m.mean(-1)
    m = m - mean[..., None]
    fact = 1 / (m.shape[-1] - int(bool(unbiased)))
    covariance = fact * (m @ m.transpose(-2, -1))
    return mean, covariance

mc = 1000000
mean = torch.randn(3).double() * 2
covariance = rand_cov(mean.size(-1), device=mean.device, dtype=mean.dtype)
relu_covariance, relu_mean = gaussian_relu_moments(covariance, mean)
normal = torch.distributions.MultivariateNormal(mean, covariance)
mc_mean, mc_covariance = compute_moments(
    normal.sample([mc]).clamp_min_(0).t())

print('Test mean')
print(relu_mean.tolist())
print(mc_mean.tolist())
print(torch.allclose(relu_mean, mc_mean, rtol=1e-1))

print('Test variance')
print(relu_covariance.diagonal().tolist())
print(mc_covariance.diagonal().tolist())
print(torch.allclose(relu_covariance.diagonal(),
                     mc_covariance.diagonal(),
                     rtol=1e-1, atol=1e-3))

print('Test covariance')
i, j = torch.triu_indices(mean.size(-1), mean.size(-1), 1)
print(relu_covariance[..., i, j].tolist())
print(mc_covariance[..., i, j].tolist())
print(torch.allclose(relu_covariance, mc_covariance,
                     rtol=1e-1, atol=1e-3))

Test mean
[2.023749878604507, 0.03952140018354473, 1.7729310927458177]
[2.0231033917960968, 0.03960097996071463, 1.7726876231735722]
True
Test variance
[0.7228076900485041, 0.03188508095127822, 1.4355168310137292]
[0.7225533421565001, 0.03197248546042809, 1.4344567013906944]
True
Test covariance
[-0.02138807537179481, 0.17485284576809956, -0.02046914856611238]
[-0.05886193084773982, 0.17374803049963766, -0.01998272445296804]
[2.7520910518841295, 0.9936814567494819, 0.9762362312446333]
False


In [0]:
from argparse import Namespace

import torch
from torch import nn

@torch.no_grad()
def linearize(net, x):
    """Compute the affine approximation for a model.

    Args:
        net: Model.
        x: Point of approximation without batch dimension.

    Returns:
        (weights_matrix, bias_vector).
        If `net` was a linear layer, the data of the parameters
        will be returned instead of a computed clone (handle with care).

    """
    assert not net.training
    if isinstance(net, nn.Sequential) and len(net) == 1:
        return linearize(net[0], x)
    output = net(x.unsqueeze(0))[0]
    if isinstance(net, nn.Linear):
        weight = net.weight.data
        if net.bias is None:
            bias = weight.new_zeros(net.out_features)
        else:
            bias = net.bias.data
    else:
        x = x.repeat(output.numel(), *[1]*x.dim())
        eye = x.new_zeros(output.numel(), output.numel())
        eye.diagonal().fill_(1)
        eye = eye.view(output.numel(), *output.shape)
        with torch.enable_grad():
            net(x.requires_grad_(True)).backward(eye)
        weight = x.grad.view(output.numel(), -1)
        bias = output.view(-1) - weight @ x[0].view(-1)
    return Namespace(output=output, weight=weight, bias=bias)

@torch.no_grad()
def staged_linearization(net, x):
    """Linearize a sequential model around the last ReLU layer."""
    assert isinstance(net, nn.Sequential) and not net.training
    assert len(net) > 2
    for i, layer in enumerate(reversed(net)):
        if isinstance(layer, nn.ReLU):
            break
    else:
        raise ValueError('Could not find any ReLU layer')
    a = linearize(net[:-i-1].train(False), x)
    b = linearize(net[-i:].train(False), a.output.clamp_min_(0))
    del a.output, b.output
    return a, b

def get_lenet(input_size=28, channles=1, num_classes=10):
    """Create a sequential LeNet model."""
    size = input_size // 4
    return nn.Sequential(
        nn.Conv2d(channles, 32, kernel_size=5, padding=2),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.ReLU(inplace=True),
        nn.Conv2d(32, 64, kernel_size=5, padding=2),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.ReLU(inplace=True),
        nn.Flatten(),
        nn.Linear(64 * size * size, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, num_classes),
    )

In [0]:
net = get_lenet().train(False).double()
shape = (1, 28, 28)
mean = torch.randn(shape).view(-1).double()
a, b = staged_linearization(net.train(False), mean.view(shape))
cov = rand_cov(mean.numel(), device=mean.device, dtype=mean.dtype) / 20

a_mean = a.weight @ mean + a.bias
a_cov = a.weight @ cov @ a.weight.t()
r_cov, r_mean = gaussian_relu_moments(a_cov, a_mean)
b_mean = b.weight @ r_mean + b.bias
b_cov = b.weight @ r_cov @ b.weight.t()

mc = int(1e6)
normal = torch.distributions.MultivariateNormal(mean, cov)
with torch.no_grad():
    out = net(normal.sample([mc]).view(mc, *shape)).view(mc, -1)
o_mean, o_cov = compute_moments(out.t())

In [0]:
print(o_mean / b_mean)
print((o_cov / b_cov).diagonal())