In [1]:
pip install pot

Collecting pot
  Downloading POT-0.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (32 kB)
Downloading POT-0.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (835 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m835.4/835.4 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pot
Successfully installed pot-0.9.4


In [2]:
import numpy as np
import ot
import matplotlib.pyplot as plt
import time

In [9]:
def sinkhorn_knopp(mu, nu, C, reg=0.1, n_iter=1000, tol=1e-9):
    count = 0
    K = np.exp(-C / reg)
    u = np.ones_like(mu)
    v = np.ones_like(nu)
    for _ in range(n_iter):
        count += 1
        u_prev = u.copy()
        u = mu / (K @ v)
        v = nu / (K.T @ u)
        if np.linalg.norm(u - u_prev, 1) < tol:
            break
    Gamma = np.diag(u) @ K @ np.diag(v)
    return Gamma, count

def IPOT(mu, nu, C, beta=2, tol=1e-9, max_iter=1000, L=1):
    count = 0
    m = len(mu)
    n = len(nu)
    a = np.ones([m,])
    b = np.ones([n,])
    Gamma = np.ones((m,n))/(m*n)
    G = np.exp(-(C/beta))
    for i in range(max_iter):
        count += 1
        Q = G * Gamma
        for l in range(L):
            a = mu/np.matmul(Q,b)
            b = nu/np.matmul(np.transpose(Q),a)
        Gamma = np.expand_dims(a,axis=1) * Q * np.expand_dims(b,axis=0)
        if np.linalg.norm(Q - Gamma) < tol:
            break
    return Gamma, count

In [4]:
np.random.seed(42)

num = 1000

english_embeddings = np.random.rand(num, 300)
vietnamese_embeddings = np.random.rand(num, 300)

cost_matrix = ot.dist(english_embeddings, vietnamese_embeddings, metric='euclidean')

In [None]:
# sinkhorn
reg_list = [0.1, 0.01, 0.05]

for reg in reg_list:
    ot_start = time.time()
    ot_plan, ot_count = sinkhorn_knopp(np.ones(num) / num, np.ones(num) / num, cost_matrix, reg)
    ot_end = time.time()
    print()
    total_ot_cost = (ot_plan * cost_matrix).sum()
    print("Total OT Cost:", total_ot_cost)
    print("Total OT Time:", ot_end - ot_start)
    print("Num iter:", ot_count)

In [None]:
# IPOT
ipot_start = time.time()
ipot_plan, ipot_count = IPOT(np.ones(num) / num, np.ones(num) / num, cost_matrix)
ipot_end = time.time()

print()
total_ipot_cost = (ipot_plan * cost_matrix).sum()
print("Total IPOT Cost:", total_ipot_cost)
print("Total IPOT Time:", ipot_end - ipot_start)
print("Num iter:", ipot_count)

### Dimension 16 - Data 1000

**With CPU**

| Method       | Total Cost        | Total Time (seconds) | Num Iter |
|---------------|--------------------|----------------------|----------|
| **OT (0.1)**    | 1.1110577541180653 | 0.1927204132080078   |   30     |
| **OT (0.01)**    | 0.8928369460201366 | 1.3399744033813477  |    1000    |
| **OT (0.05)**    | 0.9411637510597916 | 1.3571662902832031   |    355    |
| **IPOT**  | 0.8908035207276896 |  18.85797119140625  |  1000      |

**With GPU**

| Method       | Total Cost        | Total Time (seconds) | Num Iter |
|---------------|--------------------|----------------------|----------|
| **OT (0.1)**    | 1.1110577541180653 | 0.25974488258361816   | 30        |
| **OT (0.01)**    | 0.8928369460201366 | 0.936866044998169   | 1000      |
| **OT (0.05)**    | 0.9411637510597916 | 0.6413471698760986   | 353        |
| **IPOT**  | 0.8908035207276898 | 7.8923821449279785    | 1000     |


### Dimension 300 - Data 1000

**With CPU**

| Method       | Total Cost        | Total Time (seconds) | Num Iter |
|---------------|--------------------|----------------------|----------|
| **OT (0.1)**    | 6.64015543733569 |  0.2649712562561035  |  157       |
| **OT (0.01)**    | 6.407529926342345 |  17.356149911880493  |   1000      |
| **OT (0.05)**    | 6.463452144995755 |  0.7579765319824219  |   440      |
| **IPOT**  | 6.405467976171361 | 13.811148881912231   |  1000       |

**With GPU**

| Method       | Total Cost        | Total Time (seconds) | Num Iter |
|---------------|--------------------|----------------------|----------|
| **OT (0.1)**    | 6.64015543733569 |  0.43781304359436035  |    157     |
| **OT (0.01)**    | 6.407529926342345 | 26.59229016304016   |   1000      |
| **OT (0.05)**    | 6.463452144995755 |   0.44510865211486816  |    440     |
| **IPOT**  | 6.405467976171361 | 10.711238622665405   |    1000     |