In [22]:
from __future__ import annotations
import torch
from torch import Tensor
from torch.distributions import Categorical, kl_divergence

class Discrete:
    def __init__(self, atoms: Tensor, idx_dist: Categorical):
        self.atoms = atoms
        self.idx_dist = idx_dist

def proj_kl_div(P: Discrete, Q: Discrete):
    dists = torch.cdist(P.atoms, Q.atoms, p=1) # [B, n, m]
    values, indices = torch.topk(dists, k=2, largest=False)
    v0, v1 = values[..., 0], values[..., 1]
    i0, i1 = indices[..., 0], indices[..., 1]
    w0, w1 = v1 / (v0 + v1), v0 / (v0 + v1)
    
    probs = torch.zeros(Q.atoms.shape[:-1])
    probs.scatter_add_(-1, i0, w0 * P.idx_dist.probs)
    probs.scatter_add_(-1, i1, w1 * P.idx_dist.probs)
    
    proj_P = Discrete(Q.atoms, Categorical(probs))
    return kl_divergence(proj_P.idx_dist, Q.idx_dist)
