In [None]:
import torch
import math

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

#IPFP algo (for comparison)
def ipfp(M, a, b, its = 40):
  for _ in range(its):
    row_margs = M.sum(dim = 1, keepdim= True)
    M = M * (a.unsqueeze(1) / row_margs)
    col_margs = M.sum(dim = 0, keepdim= True)
    M = M * (b.unsqueeze(0) / col_margs)
  P = M
  return P

def G1(x1, R, R_cumsum):
    l = R.shape[0]
    x1 = x1.to(R.device).float()
    R = R.to(dtype=torch.float32)
    R_cumsum = R_cumsum.to(dtype=torch.float32)

    if x1.dim() == 1:
        x1 = x1.unsqueeze(-1)

    vals = torch.zeros((x1.shape[0], 1), dtype=torch.float32, device=x1.device)
    ml = (x1 == l)
    vals[ml] = 1.0

    fl = torch.floor(x1)
    fl_long = fl.long()

    not_l_mask = (~ml).squeeze(-1) & (fl_long.squeeze(-1) != 0)
    st = torch.zeros((x1.shape[0], 1), dtype=torch.float32, device=x1.device)

    idx = fl_long[not_l_mask] - 1
    st[not_l_mask] = R_cumsum[idx]
    vals[~ml] = st[~ml] + R[fl_long[~ml].squeeze(-1)] * (x1[~ml] - fl[~ml])

    return vals

def G2(x2, C, C_cumsum):
    l = C.shape[0]
    x2 = x2.to(C.device).float()
    C = C.to(dtype=torch.float32)
    C_cumsum = C_cumsum.to(dtype=torch.float32)

    if x2.dim() == 1:
        x2 = x2.unsqueeze(-1)

    vals = torch.zeros((x2.shape[0], 1), dtype=torch.float32, device=x2.device)
    ml = (x2 == l)
    vals[ml] = 1.0

    fl = torch.floor(x2)
    fl_long = fl.long()

    not_l_mask = (~ml).squeeze(-1) & (fl_long.squeeze(-1) != 0)
    st = torch.zeros((x2.shape[0], 1), dtype=torch.float32, device=x2.device)

    idx = fl_long[not_l_mask] - 1
    st[not_l_mask] = C_cumsum[idx]

    vals[~ml] = st[~ml] + C[fl_long[~ml].squeeze(-1)] * (x2[~ml] - fl[~ml])
    return vals
def inv_cdf_batch(u, cumsum, pmf):
    device = cumsum.device
    if u.dim() == 1: u = u.unsqueeze(-1).to(device)  # shape (N, 1)
    idx = torch.searchsorted(cumsum, u, right=True)
    idx = idx.clamp(max=len(pmf)-1)

    below = torch.where(idx > 0, cumsum[idx - 1], torch.zeros_like(u))
    delta = pmf[idx]
    offset = (u - below) / delta
    return idx.squeeze(-1) + offset.squeeze(-1)
def F_batch(x1, x2, E, E_cumsum):
    device = E.device
    # E is (l, l)
    if x2.dim() == 1: x2 = x2.unsqueeze(-1).to(device)
    if x1.dim() == 1: x1 = x1.unsqueeze(-1).to(device)
    m = x1.floor().long()
    n = x2.floor().long()

    m = torch.clamp(m, max=E.size(0)-1)
    n = torch.clamp(n, max=E.size(1)-1)
    dx = x1 - m.float()
    dy = x2 - n.float()

    def safe_gather(E_cumsum, m_idx, n_idx):
        mask = (m_idx >= 0) & (n_idx >= 0)
        vals = torch.zeros((m_idx.shape[0], 1), dtype=E_cumsum.dtype, device=E_cumsum.device)

        if mask.any():
            valid_indices = torch.stack([m_idx[mask], n_idx[mask]], dim=0)  # shape (2, N)
            gathered = E_cumsum[valid_indices[0], valid_indices[1]]
            vals[mask] = gathered.to(dtype=vals.dtype)

        return vals

    A = safe_gather(E_cumsum, m-1, n-1)
    B = safe_gather(E_cumsum, m-1, n)
    C_val = safe_gather(E_cumsum, m, n-1)
    return A + (C_val-A)*(dx) + (B-A)*(dy) + E[m, n]*(dx)*(dy)

def Cop_batch(u1, u2, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, E, E_cumsum):
    return F_batch(inv_cdf_batch(u1, E_rows_cumsum, E_rows), inv_cdf_batch(u2, E_cols_cumsum, E_cols), E, E_cumsum)  # or pass E if not global

def G(x1, x2, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E, E_cumsum):
  return Cop_batch(G1(x1, R, R_cumsum), G2(x2, C, C_cumsum), E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, E, E_cumsum)

def G_batch(E, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E_cumsum):
  #E_cop = [G(i + 1, j + 1) - (G(i, j + 1) + G(i + 1, j) - G(i, j)) for i in range(l) for j in range(l)]
  l = E.shape[0]
  I, J = torch.meshgrid(torch.arange(l), torch.arange(l), indexing='ij')
  I = I.flatten()  # shape (l^2,)
  J = J.flatten()
  I0, I1 = I, I+1
  J0, J1 = J, J+1

  G_11 = G(I1, J1, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E, E_cumsum)  # G(i+1, j+1)
  G_01 = G(I0, J1, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E, E_cumsum)  # G(i, j+1)
  G_10 = G(I1, J0, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E, E_cumsum)  # G(i+1, j)
  G_00 = G(I0, J0, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E, E_cumsum)  # G(i, j)
  E_cop_flat = G_11 - G_01 - G_10 + G_00
  E_cop = E_cop_flat.view(l, l)
  return E_cop

def KLD(P1, P2):
    P1 = torch.clamp(P1, min=1e-7)
    P2 = torch.clamp(P2, min=1e-7)
    return (P1 * (P1 / P2).log()).sum()

def MI(E, E_rows, E_cols):
  E_rows = torch.clamp(E_rows, min=1e-7)
  E_cols = torch.clamp(E_cols, min=1e-7)
  if E_rows.dim() == 1: E_rows = E_rows.unsqueeze(-1)
  if E_cols.dim() == 1: E_cols = E_cols.unsqueeze(0)
  E_ind = E_rows @ E_cols
  return KLD(E, E_ind)

In [None]:

divs = []

def generate_sparse_E(l, skewness=1.0):
    logits = torch.randn(l, l) * skewness
    probs = torch.softmax(logits.flatten(), dim=0).view(l, l)
    return probs
def generate_sparse_RC(l, skewness=1.0):
    logits = torch.randn(l) * skewness
    probs = torch.softmax(logits, dim=0)
    return probs
def marg_entropy(V):
  V = torch.clamp(V, min=1e-7)
  return -(V*V.log()).sum()
l = 10

#SPECIFY TARGET MARGINALS, R AND C
R = generate_sparse_RC(l, 6)
C = generate_sparse_RC(l, 4)

#SKEWNESS
E = generate_sparse_E(l, 2.5)
E_orig = E

#Prelim IPFP (before SKLAR)
E = ipfp(E, R, C, its = 4)
E_prelim = E

E_rows = E.sum(dim = 1, keepdim = False)
E_cols = E.sum(dim = 0, keepdim = False)
E_cols_cumsum = torch.cumsum(E_cols, dim=0)  # shape: (l,)
E_rows_cumsum = torch.cumsum(E_rows, dim=0)
E_cumsum = torch.cumsum(torch.cumsum(E, dim=0), dim = 1) #2d pref sum
'''
R = E_rows + 0.0001*torch.rand(l)
R = R / R.sum()
C = E_cols + 0.0001*torch.rand(l)
C = C / C.sum()
'''
R_cumsum = torch.cumsum(R, dim=0)
C_cumsum = torch.cumsum(C, dim=0)
print("Current row marginals entropy: ", marg_entropy(E_rows))
print("Current col marginals entropy: ", marg_entropy(E_cols))
print("Target row entropy: ", marg_entropy(R))
print("Target col entropy: ", marg_entropy(C))
print("Original joint MI: ", MI(E_orig, E_rows, E_cols))

Current row marginals entropy:  tensor(0.4957)
Current col marginals entropy:  tensor(0.0498)
Target row entropy:  tensor(0.4957)
Target col entropy:  tensor(0.0498)
Original joint MI:  tensor(12.2264)


In [None]:
E_cop = G_batch(E, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E_cumsum)
E_ipfp = ipfp(E, R, C)
print("KLD (E_original, E_copula) ", KLD(E_orig, E_cop))
print("KLD (E_original, E_ipfp)", KLD(E_orig, E_ipfp))
print("KLD (E_ipfp, E_copula)", KLD(E_ipfp, E_cop), KLD(E_cop, E_ipfp))
#print("KLD (E_ipfp, E_prelim)", KLD(E_ipfp, E_prelim), KLD(E_prelim, E_ipfp))


KLD (E_original, E_copula)  tensor(10.9968)
KLD (E_original, E_ipfp) tensor(11.0085)
KLD (E_ipfp, E_copula) tensor(4.7415e-07) tensor(-3.6090e-07)


In [None]:
  E = torch.tensor([[0.1, 0.1, 0.03, 0.02], [0.02, 0.05, 0.1, 0.08], [0.2, 0.01, 0.01, 0.08], [0.025, 0.05, 0.015, 0.01]])
  R = torch.tensor([0.5, 0.4, 0.05, 0.05])
  C = torch.tensor([0.3, 0.25, 0.3, 0.15 ])
  E_rows = E.sum(dim = 1, keepdim = False)
  E_cols = E.sum(dim = 0, keepdim = False)
  E_cols_cumsum = torch.cumsum(E_cols, dim=0)  # shape: (l,)
  E_rows_cumsum = torch.cumsum(E_rows, dim=0)
  E_cumsum = torch.cumsum(torch.cumsum(E, dim=0), dim = 1) #2d pref sum
  R_cumsum = torch.cumsum(R, dim=0)
  C_cumsum = torch.cumsum(C, dim=0)

  E_cop = G_batch(E, E_rows, E_rows_cumsum, E_cols, E_cols_cumsum, R, R_cumsum, C, C_cumsum, E_cumsum)
  E_ipfp = ipfp(E, R, C)


In [None]:
E_ipfp

tensor([[0.2181, 0.1657, 0.0824, 0.0338],
        [0.0325, 0.0618, 0.2048, 0.1009],
        [0.0354, 0.0013, 0.0022, 0.0110],
        [0.0139, 0.0212, 0.0105, 0.0043]])

In [None]:
import torch
import torch.nn.functional as F
from torch.distributions import Normal
from scipy.optimize import bisect

class MixtureOfGaussians:
    def __init__(self, weights, mus, sigmas):
        """
        weights: tensor of shape (K,), not necessarily normalized
        mus: tensor of shape (K,)
        sigmas: tensor of shape (K,)
        """
        self.weights = F.softmax(weights, dim=0)
        self.mus = mus
        self.sigmas = sigmas
        self.components = [Normal(mu, sigma) for mu, sigma in zip(mus, sigmas)]

    def pdf(self, x):
        x = torch.tensor(x) if not isinstance(x, torch.Tensor) else x
        return sum(w * comp.log_prob(x).exp() for w, comp in zip(self.weights, self.components))

    def cdf(self, x):
        x = torch.tensor(x) if not isinstance(x, torch.Tensor) else x
        return sum(w * comp.cdf(x) for w, comp in zip(self.weights, self.components))

    def icdf(self, p, low=-10.0, high=10.0, tol=1e-5):
        if isinstance(p, torch.Tensor):
            if p.dim() == 0:
                return torch.tensor(self._inverse_cdf_scalar(float(p), low, high, tol))
            else:
                return torch.tensor([self._inverse_cdf_scalar(float(pi), low, high, tol) for pi in p])
        else:
            return torch.tensor(self._inverse_cdf_scalar(float(p), low, high, tol))

    def _inverse_cdf_scalar(self, target_p, low, high, tol):
        def f(x):
            return self.cdf(torch.tensor(x)) - target_p
        return bisect(f, low, high, xtol=tol)


In [None]:
R_ws = R       # unnormalized weights
C_ws = C

mus = torch.tensor([0.0, 2, 3, 4])           # means of Gaussians
sigmas = torch.tensor([2.5, 2, 1, 1])        # std deviations
gaussians = [Normal(mu, sigma) for mu, sigma in zip(mus, sigmas)]

mog_r = MixtureOfGaussians(R_ws, mus, sigmas)
mog_c = MixtureOfGaussians(C_ws, mus, sigmas)


In [None]:
def F_dis(x1, x2, E, gaussians):
    x1 = torch.as_tensor(x1, dtype=torch.float32)
    x2 = torch.as_tensor(x2, dtype=torch.float32)
    cdfs_x1 = torch.stack([g.cdf(x1) for g in gaussians])  # [k, ...]
    cdfs_x2 = torch.stack([g.cdf(x2) for g in gaussians])  # [k, ...]

    # E[i,j] * cdfs_x1[i] * cdfs_x2[j] summed over i,j
    result = torch.einsum('ij,i...,j...->...', E, cdfs_x1, cdfs_x2)

    return result

#now try straight up sklar on
orig_r = MixtureOfGaussians(E_rows, mus, sigmas)
orig_c = MixtureOfGaussians(E_cols, mus, sigmas)

def fr_sklar(x1, x2, E, gaussians):
  x1 = torch.tensor(x1)
  x2 = torch.tensor(x2)
  u = mog_r.cdf(x1)
  v = mog_c.cdf(x2)
  x1_new = orig_r.icdf(u)
  x2_new = orig_c.icdf(v)

  return F_dis(x1_new, x2_new, E, gaussians)
import numpy as np
#check if
dists = []
for _ in range(500):
  x1 = (torch.rand(1) * 6.0 - 3.0).item()  # Sample from [-4.0, 5.0]
  x2 = (torch.rand(1) * 6.0 - 3.0).item()
  dists.append(torch.abs(fr_sklar(x1, x2, E, gaussians) - F_dis(x1, x2, E_ipfp, gaussians)))
print(np.array(dists).mean(), np.median(np.array(dists)))

0.028965749 0.021127956


In [17]:

import torch
import numpy as np

def numerical_pdf(cdf_fn, x1, x2, delta, *args):
    # Approximate ∂²F/∂x1∂x2 via central difference
    fpp = cdf_fn(x1 + delta, x2 + delta, *args)
    fpm = cdf_fn(x1 + delta, x2 - delta, *args)
    fmp = cdf_fn(x1 - delta, x2 + delta, *args)
    fmm = cdf_fn(x1 - delta, x2 - delta, *args)
    return (fpp - fpm - fmp + fmm) / (4 * delta**2)

def estimate_kl(fr_sklar, F_dis, E, E_ipfp, gaussians,
                x_min=-9, x_max=9, steps=50, delta=1e-2):
    x_vals = torch.linspace(x_min, x_max, steps)
    y_vals = torch.linspace(x_min, x_max, steps)
    dx = x_vals[1] - x_vals[0]
    dy = y_vals[1] - y_vals[0]
    kl = 0.0
    eps = 1e-8

    for x1 in x_vals:
        for x2 in y_vals:
            p = numerical_pdf(fr_sklar, x1.item(), x2.item(), delta, E, gaussians).item()
            q = numerical_pdf(F_dis, x1.item(), x2.item(), delta, E_ipfp, gaussians).item()
            p = max(p, eps)
            q = max(q, eps)

            #if p <= 0 or q <= 0:
            #  print(f"Invalid PDF: p={p}, q={q}, at x1={x1}, x2={x2}")
            kl += p * np.log(p / (q + eps)) * dx * dy

    return kl

kl_value = estimate_kl(fr_sklar, F_dis, E, E_ipfp, gaussians)
print(f"Estimated KL divergence: {kl_value:.6f}")


In [None]:
#Estimated KL divergence: -0.040498, 50 steps, on range -9 to 9
#Estimated KL divergence: -0.040350