In [1]:
import numpy as np
import torch
import torch.distributions as dis
from tqdm import tqdm

### Univariate Normal

In [18]:
q = dis.Normal(loc=3, scale=1)

r1 = dis.Normal(loc=1, scale=1.5)
r2 = dis.Normal(loc=2, scale=1.5)
r3 = dis.Normal(loc=3, scale=1.5)

In [19]:
def mixture_pdf(x):
	return (r1.log_prob(x).exp() + r2.log_prob(x).exp() + r3.log_prob(x).exp()) / 3

In [20]:
x = q.sample(sample_shape=(5000,))

(q.log_prob(x) - mixture_pdf(x).log()).mean()

tensor(0.3467)

In [21]:
total = 0.0
for r in [r1, r2, r3]:
	print(dis.kl.kl_divergence(q, r))
	total += dis.kl.kl_divergence(q, r) / 3
total

tensor(1.0166)
tensor(0.3499)
tensor(0.1277)


tensor(0.4981)

In [None]:
(0.5 * (np.log(1.5 ** 2) - np.log(1.) - 1. + (1. / 1.5) ** 2 + 1. ** 2 / 1.5 ** 2)).round(4)

### Multivariate Normal

In [158]:
torch.manual_seed(711)

mu_q = torch.rand(4)
logvar_q = torch.rand(4)
q = dis.MultivariateNormal(mu_q, logvar_q.exp().diag())

mu_rs = torch.rand(3, 4)
logvar_rs = torch.rand(3, 1)

In [160]:
def mixture_pdf(x):
	r1 = dis.MultivariateNormal(mu_rs[0], logvar_rs[0].exp() * torch.eye(4))
	r2 = dis.MultivariateNormal(mu_rs[1], logvar_rs[1].exp() * torch.eye(4))
	r3 = dis.MultivariateNormal(mu_rs[2], logvar_rs[2].exp() * torch.eye(4))
	return (
		r1.log_prob(x).exp() 
		+ r2.log_prob(x).exp() 
		+ r3.log_prob(x).exp()
	) / 3

In [166]:
x = q.sample(sample_shape=(5000,))
logr =  mixture_pdf(x).log() - q.log_prob(x)
( - logr ).mean()

tensor(0.2921)

In [159]:
total = 0.0
for i in range(3):
	r = dis.MultivariateNormal(mu_rs[i], logvar_rs[i].exp() * torch.eye(4))
	print(dis.kl.kl_divergence(q, r))
	total += dis.kl.kl_divergence(q, r) / 3
total

tensor(0.2470)
tensor(0.2983)
tensor(0.4409)


tensor(0.3287)

For both univariate and multivariate cases, `KL(N0 || Mix of N) < Mix of KL(N0 || N)`

### Check `kl_ub_loss`

In [2]:
def kl_mvn(mu1, mu2, logvar_diag1, logvar_diag2):
    m = len(mu1)
    return 0.5 * (
        logvar_diag2.sum()
        - logvar_diag1.sum()
        - m
        + torch.sum(logvar_diag1.exp() / logvar_diag2.exp())
        + (mu2 - mu1).T @ logvar_diag2.exp().diag().inverse() @ (mu2 - mu1)
    )


In [3]:
q = dis.MultivariateNormal(
	loc=torch.tensor((1.0, 0.0, 1.0)),
	covariance_matrix=torch.diag(torch.tensor((2.0, 2.0, 2.0)))
)

p = dis.MultivariateNormal(
	loc=torch.tensor((2.0, 2.0, 2.0)),
	covariance_matrix=torch.diag(torch.tensor((1.0, 1.0, 1.0)))
)

dis.kl.kl_divergence(q, p)

tensor(3.4603)

In [4]:
0.5 * (np.log(1/8) - 3 + 6 + 6)

3.460279229160082

In [5]:
torch.manual_seed(0)
mu1 = torch.rand(3)
mu2 = torch.rand(3)
logvar_diag1 = torch.rand(3)
logvar_diag2 = torch.rand(1) * torch.ones(3)

# mu1 = torch.tensor((1.0, 0.0, 1.0))
# mu2 = torch.tensor((2.0, 2.0, 2.0))
# logvar_diag1 = torch.tensor((2.0, 2.0, 2.0)).log()
# logvar_diag2 = torch.tensor((1.0, 1.0, 1.0)).log()

In [6]:
dis.kl.kl_divergence(
	dis.MultivariateNormal(mu1, logvar_diag1.exp().diag()),
	dis.MultivariateNormal(mu2, logvar_diag2.exp().diag())
)

tensor(0.2020)

In [7]:
kl_mvn(
	mu1,
	mu2,
	logvar_diag1,
	logvar_diag2,
)

  + (mu2 - mu1).T @ logvar_diag2.exp().diag().inverse() @ (mu2 - mu1)


tensor(0.2020)

In [9]:
# correct here! but better off using dis.kl.kl_divergence
def kl_ub_loss(mu_minor, logvar_minor, mu_major, logvar_major, reduce: bool):
    """Calculate KL loss's lower bound using another Jensen's Inequality."""
    B, m = mu_minor.shape
    N, _ = mu_major.shape

    mu_minor = mu_minor.T[None, :, :].repeat(N, 1, 1)
    mu_major = mu_major[:, :, None].repeat(1, 1, B)
    kl_variant = (
        0.5 * ((mu_minor - mu_major) ** 2).sum(dim=1).mean(dim=0) / logvar_major.exp()
    )

    kl_invariant = 0.5 * (
        m * logvar_major
        - logvar_minor.sum(dim=1)
        - m
        + logvar_minor.exp().sum(dim=1) / logvar_major.exp()
    )
    
    kl = kl_invariant + kl_variant

    return kl.mean() if reduce else kl

In [17]:
kl_ub_loss(
	mu1.view(1,-1),
	logvar_diag1.view(1,-1),
	mu2.view(1,-1),
	logvar_diag2[0],
	reduce=False
)

tensor([0.2020])

In [48]:
# torch.manual_seed(0)
B = 16
N = 1024

mu_minor = torch.rand(B, 4)
logvar_minor = torch.rand(B, 4)

mu_major = torch.rand(1024, 4)
logvar_major = torch.rand(1)

In [49]:
_mu_minor = mu_minor[None, :, :].repeat(1024, 1, 1)
_logvar_minor = logvar_minor[None, :, :].repeat(1024, 1, 1)
_mu_major = mu_major[:, None, :].repeat(1, B, 1)
dis.kl.kl_divergence(
	dis.MultivariateNormal(_mu_minor, torch.diag_embed(_logvar_minor.exp())),
	dis.MultivariateNormal(_mu_major, logvar_major.exp() * torch.eye(4))
).mean(dim=0)

tensor([0.6279, 0.4496, 0.6519, 0.5348, 0.7484, 0.8422, 0.5366, 0.5087, 0.3609,
        0.6599, 0.5624, 0.4689, 0.7133, 0.4305, 0.5825, 0.5916])

In [156]:
kl_mat = torch.zeros(N, B)
for i in range(B):
	for j in range(N):
		kl_mat[j, i] = kl_mvn(mu_minor[i], mu_major[j], logvar_minor[i], logvar_major * torch.ones(4))

kl_mat.mean(dim=0)

tensor([0.6279, 0.4496, 0.6519, 0.5348, 0.7484, 0.8422, 0.5366, 0.5087, 0.3609,
        0.6599, 0.5624, 0.4689, 0.7133, 0.4305, 0.5825, 0.5916])

In [50]:
kl_ub_loss(mu_minor, logvar_minor, mu_major, logvar_major, reduce=False)

tensor([0.6279, 0.4496, 0.6519, 0.5348, 0.7484, 0.8422, 0.5366, 0.5087, 0.3609,
        0.6599, 0.5624, 0.4689, 0.7133, 0.4305, 0.5825, 0.5916])

In [157]:
kl_mat.min(dim=0)

torch.return_types.min(
values=tensor([0.3322, 0.1379, 0.3092, 0.3422, 0.3024, 0.4357, 0.2025, 0.2606, 0.1442,
        0.3574, 0.3085, 0.1506, 0.3976, 0.1067, 0.3104, 0.2951]),
indices=tensor([837,  82, 300, 212, 133, 470, 216,  54, 542, 537, 206, 417, 282, 509,
        109, 251]))