In [1]:
import tensorflow as tf
from tensorflow.keras import Model, layers
import ot
import numpy as np

In [2]:
class InvarianceNNGraph(Model, layers.Layer):
    
    def __init__(self):
        super(InvarianceNNGraph, self).__init__()
        self.weight = {'weight1': self.add_weight(shape=(64, 32), initializer='random_normal', trainable=True), 
                        'weight2': self.add_weight(shape=(32, 16), initializer='random_normal', trainable=True), 
                        'weight3': self.add_weight(shape=(16, 6), initializer='random_normal', trainable=True),
                        'weight_final': self.add_weight(shape = (6,1), initializer = 'random_normal', trainable = True)}
        self.bias = {'bias1': self.add_weight(shape=(32, ), initializer='random_normal', trainable=True), 
                        'bias2': self.add_weight(shape=(16, ), initializer='random_normal', trainable=True), 
                        'bias3': self.add_weight(shape=(6, ), initializer='random_normal', trainable=True),
                        'bias_final0': self.add_weight(shape = (1,), initializer = 'random_normal', trainable = True), 
                        'bias_final1': self.add_weight(shape = (1,), initializer = 'random_normal', trainable = True)}
        
    def invariantMap(self, x):
        out = tf.nn.sigmoid(tf.add(tf.matmul(x, self.weight['weight1']), self.bias['bias1']))
        out = tf.nn.sigmoid(tf.add(tf.matmul(out, self.weight['weight2']), self.bias['bias2']))
        out = tf.nn.sigmoid(tf.add(tf.matmul(out, self.weight['weight3']), self.bias['bias3']))
        return out
        
    def call(self, x, env = 0, is_training = False):
        if env == 0:
            out = tf.add(tf.matmul(self.invariantMap(x), self.weight['weight_final']) , self.bias['bias_final0'])
            out = tf.concat([-out, out], axis = 1)
            #predict_prob = tf.nn.sigmoid(tf.add(tf.matmul(self.invariantMap(x), self.weight['weight_final']), self.bias['bias_final0']))
            #predict_prob2 = tf.concat([1- predict_prob, predict_prob], axis = 1)
            if is_training:
                return out#tf.argmax(predict_prob2, axis = 1)
            else: 
                return tf.nn.softmax(out)
            
        elif env == 1:
            out = tf.add(tf.matmul(self.invariantMap(x), self.weight['weight_final']), self.bias['bias_final1'])
            out = tf.concat([-out, out], axis = 1)
            #predict_prob = tf.nn.sigmoid(tf.add(tf.matmul(self.invariantMap(x), self.weight['weight_final']), self.bias['bias_final1']))
            #predict_prob2 = tf.concat([1- predict_prob, predict_prob], axis = 1)
            if is_training:
                return out #tf.argmax(predict_prob2, axis = 1)
            else: 
                return tf.nn.softmax(out) #predict_prob
inv = InvarianceNNGraph()

In [3]:
x = np.random.normal(0, 1, (10, 64))
x = tf.convert_to_tensor(x, dtype=tf.float32)
u = inv(x, is_training = False)
tf.math.log(u)

<tf.Tensor: id=133, shape=(10, 2), dtype=float32, numpy=
array([[-0.6171767 , -0.77536714],
       [-0.6172655 , -0.7752631 ],
       [-0.6171934 , -0.7753476 ],
       [-0.617163  , -0.7753833 ],
       [-0.617281  , -0.7752451 ],
       [-0.6173059 , -0.77521574],
       [-0.6172641 , -0.7752648 ],
       [-0.61726373, -0.7752652 ],
       [-0.617303  , -0.7752194 ],
       [-0.61719525, -0.77534527]], dtype=float32)>

In [4]:
def dmat(x, y):
    """
    :param x: (na, 2)
    :param y: (nb, 2)
    :return:
    """
    mmp1 = tf.tile(tf.expand_dims(x, axis=1), [1, y.shape[0], 1])  # (na, nb, 2)
    mmp2 = tf.tile(tf.expand_dims(y, axis=0), [x.shape[0], 1, 1])  # (na, nb, 2)

    mm = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(mmp1, mmp2)), axis=2))  # (na, nb)
    mm = tf.cast(mm, dtype = tf.float32)
    return mm




def WDist(x, y, wReg):
    nx = x.shape[0]
    ny = y.shape[0]
    M = dmat(x,y)
    M = np.array(M)
    histX = np.ones((nx,))/nx
    histY = np.ones((ny,))/ny
    T = ot.sinkhorn(histX, histY, M, wReg)
    T = tf.cast(T, dtype = tf.float32)
    return tf.reduce_sum(tf.math.multiply(M, T))


def WDistLoss(modifiedDict, reg):
    loss = tf.cast([WDist(inv.invariantMap(modifiedDict[0][x]), inv.invariantMap(modifiedDict[1][x]), reg) 
                            for x in range(2)], dtype=tf.float32)
    return tf.reduce_sum(loss)

def cross_entropy_loss(y, predict_prob):
    y = tf.one_hot(y, 2)
    loss = -y*tf.math.log(predict_prob)
    return 2*tf.reduce_mean(loss)


def makeModifiedDict(dataDict):
    retDict = dict()
    for key in dataDict:
        retDict[key] = dict()
        dHand = retDict[key]
        x = dataDict[key]['x']
        y = dataDict[key]['y']
        dHand[0] = tf.cast(x[y == 0], dtype=tf.float32)
        dHand[1] = tf.cast(x[y == 1], dtype=tf.float32)
    return retDict

In [5]:
x0 = tf.cast(np.random.normal(0, 1, (100, 64)), dtype = tf.float32)
x1 = tf.cast(np.random.normal(1, 1, (150, 64)), dtype=tf.float32)
y0 = tf.cast(np.random.normal(0, 1, (100, 64)), dtype=tf.float32)
y1 = tf.cast(np.random.normal(1, 1, (150, 64)), dtype=tf.float32)
dataDict = dict()
dataDict[0] = {0: x0, 1: x1}
dataDict[1] = {0: y0, 1: y1}

In [6]:
WDistLoss(dataDict, 1)

<tf.Tensor: id=242, shape=(), dtype=float32, numpy=0.0025236334>

In [7]:
y0 = np.random.binomial(1, 0.5, (1200,))
y1 = np.random.binomial(1, 0.8, (1500,))
f = lambda y: np.random.normal(1, 1, (64,)) if y else np.random.normal(0, 1, (64,))
x0 = [f(y) for y in y0]
x1 = [f(y) for y in y1]
#y0 = tf.one_hot(y0, 2)
#y1 = tf.one_hot(y1, 2)
x0 = tf.cast(x0, dtype= tf.float32)
x1 = tf.cast(x1, dtype= tf.float32)
dataDict = {0: {'x': x0, 'y': y0}, 1: {'x': x1, 'y': y1}}
prob = inv(x0)
c = cross_entropy_loss(y0, prob)
c

<tf.Tensor: id=281, shape=(), dtype=float32, numpy=0.69570327>

In [8]:
batch_size = 150
num_steps = 100
for key in dataDict:
    dictHand = dataDict[key]
    batch = tf.data.Dataset.from_tensor_slices((dictHand['x'], dictHand['y']))
    batch = batch.repeat().shuffle(5000).batch(batch_size).prefetch(1)
    dictHand['batch'] = batch.take(num_steps)
    
learningRate = 0.01
optimizer = tf.optimizers.Adam(learningRate)


In [9]:
def RunOptimizer(trainDataDict, lr, reg, lam):
    with tf.GradientTape() as g:
        modifiedDict = makeModifiedDict(trainDataDict)
        loss = lam*WDistLoss(modifiedDict, reg)
        for env in trainDataDict:
            dictHand = trainDataDict[env]
            predict_prob = inv(dictHand['x'])
            loss += cross_entropy_loss(dictHand['y'], predict_prob)
        
    trainable_vars = inv.trainable_variables
    gradients = g.gradient(loss, trainable_vars)
    optimizer.apply_gradients(zip(gradients, trainable_vars))
        

In [10]:
with tf.GradientTape() as g:
    trainDataDict = dataDict
    modifiedDict = makeModifiedDict(trainDataDict)
    loss = WDistLoss(modifiedDict, 0.5)
    for env in trainDataDict:
        dictHand = trainDataDict[env]
        logits = inv(dictHand['x'])
        loss += cross_entropy_loss(dictHand['y'], logits)
    

In [11]:
trainable_vars = inv.trainable_variables
gradients = g.gradient(loss, trainable_vars)
gradients

[<tf.Tensor: id=731, shape=(64, 32), dtype=float32, numpy=
 array([[-6.2274839e-06, -2.4451609e-05,  2.7979222e-06, ...,
          3.4066805e-05, -3.4612449e-06,  1.7386556e-05],
        [-6.2019135e-06, -2.3625422e-05,  2.7308733e-06, ...,
          3.3042568e-05, -3.3285392e-06,  1.6836180e-05],
        [-6.0096327e-06, -2.3207926e-05,  2.6405708e-06, ...,
          3.1855998e-05, -3.3150689e-06,  1.6257571e-05],
        ...,
        [-6.3031794e-06, -2.4055231e-05,  2.7643487e-06, ...,
          3.3723732e-05, -3.4139412e-06,  1.7041515e-05],
        [-6.0333723e-06, -2.3516379e-05,  2.6968105e-06, ...,
          3.2740034e-05, -3.3314404e-06,  1.6429136e-05],
        [-6.2785493e-06, -2.4549736e-05,  2.8262405e-06, ...,
          3.4050656e-05, -3.5094276e-06,  1.7581318e-05]], dtype=float32)>,
 <tf.Tensor: id=732, shape=(32, 16), dtype=float32, numpy=
 array([[-1.77804177e-04,  2.10071958e-04,  4.03544283e-04,
         -3.02262924e-04, -1.09264520e-05,  8.84566689e-05,
         -6

In [12]:
optimizer.apply_gradients(zip(gradients, trainable_vars))



<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=1>

In [65]:
optimizer

<tensorflow.python.keras.optimizer_v2.adam.Adam at 0x13374ada408>

In [66]:
inv.trainable_variables

[<tf.Variable 'Variable:0' shape=(64, 32) dtype=float32, numpy=
 array([[ 0.1499512 ,  0.03687894,  0.04996171, ...,  0.04847898,
          0.06829581,  0.04833071],
        [-0.0399543 ,  0.00394524,  0.02390912, ...,  0.01155864,
         -0.03722502, -0.03666129],
        [-0.06393464,  0.06313321, -0.02757789, ...,  0.02294632,
          0.06605469, -0.10789549],
        ...,
        [-0.05742737,  0.07109744, -0.00678881, ...,  0.03579374,
         -0.04743608, -0.05193434],
        [-0.00542439,  0.01923782,  0.01418852, ...,  0.0790406 ,
          0.03077881, -0.00616618],
        [ 0.03691861, -0.0527195 , -0.01191864, ...,  0.03356226,
          0.07611102, -0.06887035]], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=(32, 16) dtype=float32, numpy=
 array([[-4.18962725e-02, -2.17476636e-02, -1.11166285e-02,
         -4.29358110e-02,  7.19985366e-02, -7.79246092e-02,
          9.76048131e-03,  9.30439681e-03, -3.90602984e-02,
          4.03320305e-02,  6.34792224e-02,  1.044

In [67]:
v = tf.Variable(0)

In [68]:
v

<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=0>

In [69]:
for step, ((x0, y0), (x1, y1)) in enumerate(zip(dataDict[0]['batch'], dataDict[1]['batch']), 1):
    batchDataDict = {0: {'x': x0, 'y': y0}, 1: {'x': x1, 'y': y1}}
    RunOptimizer(batchDataDict, 0.01, 0.5, 0.5)





In [70]:
inv.trainable_variables

[<tf.Variable 'Variable:0' shape=(64, 32) dtype=float32, numpy=
 array([[ 0.27501473,  0.13654505,  0.15078057, ...,  0.32247686,
          0.00610315,  0.3900001 ],
        [ 0.00750037,  0.01874435,  0.06969524, ...,  0.2695772 ,
         -0.05409102,  0.27882195],
        [-0.05048328,  0.11909208, -0.04668152, ...,  0.19153905,
         -0.0331324 ,  0.22138214],
        ...,
        [-0.01988893,  0.11323827,  0.04747447, ...,  0.25084564,
         -0.08388431,  0.3066712 ],
        [ 0.17226703,  0.10457055,  0.12667607, ...,  0.19228283,
         -0.06529453,  0.22233121],
        [ 0.02620078, -0.0252241 , -0.01888999, ...,  0.3003206 ,
          0.08190741,  0.2266776 ]], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=(32, 16) dtype=float32, numpy=
 array([[ 0.21524851,  0.19695066,  0.40024558, -0.24085239,  0.27304652,
          0.13452767, -0.17223547,  0.21960135,  0.20600691,  0.29083976,
          0.2635001 ,  0.2927746 ,  0.20807678,  0.3334041 ,  0.25778523,
       