In [46]:
import torch
import numpy as np

def get_nll(p_samples, q_samples, log_sigma):
    # p_samples : (M, z)
    # q_samples : (N, z)
    M = len(p_samples)
    N = len(q_samples)
    z_dim = p_samples.shape[1]
    
    # (N, M)
    distance = torch.norm(q_samples.unsqueeze(1) - p_samples.unsqueeze(0), dim=2) ** 2
    alpha = -1/(2*np.exp(log_sigma)**2)
    nll = -torch.mean(torch.logsumexp(alpha*distance, dim=0))
    nll = nll + 0.5*z_dim*(2*log_sigma-np.log(np.e)) + np.log(N)
    
    return nll.item()

def get_optimum_log_sigma(p_samples1, p_samples2, min_log_sigma=-3, max_log_sigma=3):
    log_sigmas = np.linspace(-3, 3, 100)
    nlls = np.array([get_nll(p_samples1, p_samples2, log_sigma) for log_sigma in log_sigmas])
    return log_sigmas[np.argmin(nlls)]

def get_cross_nll(p_samples, q_samples, log_sigma):
    nll_p2q = get_nll(p_samples, q_samples, log_sigma)
    nll_q2p = get_nll(q_samples, p_samples, log_sigma)
    return (nll_p2q + nll_q2p) / 2
    

In [47]:
p_samples1 = torch.randn(10000, 2)
p_samples2 = torch.randn(10000, 2)
q_samples = torch.randn(10000, 2)
log_sigma = get_optimum_log_sigma(p_samples1, p_samples2)
print(log_sigma)
cross_nll = get_cross_nll(p_samples, q_samples, log_sigma)
print(cross_nll)

-1.3636363636363635
0.008036613464355469
