In [36]:
import tensorflow as tf
from tensorflow.keras import Model, layers
import numpy as np

In [37]:
def dist(a, b):
    M1 = tf.tile(tf.expand_dims(a, axis=1), [1, y.shape[0], 1])  # (na, nb, 2)
    M2 = tf.tile(tf.expand_dims(b, axis=0), [x.shape[0], 1, 1])  # (na, nb, 2)
    M = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(M1, M2)), axis=2))  # (na, nb)
    return M

In [38]:
def sinkhorn(a, b, M,  reg, L):
    gamma = tf.cast(1/reg, dtype=tf.float32)
    K = tf.exp(-tf.math.scalar_mul(gamma, M))
    v = tf.cast(np.ones((b.shape[0],)), dtype = tf.float32)
    for _ in range(L):
        u = tf.math.divide(a, tf.linalg.matvec(K, v))
        v = tf.math.divide(b, tf.linalg.matvec(K, u, transpose_a = True))
    P = tf.linalg.matmul(tf.linalg.tensor_diag(u), K)
    P = tf.linalg.matmul(P, tf.linalg.tensor_diag(v))
    return tf.reduce_sum(tf.math.multiply(P, M))

In [39]:
x = np.random.normal(0, 1, (20, 3))
y = np.random.normal(0, 1, (20, 3))
x = tf.cast(x, dtype=tf.float32)
y = tf.cast(y, dtype=tf.float32)
M = dist(x, y)
reg = 2
a = tf.cast(np.ones((20,))/20, dtype = tf.float32)
b = tf.cast(np.ones((20,))/20, dtype = tf.float32)
sinkhorn(a, b, M, 0.1, 5)

<tf.Tensor: id=1042, shape=(), dtype=float32, numpy=1.2682142>

In [14]:
gamma

<tf.Tensor: id=78, shape=(), dtype=float32, numpy=0.5>

In [15]:
D = dist(a, b)
D

In [24]:
u = tf.cast([1,2], dtype = tf.float32)
v = tf.cast([[1,2],[3,4]], dtype = tf.float32)
tf.linalg.matvec(v,u)

<tf.Tensor: id=145, shape=(2,), dtype=float32, numpy=array([ 5., 11.], dtype=float32)>