In [1]:
import numpy as np
import torch
from torch.distributions import MultivariateNormal

In [2]:
qm, qv = torch.rand(2, 5), torch.rand(2, 5)
dm, dv = torch.rand(3, 5).exp(), torch.rand(3, 5).exp()

In [3]:
def pytorch_kld(qmean, qvar, dmean, dvar):
    kl = torch.zeros(qmean.size(0), dmean.size(0), device=qmean.device)
    p = []
    q = []
    for i in range(qmean.size(0)):
        q.append(MultivariateNormal(qmean[i, :], covariance_matrix=torch.diag(qvar[i, :])))
    for i in range(dmean.size(0)):
        p.append(MultivariateNormal(dmean[i, :], covariance_matrix=torch.diag(dvar[i, :])))
    for (i, j) in np.ndindex(len(q), len(p)):
        kl[i, j] = -1 * torch.distributions.kl_divergence(q[i], p[j])
    return kl

pytorch_kld_ = pytorch_kld(qm, qv, dm, dv)
print(f"Ground truth for KL divergence:\n{pytorch_kld_}")

Ground truth for KL divergence:
tensor([[-4.2275, -4.9790, -4.4476],
        [-5.8411, -5.5777, -6.4685]])


In [4]:
def mrl_eq10_kld(qmean, qvar, dmean, dvar, remove_constants=False):
    """ Constants = k, 0.5, torch.log(qvar).sum() (constant w.r.t. document ranking).
        If remove_constants=True, we get Eq. 10 from MRL paper.
        If remove_constants=False, we get the same implementation as pytorch_kld.
    """
    k = qmean.size(1)
    logvar_ratio_term = torch.log(dvar).sum(1)[None, :]
    if not remove_constants:
        logvar_ratio_term = logvar_ratio_term - torch.log(qvar).sum(1)[:, None]
    trace_term = (qvar[:, None, :] / dvar[None, :, :]).sum(2)
    square_term = ((qmean[:, None, :] - dmean[None, :, :])**2 / dvar[None, :, :]).sum(2)
    kld = logvar_ratio_term + trace_term + square_term
    if not remove_constants:
        kld = 0.5 * (kld - k)
    return -kld

mrl_eq10_kld_ = mrl_eq10_kld(qm, qv, dm, dv)
mrl_eq10_kld_simplified_ = mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)
print(f"Same output as 'pytorch_kld(qm, qv, dm, dv)':\n{mrl_eq10_kld_}")
print(f"Eq. 10 simplified, constants removed:\n{mrl_eq10_kld_simplified_}")

Same output as 'pytorch_kld(qm, qv, dm, dv)':
tensor([[-4.2275, -4.9790, -4.4476],
        [-5.8411, -5.5777, -6.4685]])
Eq. 10 simplified, constants removed:
tensor([[ -9.9911, -11.4941, -10.4313],
        [-11.5406, -11.0140, -12.7955]])


In [5]:
def mrl_query_rep(qmean, qvar):
    batch_size = qmean.size(0)
    k = qmean.size(1)
    rep = torch.zeros(batch_size, 1 + 3 * k, device=qmean.device)
    rep[:, 0] = 1
    rep[:, 1:k + 1] = qvar
    rep[:, k + 1:2 * k + 1] = qmean ** 2
    rep[:, 2 * k + 1:] = qmean
    return rep

def mrl_document_rep(dmean, dvar):
    batch_size = dmean.size(0)
    k = dmean.size(1)
    rep = torch.zeros(batch_size, 1 + 3 * k, device=dmean.device)
    rep[:, 0] = - (torch.log(dvar) + dmean**2 / dvar).sum(1)
    rep[:, 1:k + 1] = - 1 / dvar
    rep[:, k + 1:2 * k + 1] = - 1 / dvar
    rep[:, 2 * k + 1:] = (2 * dmean) / dvar
    return rep

def mrl_dot_product(qmean, qvar, dmean, dvar):
    qrep = mrl_query_rep(qmean, qvar)
    drep = mrl_document_rep(dmean, dvar)
    return qrep @ drep.T

mrl_dot_product_ = mrl_dot_product(qm, qv, dm, dv)
print(f"Same output as 'mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)':\n{mrl_dot_product_}")

Same output as 'mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)':
tensor([[ -9.9911, -11.4941, -10.4313],
        [-11.5406, -11.0140, -12.7955]])


Final checks:

In [6]:
print(f"Formal definition of KL div. matches unsimplified version of Eq.10: {torch.allclose(pytorch_kld_, mrl_eq10_kld_)}")
print(f"Simplified version (Eq. 10) matches dot product formulation: {torch.allclose(mrl_eq10_kld_simplified_, mrl_dot_product_)}")

Formal definition of KL div. matches unsimplified version of Eq.10: True
Simplified version (Eq. 10) matches dot product formulation: True
