In [1]:
import numpy as np
import torch
from torch.distributions import MultivariateNormal

In [2]:
qm, qv = torch.rand(2, 5), torch.rand(2, 5)
dm, dv = torch.rand(3, 5).exp(), torch.rand(3, 5).exp()

## KL divergence as computed by native PyTorch library.

$\DeclareMathOperator{\kld}{KLD}$
$\DeclareMathOperator{\trace}{tr}$
$\DeclareMathOperator{\mean}{\mu}$
$\DeclareMathOperator{\cov}{\Sigma}$

\begin{equation}
    \kld(Q \| D) = \frac{1}{2} \Big[ \log \frac{\det \cov_D}{\det \cov_Q} - k + \trace\{\cov_D^{-1}\cov_Q\} + (\mean_Q - \mean_D)^\intercal \cov_D^{-1}(\mean_Q - \mean_D) \Big]
\end{equation}

In [3]:
def pytorch_kld(qmean, qvar, dmean, dvar):
    kl = torch.zeros(qmean.size(0), dmean.size(0), device=qmean.device)
    p = []
    q = []
    for i in range(qmean.size(0)):
        q.append(MultivariateNormal(qmean[i, :], covariance_matrix=torch.diag(qvar[i, :])))
    for i in range(dmean.size(0)):
        p.append(MultivariateNormal(dmean[i, :], covariance_matrix=torch.diag(dvar[i, :])))
    for (i, j) in np.ndindex(len(q), len(p)):
        kl[i, j] = -1 * torch.distributions.kl_divergence(q[i], p[j])
    return kl

pytorch_kld_ = pytorch_kld(qm, qv, dm, dv)
print(f"Ground truth for KL divergence:\n{pytorch_kld_}")

Ground truth for KL divergence:
tensor([[-7.2569, -4.2285, -4.5962],
        [-9.0958, -6.4639, -6.0405]])


## Simplified KL divergence scoring function as per Eq. 10 of MRL paper (corrected).

\begin{equation}
    \kld(Q \| D) = \sum_{i=1}^k \log \sigma^2_{i_D} + \sum_{i=1}^k \frac{\sigma^2_{i_Q}}{\sigma^2_{i_D}} + \sum_{i=1}^k \frac{\mu_{i_Q}^2}{\sigma^2_{i_D}} - \sum_{i=1}^k \frac{2\mu_{i_Q}\mu_{i_D}}{\sigma^2_{i_D}} + \sum_{i=1}^k \frac{\mu_{i_D}^2}{\sigma^2_{i_D}}
\end{equation}

Optional: `remove_constants=False` to get full KL divergence.

\begin{equation}
    \kld(Q \| D) = \frac{1}{2} \Bigg[ \sum_{i=1}^k \log \sigma^2_{i_D} - \sum_{i=1}^k \log \sigma^2_{i_Q} - k + \sum_{i=1}^k \frac{\sigma^2_{i_Q}}{\sigma^2_{i_D}} + \sum_{i=1}^k \frac{\mu_{i_Q}^2}{\sigma^2_{i_D}} - \sum_{i=1}^k \frac{2\mu_{i_Q}\mu_{i_D}}{\sigma^2_{i_D}} + \sum_{i=1}^k \frac{\mu_{i_D}^2}{\sigma^2_{i_D}} \Bigg]
\end{equation}

In [4]:
def mrl_eq10_kld(qmean, qvar, dmean, dvar, remove_constants=False):
    """ Constants = k, 0.5, torch.log(qvar).sum() (constant w.r.t. document ranking).
        If remove_constants=True, we get Eq. 10 from MRL paper.
        If remove_constants=False, we get the same implementation as pytorch_kld.
    """
    k = qmean.size(1)
    logvar_ratio_term = torch.log(dvar).sum(1)[None, :]
    if not remove_constants:
        logvar_ratio_term = logvar_ratio_term - torch.log(qvar).sum(1)[:, None]
    trace_term = (qvar[:, None, :] / dvar[None, :, :]).sum(2)
    square_term = ((qmean[:, None, :] - dmean[None, :, :])**2 / dvar[None, :, :]).sum(2)
    kld = logvar_ratio_term + trace_term + square_term
    if not remove_constants:
        kld = 0.5 * (kld - k)
    return -kld

mrl_eq10_kld_ = mrl_eq10_kld(qm, qv, dm, dv)
mrl_eq10_kld_simplified_ = mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)
print(f"Same output as 'pytorch_kld(qm, qv, dm, dv)':\n{mrl_eq10_kld_}")
print(f"Eq. 10 simplified, constants removed:\n{mrl_eq10_kld_simplified_}")

Same output as 'pytorch_kld(qm, qv, dm, dv)':
tensor([[-7.2569, -4.2285, -4.5962],
        [-9.0958, -6.4639, -6.0405]])
Eq. 10 simplified, constants removed:
tensor([[-15.9074,  -9.8506, -10.5861],
        [-16.2117, -10.9480, -10.1012]])


## Simplified KL divergence scoring function as a dot product.

\begin{equation}
    \kld(Q \| D) = \vec{q}^\intercal \cdot \vec{d}
\end{equation}
\begin{align}
    & \vec{q} = \Big[ 1, \sigma^2_{1_Q}, \dotsc, \sigma^2_{k_Q}, \mu_{1_Q}^2, \dotsc, \mu_{k_Q}^2, \mu_{1_Q}, \dotsc, \mu_{k_Q} \Big] \\
    & \vec{d} = \Bigg[ \gamma_{D}, \frac{1}{\sigma^2_{1_D}}, \dotsc, \frac{1}{\sigma^2_{k_D}}, \frac{1}{\sigma^2_{1_D}}, \dotsc, \frac{1}{\sigma^2_{k_D}}, -\frac{2\mu_{1_D}}{\sigma^2_{1_D}}, \dotsc, -\frac{2\mu_{k_D}}{\sigma^2_{k_D}} \Bigg]
\end{align}
Where
\begin{equation}
    \gamma_{D} = \sum_{i=1}^k \Bigg( \log \sigma^2_{i_D} + \frac{\mu_{i_D}^2}{\sigma^2_{i_D}} \Bigg)
\end{equation}

Output should be the same as Eq. 10 of MRL with constants removed.

In [5]:
def mrl_query_rep(qmean, qvar):
    batch_size = qmean.size(0)
    k = qmean.size(1)
    rep = torch.zeros(batch_size, 3 * k + 1, device=qmean.device)
    rep[:, 0] = 1
    rep[:, 1:k + 1] = qvar
    rep[:, k + 1:2 * k + 1] = qmean ** 2
    rep[:, 2 * k + 1:] = qmean
    return rep

def mrl_document_rep(dmean, dvar):
    batch_size = dmean.size(0)
    k = dmean.size(1)
    rep = torch.zeros(batch_size, 3 * k + 1, device=dmean.device)
    rep[:, 0] = - (torch.log(dvar) + dmean**2 / dvar).sum(1)
    rep[:, 1:k + 1] = - 1 / dvar
    rep[:, k + 1:2 * k + 1] = - 1 / dvar
    rep[:, 2 * k + 1:] = (2 * dmean) / dvar
    return rep

def mrl_dot_product(qmean, qvar, dmean, dvar):
    qrep = mrl_query_rep(qmean, qvar)
    drep = mrl_document_rep(dmean, dvar)
    return qrep @ drep.T

mrl_dot_product_ = mrl_dot_product(qm, qv, dm, dv)
print(f"Same output as 'mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)':\n{mrl_dot_product_}")

Same output as 'mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)':
tensor([[-15.9074,  -9.8506, -10.5861],
        [-16.2117, -10.9480, -10.1012]])


## Full KL divergence as a dot product.

\begin{equation}
    \kld(Q \| D) = \frac{1}{2} \big(\vec{q'}^\intercal \cdot \vec{d'} - k \big)
\end{equation}
\begin{align}
    & \vec{q'} = \Big[ 1, \gamma_{Q}, \sigma^2_{1_Q}, \dotsc, \sigma^2_{k_Q}, \mu_{1_Q}^2, \dotsc, \mu_{k_Q}^2, \mu_{1_Q}, \dotsc, \mu_{k_Q} \Big] \\
    & \vec{d'} = \Bigg[ \gamma_{D}, 1, \frac{1}{\sigma^2_{1_D}}, \dotsc, \frac{1}{\sigma^2_{k_D}}, \frac{1}{\sigma^2_{1_D}}, \dotsc, \frac{1}{\sigma^2_{k_D}}, -\frac{2\mu_{1_D}}{\sigma^2_{1_D}}, \dotsc, -\frac{2\mu_{k_D}}{\sigma^2_{k_D}} \Bigg]
\end{align}
Where
\begin{equation}
    \gamma_{D} = \sum_{i=1}^k \Bigg( \log \sigma^2_{i_D} + \frac{\mu_{i_D}^2}{\sigma^2_{i_D}} \Bigg)
\end{equation}
and
\begin{equation}
    \gamma_{Q} = \sum_{i=1}^k \log \sigma^2_{i_Q}
\end{equation}

In [6]:
def mrl_query_rep_full(qmean, qvar):
    batch_size = qmean.size(0)
    k = qmean.size(1)
    rep = torch.zeros(batch_size, 3 * k + 2, device=qmean.device)
    rep[:, 0] = 1
    rep[:, 1] = torch.log(qvar).sum(1)
    rep[:, 2:k + 2] = qvar
    rep[:, k + 2:2 * k + 2] = qmean ** 2
    rep[:, 2 * k + 2:] = qmean
    return rep

def mrl_document_rep_full(dmean, dvar):
    batch_size = dmean.size(0)
    k = dmean.size(1)
    rep = torch.zeros(batch_size, 3 * k + 2, device=dmean.device)
    rep[:, 0] = - (torch.log(dvar) + dmean**2 / dvar).sum(1)
    rep[:, 1] = 1
    rep[:, 2:k + 2] = - 1 / dvar
    rep[:, k + 2:2 * k + 2] = - 1 / dvar
    rep[:, 2 * k + 2:] = (2 * dmean) / dvar
    return rep

def mrl_dot_product_full(qmean, qvar, dmean, dvar):
    k = qmean.size(1)
    qrep = mrl_query_rep_full(qmean, qvar)
    drep = mrl_document_rep_full(dmean, dvar)
    return 0.5 * (qrep @ drep.T + k)

mrl_dot_product_full_ = mrl_dot_product_full(qm, qv, dm, dv)
print(f"Same output as 'pytorch_kld(qm, qv, dm, dv)':\n{mrl_dot_product_full_}")

Same output as 'pytorch_kld(qm, qv, dm, dv)':
tensor([[-7.2569, -4.2285, -4.5962],
        [-9.0958, -6.4639, -6.0405]])


## Some final sanity checks for matching outputs.

1. The formal definition of KL divergence (as per PyTorch implementation), should match the unsimplified version of the (corrected) Eq. 10 of the MRL paper.
2. The simplified KL divergence-based scoring function in Eq. 10, should match its dot product formulation.
3. The formal definition of KL divergence (as per PyTorch implementation), should match the dot product formulation of the unsimplified version of Eq. 10, that includes constants.

In [7]:
print(f"Formal definition of KL div. matches unsimplified version of Eq.10: {torch.allclose(pytorch_kld_, mrl_eq10_kld_)}")
print(f"Simplified version of Eq. 10 matches dot product formulation: {torch.allclose(mrl_eq10_kld_simplified_, mrl_dot_product_)}")
print(f"Formal definition of KL div. matches dot product formulation including constants: {torch.allclose(pytorch_kld_, mrl_dot_product_full_)}")

Formal definition of KL div. matches unsimplified version of Eq.10: True
Simplified version of Eq. 10 matches dot product formulation: True
Formal definition of KL div. matches dot product formulation including constants: True
