In [1]:
import tensorflow as tf
from tensorflow.keras import Model, layers
import ot
import numpy as np

In [33]:
class InvariantMap(Model):
    
    def __init__(self):
        super(InvariantMap, self).__init__()
        self.layer1 = layers.Dense(64, activation='sigmoid')
        self.layer2 = layers.Dense(32, activation='sigmoid')
        self.layer3 = layers.Dense(4, activation='sigmoid')
        
    def call(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        return out
    
invariantMap = InvariantMap()

In [34]:
def dmat(x, y):
    """
    :param x: (na, 2)
    :param y: (nb, 2)
    :return:
    """
    mmp1 = tf.tile(tf.expand_dims(x, axis=1), [1, y.shape[0], 1])  # (na, nb, 2)
    mmp2 = tf.tile(tf.expand_dims(y, axis=0), [x.shape[0], 1, 1])  # (na, nb, 2)

    mm = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(mmp1, mmp2)), axis=2))  # (na, nb)
    mm = tf.cast(mm, dtype = tf.float32)
    return mm




def WDist(x, y, wReg):
    nx = x.shape[0]
    ny = y.shape[0]
    M = dmat(x,y)
    M = np.array(M)
    histX = np.ones((nx,))/nx
    histY = np.ones((ny,))/ny
    T = ot.sinkhorn(histX, histY, M, wReg)
    T = tf.cast(T, dtype = tf.float32)
    return tf.reduce_sum(tf.math.multiply(M, T))

In [35]:
x0 = np.random.normal(0, 1, (100, 5))
x1 = np.random.normal(1, 1, (150, 5))
y0 = np.random.normal(0, 1, (100, 5))
y1 = np.random.normal(1, 1, (150, 5))
WDist(x0,x1, 0.5)
dataDict = dict()
dataDict[0] = {0: x0, 1: x1}
dataDict[1] = {0: y0, 1: y1}
dataDict

{0: {0: array([[ 2.86531154e-01,  9.56131292e-01, -6.91267649e-01,
           4.52300278e-01, -8.61606513e-01],
         [ 2.16372827e-01,  9.76896447e-02, -8.86320172e-01,
           1.38413914e+00, -1.45596410e-01],
         [ 9.48442308e-01, -1.37082170e-01,  6.13519751e-02,
          -9.01189518e-01, -1.01989067e-01],
         [ 5.25896310e-01,  2.30325242e-01, -1.67737397e+00,
          -1.03218246e+00, -1.15968109e-01],
         [ 7.69369901e-01,  8.26627457e-01, -2.21111968e+00,
           7.85169029e-01,  1.00158530e+00],
         [ 1.34976231e+00, -1.20321367e+00, -5.14336117e-01,
          -8.49031455e-01, -1.40159493e-01],
         [ 2.58234806e+00,  3.11092265e-01,  1.26001626e-01,
           1.78964333e-01, -1.28111831e+00],
         [ 3.59681887e-01, -1.62752748e-01, -2.73039344e-01,
          -1.52936431e-01, -1.98910986e+00],
         [ 9.54166866e-01,  8.75699930e-01, -1.84017130e+00,
          -1.27095349e-01, -5.18513234e-01],
         [ 4.85280046e-01, -1.05633110e+

In [16]:
def WDistLoss(dataDict, reg):
    loss = [WDist(invariantMap(dataDict[0][x]), invariantMap(dataDict[1][x]), reg) for x in range(2)]
    return np.sum(loss)

random_normal = tf.initializers.RandomNormal()
weight_final = { 'w': tf.Variable(random_normal([4,]))}
biases_final = {'b0': tf.Variable(random_normal([])), 'b1': tf.Variable(random_normal([]))}

In [17]:
biases_final

{'b0': <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0627598>,
 'b1': <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=-0.06851958>}

In [36]:
class FinalLayer(Model, layers.Layer):
    
    def __init__(self):
        super(FinalLayer, self).__init__()
        self.w = self.add_weight(shape=(4, ),
                               initializer='random_normal',
                               trainable=True)
        self.b0 = self.add_weight(shape=(),
                               initializer='random_normal',
                               trainable=True)
        self.b1 = self.add_weight(shape=(),
                               initializer='random_normal',
                               trainable=True)
        
    def call(self, x, env = 0, predict = False):
        if env == 0:
            predict_prob = tf.nn.sigmoid(tf.add(tf.matmul(x, self.w) + self.b0))
            if predict:
                return 1 if predict_prob > 0.5 else 0
            else: 
                return predict_prob
            
        elif env == 1:
            predict_prob = tf.nn.sigmoid(tf.add(tf.matmul(x, self.w) + self.b0))
            if predict:
                return 1 if predict_prob > 0.5 else 0
            else: 
                return predict_prob
            
        else: 
            raise TypeError("Wrong input")
final = FinalLayer()        

In [37]:
final.w

<tf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([ 0.01489473,  0.01834584, -0.03081675, -0.0358052 ], dtype=float32)>

In [38]:
final.trainable_variables
invariantMap.trainable_variables

[]