In [89]:
from dataclasses import dataclass
from typing import Union

import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt

np.random.seed(3698)

In [90]:
@dataclass()
class UOT:
    C: np.ndarray
    a: np.ndarray
    b: np.ndarray
    tau: float

@dataclass()
class EntRegUOT(UOT):
    eta: float

In [91]:
def calc_R(p: EntRegUOT) -> float:
    n = p.C.shape[0]
    R = max(np.max(np.log(p.a)), np.max(np.log(p.b))) + max(np.log(n), np.max(np.abs(p.C)) / p.eta - np.log(n))
    return R

def calc_U(p: UOT, 
           eps: np.float128) -> float:
    n = p.C.shape[0]
    S = 0.5 * (alpha + beta) + 0.5 + 0.25 / np.log(n)
    T = 0.5 * (alpha + beta) * (np.log(0.5 * (alpha + beta)) + 2 * np.log(n) - 1) + np.log(n) + 2.5
    U = max(S + T, 2 * eps, 4 * eps * np.log(n) / p.tau, 4 * eps * (alpha + beta) * np.log(n) / p.tau)
    return U

def calc_k_stop(p: EntRegUOT, 
                eps: np.float128) -> int:
    R = calc_R(p)
    U = calc_U(p, eps)
    k_float = (p.tau * U + 1) * (np.log(8 * p.eta * R) + np.log(p.tau * (p.tau + 1)) + 3 * np.log(U / eps))
    return int(k_float)

def calc_B(p: EntRegUOT, 
           u: np.float128,
           v: np.float128) -> np.ndarray:
    return np.diag(np.exp(u / p.eta)) * np.exp(- p.C / p.eta) * np.diag(np.exp(v / p.eta))
    

def sinkhorn_approx_entreg_uot(p: EntRegUOT, 
                               eps: float) -> np.ndarray:
    # Find problem dimension
    n = p.C.shape[0]

    # Initialize
    u = np.zeros(n)
    v = np.zeros(n)

    # Find stopping condition
    k_stop = calc_k_stop(p, eps)

    # Loop
    scale = p.eta * p.tau / (p.eta + p.tau)
    for k in range(k_stop + 1):
        # 
        X = calc_B(p, u, v)

        # Update
        if k % 2 == 0:
            ak = X.sum(-1)
            u = (u / p.eta + np.log(p.a) - np.log(ak)) * scale
        else:
            bk = X.sum(0)
            v = (v / p.eta + np.log(p.b / bk)) * scale
    
    # Calculate and return
    return calc_B(p, u, v)

### Configuration

In [92]:
# Dimension
n = 10

# Regularization
tau = 5.0

# Mass normalization
alpha = 2.0
beta = 4.0

# Number of eps
neps = 20


### Generation

In [95]:
np.random.seed(3698)

# Cost matrix
C = np.random.uniform(low=1.0, high=50.0, size=(n, n)).astype(np.float128)
C = (C + C.T) / 2.0

# Marginal vectors
a = np.random.rand(n).astype(np.float128)
b = np.random.rand(n).astype(np.float128)

a = a / a.sum() * alpha
b = b / b.sum() * beta

# Epsilons
eps_arr = np.logspace(start=0, stop=-1, num=neps).astype(np.float128)

In [96]:
for eps in eps_arr:
    # Entropic regularization parameter
    uot_p = UOT(C, a, b, tau)
    U = calc_U(uot_p, eps)
    eta = eps / U
    print(eta)

    # Convert to Entropic Regularized UOT
    p = EntRegUOT(C, a, b, tau, eta)

    # Sinkhorn
    sinkhorn_approx_entreg_uot(p, eps)

0.044400032300922804798
0.03933251410852248201
0.034843368032076103356


  return np.diag(np.exp(u / p.eta)) * np.exp(- p.C / p.eta) * np.diag(np.exp(v / p.eta))
  return np.diag(np.exp(u / p.eta)) * np.exp(- p.C / p.eta) * np.diag(np.exp(v / p.eta))
  return np.diag(np.exp(u / p.eta)) * np.exp(- p.C / p.eta) * np.diag(np.exp(v / p.eta))
  return np.diag(np.exp(u / p.eta)) * np.exp(- p.C / p.eta) * np.diag(np.exp(v / p.eta))
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)
  v = (v / p.eta + np.log(p.b / bk)) * scale
  v = (v / p.eta + np.log(p.b / bk)) * scale


0.030866582605652533494
0.027343680463797096624
0.0242228584504628126
0.021458225870069239871
0.019009129679412841002
0.016839556697590495855
0.01491760404362310334
0.01321501001473287017
0.011706738506988505137
0.010370610867356031119
0.009186979763656611769
0.008138440276792860759
0.0072095739669466349547
0.006386722150323103059
0.005657785052547709833
0.005012043885330579787
0.004440003230092280726
