In [1]:
import numpy as np
import torch
from torch.distributions import MultivariateNormal

In [2]:
qm, qv = torch.rand(1, 5), torch.rand(1, 5)
dm, dv = torch.rand(3, 5).exp(), torch.rand(3, 5).exp()

In [3]:
def pytorch_kld(qmean, qvar, dmean, dvar):
    kl = torch.zeros(qmean.size(0), dmean.size(0), device=qmean.device)
    p = []
    q = []
    for i in range(qmean.size(0)):
        q.append(MultivariateNormal(qmean[i, :], covariance_matrix=torch.diag(qvar[i, :])))
    for i in range(dmean.size(0)):
        p.append(MultivariateNormal(dmean[i, :], covariance_matrix=torch.diag(dvar[i, :])))
    for (i, j) in np.ndindex(len(q), len(p)):
        kl[i, j] = -1 * torch.distributions.kl_divergence(q[i], p[j])
    return kl

print("Ground truth for KL divergence:", pytorch_kld(qm, qv, dm, dv))

Ground truth for KL divergence: tensor([[-4.1362, -4.1960, -3.5086]])


In [4]:
def mrl_eq10_kld(qmean, qvar, dmean, dvar, remove_constants=False):
    """ Constants = k, 0.5, torch.log(qvar).sum() (constant w.r.t. document ranking).
        If remove_constants=True, we get Eq. 10 from MRL paper.
        If remove_constants=False, we get the same implementation as pytorch_kld.
    """
    k = qmean.size(1)
    logvar_ratio_term = torch.log(dvar).sum(1)
    if not remove_constants:
        logvar_ratio_term -= torch.log(qvar).sum(1)
    trace_term = (qvar / dvar).sum(1)
    square_term = ((qmean - dmean)**2 / dvar).sum(1)
    kld = logvar_ratio_term + trace_term + square_term
    if not remove_constants:
        kld = 0.5 * (kld - k)
    return -kld

print("Same output as 'pytorch_kld(qm, qv, dm, dv)':", mrl_eq10_kld(qm, qv, dm, dv))
print("Constants removed:", mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True))

Same output as 'pytorch_kld(qm, qv, dm, dv)': tensor([-4.1362, -4.1960, -3.5086])
Constants removed: tensor([-11.4154, -11.5351, -10.1603])


In [5]:
def mrl_query_rep(qmean, qvar):
    batch_size = qmean.size(0)
    k = qmean.size(1)
    rep = torch.zeros(batch_size, 1 + 3 * k, device=qmean.device)
    rep[:, 0] = 1
    rep[:, 1:k + 1] = qvar
    rep[:, k + 1:2 * k + 1] = qmean ** 2
    rep[:, 2 * k + 1:] = qmean
    return rep

def mrl_document_rep(dmean, dvar):
    batch_size = dmean.size(0)
    k = dmean.size(1)
    rep = torch.zeros(batch_size, 1 + 3 * k, device=dmean.device)
    rep[:, 0] = - (torch.log(dvar) + dmean**2 / dvar).sum(1)
    rep[:, 1:k + 1] = - 1 / dvar
    rep[:, k + 1:2 * k + 1] = - 1 / dvar
    rep[:, 2 * k + 1:] = (2 * dmean) / dvar
    return rep

def mrl_dot_product(qmean, qvar, dmean, dvar):
    qrep = mrl_query_rep(qmean, qvar)
    drep = mrl_document_rep(dmean, dvar)
    return qrep @ drep.T
    # return torch.einsum("ij,jk->k", qrep, drep.T)

print("Same output as 'mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)':", mrl_dot_product(qm, qv, dm, dv))

Same output as 'mrl_eq10_kld(qm, qv, dm, dv, remove_constants=True)': tensor([[-11.4154, -11.5351, -10.1603]])
