In [15]:
import tensorflow as tf
from tensorflow.keras import Model, layers
import ot
import numpy as np

In [2]:
class InvarianceTestGraph(Model, layers.Layer):
    
    def __init__(self):
        super(InvarianceTestGraph, self).__init__()
        
        self.weight = {'weight1': self.add_weight(shape=(64, 32), initializer='random_normal', trainable=True), 
                        'weight2': self.add_weight(shape=(32, 16), initializer='random_normal', trainable=True), 
                        'weight3': self.add_weight(shape=(16, 6), initializer='random_normal', trainable=True),
                        'weight_final': self.add_weight(shape = (6,2), initializer = 'random_normal', trainable = True)}
        
        self.bias = {'bias1': self.add_weight(shape=(32, ), initializer='random_normal', trainable=True), 
                        'bias2': self.add_weight(shape=(16, ), initializer='random_normal', trainable=True), 
                        'bias3': self.add_weight(shape=(6, ), initializer='random_normal', trainable=True),
                        'bias_final0': self.add_weight(shape = (2,), initializer = 'random_normal', trainable = True), 
                        'bias_final1': self.add_weight(shape = (2,), initializer = 'random_normal', trainable = True)}
        
    def invariantMap(self, x):
        out = tf.nn.sigmoid(tf.add(tf.matmul(x, self.weight['weight1']), self.bias['bias1']))
        out = tf.nn.sigmoid(tf.add(tf.matmul(out, self.weight['weight2']), self.bias['bias2']))
        out = tf.nn.sigmoid(tf.add(tf.matmul(out, self.weight['weight3']), self.bias['bias3']))
        return out


    def call(self, x, env = 0):
        out = self.invariantMap(x)
        if env == 0:
            out = tf.add(tf.matmul(out, self.weight['weight_final']), self.bias['bias_final0'])
        else:
            out = tf.add(tf.matmul(out, self.weight['weight_final']), self.bias['bias_final1'])
        return out
    
inv = InvarianceTestGraph()

In [3]:
y0 = np.random.binomial(1, 0.5, (1200,))
y1 = np.random.binomial(1, 0.8, (1500,))
f = lambda y: np.random.normal(1, 1, (64,)) if y else np.random.normal(0, 1, (64,))
x0 = [f(y) for y in y0]
x1 = [f(y) for y in y1]
y0 = tf.one_hot(y0, 2)
y1 = tf.one_hot(y1, 2)
x0 = tf.cast(x0, dtype=tf.float32)
x1 = tf.cast(x1, dtype = tf.float32)
batch0 = tf.data.Dataset.from_tensor_slices((x0, y0))
batch1 = tf.data.Dataset.from_tensor_slices((x1, y1))
batch0 = batch0.repeat().shuffle(5000).batch(200).prefetch(1)
batch1 = batch1.repeat().shuffle(5000).batch(200).prefetch(1)


learningRate = 0.01
optimizer = tf.optimizers.Adam(learningRate)

In [4]:
for step, ((bx0, by0), (bx1, by1)) in enumerate(zip(batch0.take(120), batch1.take(150)), 1):
    with tf.GradientTape() as g:
        loss = tf.nn.softmax_cross_entropy_with_logits(by0, inv(bx0, 0))
        loss = loss + tf.nn.softmax_cross_entropy_with_logits(by1, inv(bx1, 1))
        
    trainable_variables = inv.trainable_variables
    gradients = g.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))

In [5]:
x0.shape

TensorShape([1200, 64])

In [11]:
x = np.random.normal(0, 1, (10,3))
y = np.random.normal(0, 1, (10,3))

In [12]:
def WDist(x, y, reg):
    """
    :param x: (na, 2)
    :param y: (nb, 2)
    :return:
    """
    nx = x.shape[0]
    ny = y.shape[0]
    histX = np.ones((nx,))/nx
    histX = tf.cast(histX, dtype = tf.float32)
    histY = np.ones((ny,))/ny
    histY = tf.cast(histY, dtype = tf.float32)
    mmp1 = tf.tile(tf.expand_dims(x, axis=1), [1, y.shape[0], 1])  # (na, nb, 2)
    mmp2 = tf.tile(tf.expand_dims(y, axis=0), [x.shape[0], 1, 1])  # (na, nb, 2)

    mm = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(mmp1, mmp2)), axis=2))  # (na, nb)
    mm = tf.cast(mm, dtype = tf.float32)
    T = ot.sinkhorn(histX, histY, mm, reg)
    
    return mm

In [18]:
nx = x.shape[0]
ny = y.shape[0]
M = dmat(x,y)
histX = np.ones((nx,))/nx
histX = tf.cast(histX, dtype = tf.float32)
histY = np.ones((ny,))/ny
histY = tf.cast(histY, dtype = tf.float32)
T = ot.sinkhorn(histX, histY, M, 1)

In [19]:
T

array([[0.00791357, 0.0095655 , 0.01831183, 0.00408009, 0.00544475,
        0.00674669, 0.00844901, 0.00438455, 0.01234625, 0.02275776],
       [0.02077406, 0.01293324, 0.00400238, 0.0058651 , 0.00550861,
        0.00989703, 0.00603599, 0.02135876, 0.01021414, 0.0034107 ],
       [0.00797319, 0.01058248, 0.01230927, 0.00668704, 0.01043484,
        0.0107341 , 0.01400176, 0.00784463, 0.01025335, 0.00917935],
       [0.00627735, 0.00704867, 0.00749176, 0.00932166, 0.01422727,
        0.01447387, 0.01475634, 0.00710511, 0.01165989, 0.00763807],
       [0.00746371, 0.00995511, 0.01382115, 0.00680734, 0.01031016,
        0.00994715, 0.01237634, 0.00651472, 0.01065249, 0.01215183],
       [0.01595954, 0.00972   , 0.00280238, 0.01308259, 0.00941025,
        0.0122434 , 0.005273  , 0.02003469, 0.00879487, 0.00267928],
       [0.01099474, 0.01159207, 0.02220033, 0.00207922, 0.00245843,
        0.00401915, 0.00494281, 0.00364239, 0.01149988, 0.02657098],
       [0.00199007, 0.00168718, 0.0020192