In [1]:
pip install pot

Collecting pot
  Downloading POT-0.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (32 kB)
Downloading POT-0.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (835 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m835.4/835.4 kB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pot
Successfully installed pot-0.9.4


In [2]:
import numpy as np
import ot
import matplotlib.pyplot as plt
import time

In [10]:
def sinkhorn_knopp(mu, nu, C, reg=0.1, n_iter=1000, tol=1e-9):
    count = 0
    K = np.exp(-C / reg)
    u = np.ones_like(mu)
    v = np.ones_like(nu)
    for _ in range(n_iter):
        count += 1
        u_prev = u.copy()
        u = mu / (K @ v)
        v = nu / (K.T @ u)
        if np.linalg.norm(u - u_prev, 1) < tol:
            break
    Gamma = np.diag(u) @ K @ np.diag(v)
    return Gamma, count

def IPOT(mu, nu, C, beta=2, tol=1e-9, max_iter=1000, L=1):
    count = 0
    m, n = C.shape
    b = np.ones(n) / n
    G = np.exp(-C/beta)
    Gamma = np.ones((m, n))
    for t in range(max_iter):
        count += 1
        Q = G * Gamma
        for l in range(L):
            a = mu / (Q @ b)
            b = nu / (Q.T @ a)
        Gamma = np.diag(a) @ Q @ np.diag(b)
        if np.linalg.norm(Q - Gamma) < tol:
            break
    return Gamma, count

def IPOT_1(mu, nu, C, beta=2, tol=1e-9, max_iter=1000, L=1):
    count = 0
    m = len(mu)
    n = len(nu)
    a = np.ones([m,])
    b = np.ones([n,])
    Gamma = np.ones((m,n))/(m*n)
    G = np.exp(-(C/beta))
    for i in range(max_iter):
        count += 1
        Q = G * Gamma
        for l in range(L):
            a = mu/np.matmul(Q,b)
            b = nu/np.matmul(np.transpose(Q),a)
        Gamma = np.expand_dims(a,axis=1) * Q * np.expand_dims(b,axis=0)
        if np.linalg.norm(Q - Gamma) < tol:
            break
    return Gamma, count

In [4]:
np.random.seed(42)

num = 1000

english_embeddings = np.random.rand(num, 300)
vietnamese_embeddings = np.random.rand(num, 300)

cost_matrix = ot.dist(english_embeddings, vietnamese_embeddings, metric='euclidean')
cost_matrix_new = cost_matrix/cost_matrix.max()

In [5]:
# sinkhorn
reg_list = [0.1, 0.01, 0.05]

for reg in reg_list:
    ot_start = time.time()
    ot_plan, ot_count = sinkhorn_knopp(np.ones(num) / num, np.ones(num) / num, cost_matrix_new, reg)
    ot_end = time.time()
    print()
    total_ot_cost = (ot_plan * cost_matrix).sum()
    print("Total OT Cost:", total_ot_cost)
    print("Total OT Time:", ot_end - ot_start)
    print("Num iter:", ot_count)


Total OT Cost: 7.012189574275423
Total OT Time: 0.0862886905670166
Num iter: 4

Total OT Cost: 6.575419710861599
Total OT Time: 0.1262683868408203
Num iter: 67

Total OT Cost: 6.9600939689426955
Total OT Time: 0.10007762908935547
Num iter: 7


In [6]:
# IPOT
ipot_start = time.time()
ipot_plan, ipot_count = IPOT(np.ones(num) / num, np.ones(num) / num, cost_matrix_new)
ipot_end = time.time()

print()
total_ipot_cost = (ipot_plan * cost_matrix).sum()
print("Total IPOT Cost:", total_ipot_cost)
print("Total IPOT Time:", ipot_end - ipot_start)
print("Num iter:", ipot_count)


Total IPOT Cost: 6.410761435730385
Total IPOT Time: 86.92793369293213
Num iter: 1000


In [7]:
# IPOT_1
ipot_start = time.time()
ipot_plan, ipot_count = IPOT_1(np.ones(num) / num, np.ones(num) / num, cost_matrix_new)
ipot_end = time.time()

print()
total_ipot_cost = (ipot_plan * cost_matrix).sum()
print("Total IPOT Cost:", total_ipot_cost)
print("Total IPOT Time:", ipot_end - ipot_start)
print("Num iter:", ipot_count)


Total IPOT Cost: 6.410761435730386
Total IPOT Time: 8.292207717895508
Num iter: 1000


**With CPU**

| Method       | Total Cost        | Total Time (seconds) | Num Iter |
|---------------|--------------------|----------------------|----------|
| **OT (0.1)**    | 7.012189574275424 | 0.1614372730255127   | 4        |
| **OT (0.01)**    | 6.575419710861599 | 0.3141932487487793   | 192      |
| **OT (0.05)**    | 6.960093968942695 | 0.1728804111480713   | 7        |
| **IPOT (0)**  | 6.410761435730382 | 162.2253954410553    | 1000     |
| **IPOT (1)**  | 6.410761435730384 | 16.870448112487793    | 1000     |

**With GPU**

| Method       | Total Cost        | Total Time (seconds) | Num Iter |
|---------------|--------------------|----------------------|----------|
| **OT (0.1)**    | 7.012189574275423 | 0.0862886905670166   | 4        |
| **OT (0.01)**    | 6.575419710861599 | 0.1262683868408203   | 67      |
| **OT (0.05)**    | 6.9600939689426955 | 0.10007762908935547   | 7        |
| **IPOT (0)**  | 6.410761435730385 | 86.92793369293213    | 1000     |
| **IPOT (1)**  | 6.410761435730386 | 8.292207717895508    | 1000     |