In [1]:
import numpy as np
import matplotlib.pyplot as plt

from src.rot.problem import ROT, EntropicROT
from src.rot.sinkhorn import calc_U, calc_k_formula, robust_sinkhorn_eps
from src.utils import norm_inf

import time

In [2]:
# Dimension
n = 100

# Regularization
tau = np.float64(1.0)

In [3]:
np.random.seed(3698)

# Cost matrix
C = np.random.uniform(1.0, 50.0, (n, n)).astype(np.float64)

# Marginal vectors
a = np.random.uniform(0.1, 1.0, n).astype(np.float64)
b = np.random.uniform(0.1, 1.0, n).astype(np.float64)

a = a / a.sum()
b = b / b.sum()

# Varying $\epsilon$

In [None]:
# Original UOT problem
rsot = RSOT(C, a, b, tau)

# Optimal solution
f_optimal, X_optimal = exact_rsot(rsot)

print('Optimal:', f_optimal)

In [None]:
# Number of eps
neps = 30

# Epsilons
eps_arr = np.logspace(start=-1, stop=-5, num=neps).astype(np.float64)

kfs, kcs = np.zeros((2, neps))

for i, eps in enumerate(eps_arr):
    print(f'Epsilon {i}:', eps)
    start = time.time()

    # Entropic regularization parameter
    U = calc_U(rsot, eps)
    eta = eps / U

    # Convert to Entropic Regularized UOT
    ersot = EntropicRSOT(C, a, b, tau, eta)

    # Sinkhorn
    _, log = robust_semisinkhorn_eps(ersot, f_optimal, eps, 
                                     patience=1000,
                                     save_uv=False, verbose=True)
    
    plt.plot(log['f'][1:])
    plt.axhline(f_optimal, color='red')
    plt.show()

    # Find k
    kfs[i] = calc_k_formula(ersot, eps)
    kcs[i] = len(log['f']) - 1000

    print('Time elapsed:', time.time() - start)
    print('=================')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5), dpi=150)

ax[0].plot(np.log(1 / eps_arr), np.log(kcs), marker='.', label='$\log k_c$')
ax[0].plot(np.log(1 / eps_arr), np.log(kfs), marker='.', label='$\log k_f$')
ax[0].legend()

ax[1].plot(np.log(1 / eps_arr), kfs / kcs, marker='.', label='$k_f$ / $k_c$')
ax[1].legend()

# Theory checking

In [4]:
# Original UOT problem
rot = ROT(C, a, b, tau)

In [5]:
# Optimal solution for ROT
X_hat = rot.optimize_f('SCS')

f_hat = rot.calc_f(X_hat)

print('Optimal:', f_hat)

Optimal: 1.673249265519985


In [6]:
# Entropic regularization parameter
eta = np.float64(0.0)

# Convert to Entropic Regularized ROT
erot = rot.entropic_regularize(eta)

In [7]:
# Optimal solution for EntropicROT
X_star = erot.optimize_g('SCS')

f_star = erot.calc_f(X_star)
g_star = erot.calc_g(X_star)

print('Optimal:', f_star, g_star)

Optimal: 2.238064837665014 2.238064837665014


In [11]:
# Optimal solution for EntropicROT (dual)
u_star, v_star = erot.optimize_h(solver='SCS')
X_star_dual = erot.calc_B(u_star, v_star)

f_star_dual = erot.calc_f(X_star_dual)
g_star_dual = erot.calc_g(X_star_dual)
h_star_dual = erot.calc_h(u_star, v_star)

print('Optimal:', f_star_dual, g_star_dual)

Optimal: nan nan


In [None]:
eps = 0.001

# Entropic regularization parameter
U = calc_U(rsot, eps)
eta = eps / U

# Convert to Entropic Regularized UOT
ersot = rsot.entropic_regularize(eta)

start = time.time()

# Sinkhorn
_, log = robust_semisinkhorn_eps(ersot, f_star, eps, 
                                 save_uv=False,
                                 patience=1000, verbose=True)

print('Time elapsed:', time.time() - start)

In [None]:
plt.figure(figsize=(10, 5), dpi=150)
plt.plot(log['f'][1:])
plt.axhline(f_star, color='red')
plt.axhline(f_hat, color='green')
plt.show()

In [None]:
y = []

for uk, vk in zip(log['u'], log['v']):
    Xk = calc_B(ersot, uk, vk)
    # y.append(norm_inf(Xk - X_hat))
    y.append(calc_g_rsot(ersot, Xk))

plt.plot(y)