In [2]:
import torch
import matplotlib.pyplot as plt
from IPython.display import display
from mpl_toolkits.mplot3d import Axes3D
import network_moments.torch.gaussian as gnm
from ipywidgets import Checkbox, interactive, FloatSlider
from torch.distributions import MultivariateNormal as gaussian

plt.style.use('dark_background')

def continuous_update(*args):
    y = Checkbox(value=False, description='Continuous Update')
    def observe(*a):
        for x in args:
            x.continuous_update = y.value
    observe()
    y.observe(observe, 'value')
    display(y)

    
def f(x, y, z, v, w):
    if x == 0 or y == 0 or x*y == z**2:
        print()
        return
    dtype=torch.float64
    mean = torch.tensor([v, w], dtype=dtype)
    cov = torch.tensor([[x, z], [z, y]], dtype=dtype)
    torch.manual_seed(0)
    try:
        samples = gaussian(mean, cov).sample((100000,))
    except:
        print()
        return
    samples.clamp_(min=0)
    corr = gnm.utils.cov(samples)[1, 1] + samples.mean(dim=0).prod()
    plt.scatter(samples[:, 0].numpy(), samples[:, 1].numpy(), 1, label=corr.item())
    plt.legend()
    plt.show()
    def plot(data, color, label):
        xs = torch.linspace(data.min(), data.max())
        plt.plot(xs.numpy()[1:], data.histc().numpy()[1:], color, label=label)
#     plot(samples[:, 0], 'c', '1')
#     plot(samples[:, 1], 'g', '2')
#     plot(samples[:, 0].clamp(min=0), 'b', 'r1')
#     plot(samples[:, 1].clamp(min=0), 'r', 'r2')
#     plt.legend()
#     plt.show()

v = FloatSlider(min=-100, max=100, value=0, step=0.1, description='m1')
w = FloatSlider(min=-100, max=100, value=0, step=0.1, description='m2')
x = FloatSlider(min=0, max=100, value=20, step=0.1, description='s1')
y = FloatSlider(min=0, max=100, value=40, step=0.1, description='s2')
z = FloatSlider(min=-100, max=100, value=17, step=0.1, description='s12')
continuous_update(x, y, z, v, w)
interactive(f, x=x, y=y, z=z, v=v, w=w)

Let $\mathbf{x}\sim\mathcal{N}\left(\mathbf{\mu}, \mathbf{\Sigma}\right)$ and $q(x) = \max(x, 0)$ where $\Phi(x)$ and $\varphi(x)$ are the CDF and PDF of the standard Gaussian distribution,

$\mathbb{E}\left[q(\mathbf{x})\right] = \mathbf{\mu}\odot\Phi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right) + \mathbf{\sigma}\odot\varphi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right)$
where $\mathbf{\sigma} = \sqrt{\text{diag}\left(\mathbf{\Sigma}\right)}$ with
$\odot$ and $\oslash$ as element-wise product and division.

$\mathbb{E}\left[q^2(\mathbf{x})\right] = 
\left(\mathbf{\mu}^2+\mathbf{\sigma}^2\right) \odot \Phi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right) + \mathbf{\mu} \odot \mathbf{\sigma} \odot \varphi\left(\mathbf{\mu}\oslash\mathbf{\sigma}\right)$
where $\text{var}\left[q(\mathbf{x})\right] = \mathbb{E}\left[q^2(\mathbf{x})\right] - \mathbb{E}\left[q(\mathbf{x})\right]^2$

$\left.\mathbb{E}\left[q(\mathbf{x})q(\mathbf{x})^\top\right]\right|_{\mathbf{\mu} = \mathbf{0}} = c\left(\mathbf{\Sigma}\oslash\mathbf{\sigma}\mathbf{\sigma}^\top\right) \odot \mathbf{\sigma}\mathbf{\sigma}^\top$
where $c(x) = \frac{1}{2\pi}\left(x\cos^{-1}(-x)+\sqrt{1-x^2}\right)$
(Note: $\left|c(x) - \Phi(x - 1)\right| < 0.0241$)

$\text{cov}\left[q(\mathbf{x})\right] = \mathbb{E}\left[q(\mathbf{x})q(\mathbf{x})^\top\right] - \mathbb{E}\left[q(\mathbf{x})\right]\mathbb{E}\left[q(\mathbf{x})\right]^\top$
where $\left.\text{cov}\left[q(\mathbf{x})\right]\right|_{\mathbf{\mu} = \mathbf{0}} = \frac{\mathbf{\sigma}\mathbf{\sigma}^\top}{2\pi}\left(\mathbf{R}\cos^{-1}\left(-\mathbf{R}\right)+\sqrt{1-\mathbf{R}^2} - 1\right)$ and $\mathbf{R} = \mathbf{\Sigma} \oslash \mathbf{\sigma}\mathbf{\sigma}^\top$

$\mathbb{E}\left[q^2(\mathbf{x})\right] = \text{diag}\left(\mathbb{E}\left[q(\mathbf{x})q(\mathbf{x})^\top\right]\right)$

In [2]:
# # def g(s1, s2, s12):
# #     s1s2 = s1 * s2
# #     sqrt = math.sqrt(s1s2**2-s12**2)
# #     return (s12 * math.asin(s12 / s1s2) + sqrt) / (2 * math.pi) + s12 / 4

# def off(x):
#     y = x.view(-1, *x.shape[-2:]) if x.dim() > 2 else x.unsqueeze(0)
#     cov = torch.tensor([gnm.utils.cov(s)[1, 1] for s in y], dtype=x.dtype)
#     return cov.view(*x.shape[:-2]) if x.dim() > 2 else cov

# def corr(mean, s1, s2, s12):
#     cov = torch.tensor([[s1, s12], [s12, s2]], dtype=mean.dtype)
#     torch.manual_seed(0)
#     try:
#         samples = gaussian(torch.zeros(2, dtype=mean.dtype), cov).sample((10000,))
#     except:
#         n = mean.size(0) if mean.dim() > 1 else 1
#         return torch.zeros(n, dtype=mean.dtype)
#     out = (samples + mean.unsqueeze(-2)).clamp(min=0)
#     mean = out.mean(dim=-2).prod()
#     return off(out) + mean

# def f(s1, s2, s12):
#     dtype = torch.float64
#     xs, ys = torch.meshgrid([torch.linspace(-2, 2, dtype=dtype)] * 2)
#     zs = corr(torch.stack((xs, ys), dim=-1), s1, s2, s12)
#     fig = plt.figure()
#     ax = fig.add_subplot(111, projection='3d')
#     ax.plot_surface(xs.numpy(), ys.numpy(), zs.numpy())
#     plt.show()

# f(1, 1, 0)

# # s1 = FloatSlider(min=0, max=100, value=20, step=0.1, description='s1')
# # s2 = FloatSlider(min=0, max=100, value=40, step=0.1, description='s2')
# # s12 = FloatSlider(min=-100, max=100, value=17, step=0.1, description='s12')
# # continuous_update(s1, s2, s12)
# # interactive(f, s1=s1, s2=s2, s12=s12)

In [9]:
import math
import torch
import network_moments.torch.gaussian as gnm

outer = gnm.utils.outer
normal_density = gnm.utils.stats.gaussian.normal_density

def prn(a):
    print('<{:.1e}, {:.1e}, {:.1e}>'.format(
        a[0, 0].item(), a[1, 1].item(), a[0, 1].item()))

def err(a, b):
    return (torch.norm(a - b) / torch.norm(b)).item()

factor = 10
cov = gnm.utils.rand.definite(2, norm=factor**2, dtype=torch.float64)
mean = factor * torch.randn(cov.size(0), dtype=cov.dtype, device=cov.device)*0
samples = torch.distributions.MultivariateNormal(mean, cov).sample((1000000,)).clamp(min=0.0)

N = outer(samples.mean(0))
out_cov = gnm.utils.cov(samples) + N
my_cov = relu_cov(mean, cov) + N
print(err(out_cov, my_cov))
prn(my_cov)
prn(out_cov)
prn((my_cov - out_cov).abs() / factor**2)

0.0012089187375359063
<4.4e+01, 2.4e+01, 9.5e+00>
<4.4e+01, 2.4e+01, 9.5e+00>
<5.5e-04, 8.8e-06, 2.0e-04>


$f\left(x_1, x_2\right) = \max\left(x_1, 0\right) \max\left(x_2, 0\right)$ such that $x_1 \sim \mathcal{N}(\mu_1, \sigma_1)$ and $x_2 \sim \mathcal{N}(\mu_2, \sigma_2)$.

If $x_1 = x_2$, then $\mathbb{E}\left[f\right] = (\mu^2+\sigma^2)C\left(\frac{\mu}{\sigma}\right) + \mu\sigma P\left(\frac{\mu}{\sigma}\right)$ where $C(x)$ and $P(x)$ are the CDF and PDF of the standard normal distribution.

If $\mu_1 = \mu_2 = 0$, then $\mathbb{E}\left[f\right] = \sigma_1\sigma_2q(c)$ where $q(x) = \frac{x\cos^{-1}(-x)+\sqrt{1-x^2}}{2\pi}$ and $c = \frac{\sigma_{1,2}}{\sigma_1\sigma_2}$ (correlation $|c| \leq 1$).  
Note: $q(x): \left[-1, 1\right]\rightarrow\left[0, \frac{1}{2}\right]$ and $q(x) \approx C(x - 1)$ with a maximum error of $0.0241$ closer to the left tail.

If $x_1 = x_2$ and $\mu = 0$, then $\mathbb{E}\left[f\right] = \sigma^2q\left(1\right) = \sigma^2C\left(0\right) = \frac{\sigma^2}{2}$.

In [5]:
def relu_cov(mean, covariance, stability=0.0):
    var = torch.diagonal(covariance, 0, -2, -1)
    std = torch.sqrt(var)
    pdf, cdf = normal_density(mean / std)
    S = outer(std)
    V = (covariance / S).clamp_(stability - 1.0, 1.0 - stability)
#     Q = normal_density(V - 1.0, pdf=False)
    Q = (torch.acos(-V) * V + torch.sqrt(1.0 - (V**2.0))) / (2.0 * math.pi)
    M = outer(mean)
    N = outer(gnm.relu.mean(mean, std, std=True))
#     C = Q * S
#     C = (S + M) * cdf.diag() + (mean * std * pdf).diag()
    out_var = (mean**2 + var) * cdf + mean * std * pdf
    C = (Q * S).as_strided([Q.size(0)], [Q.size(0)+1]).copy_(out_var)
    return C - N

# V = gnm.utils.rand.definite(100, dtype=torch.float64)
# %timeit (torch.acos(-V) * V + torch.sqrt(1.0 - (V**2.0))) / (2.0 * math.pi)
# %timeit gnm.utils.stats.gaussian.normal_density(V - 1.0, pdf=False)

In [7]:
cov / outer(cov.diag().sqrt())

tensor([[ 1.0000, -0.9839],
        [-0.9839,  1.0000]], dtype=torch.float64)

In [8]:
def relu_cov(mean, covariance, stability=0.0):
    std = torch.sqrt(torch.diagonal(covariance, 0, -2, -1))
    S = gnm.utils.outer(std)
    V = (covariance / S).clamp_(stability - 1.0, 1.0 - stability)
    Q = (torch.acos(-V) * V + torch.sqrt(1.0 - (V**2.0))) / (2.0 * math.pi)
    M = gnm.utils.outer(gnm.relu.mean(mean, std, std=True))
    return Q * S - M