## A demonstration for enforcing invariance

Here we enforce invariance through wasserstein distance

### Required modules

In [1]:
import data_load as data
import numpy as np
import tensorflow as tf
import setup
import utils
#tf.debugging.set_log_device_placement(True)

### Load and preprocess coloured MNIST data

In [2]:
# Setting seed
np.random.seed(1)

# Load data
data_train, data_test = data.make_environments(path='../MNIST', red_0_corrs = [0.8, 0.9])

# Unpack data
x0_train, y0_train = data_train[0]
x1_train, y1_train = data_train[1]
x_test, y_test = data_test

# Casting to tensor and repacking data
x0_train, x1_train = tf.cast(x0_train, dtype=tf.float32), tf.cast(x1_train, dtype=tf.float32)
y0_train, y1_train = tf.one_hot(y0_train, 2), tf.one_hot(y1_train, 2)
data_train = [[x0_train, y0_train], [x1_train, y1_train]]

x_test, y_test = tf.cast(x_test, dtype=tf.float32), tf.one_hot(y_test, 2)
data_test = [x_test, y_test]

Executing op Cast in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op OneHot in device /job:localhost/replica:0/task:0/device:CPU:0


### Setting up graphs and experiments

In [3]:
expt = setup.Invariance(epoch = 10000, learning_rate = 5e-4)
expt.create_graph()
expt.potential_graph()

In [4]:
expt.load_full_data(data_train, data_test)
expt.data_steam(data_train, data_test)
expt.create_tensorboard()

Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Equal in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op LogicalAnd in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op SelectV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AnonymousRandomSeedGenerator in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ShuffleDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op TakeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Exe

In [None]:
expt.fit()

Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MakeIterator in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op OptimizeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ModelDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op IteratorGetNextSync in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Cast in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RandomUniform in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Sub in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Mul in d

Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Fill in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op LogicalNot in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Assert in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op AddV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Cast in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Pow in device /job:localhost/replica:0/task:0