In [1]:
import numpy as np
import tensorflow as tf
from node.core import get_node_function
from node.fix_grid import RKSolver
from node.utils.initializers import GlorotUniform
from node.utils.train import print_status_bar
from node.utils.trajectory import tracer
from node.energy_based import Energy, energy_based, rescale


# for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


class MyLayer(tf.keras.layers.Layer):

    def __init__(self, units, dt, num_grids, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.num_grids = num_grids

        t0 = tf.constant(0.)
        self.tN = t0 + num_grids * dt

        self._model = tf.keras.Sequential([
            tf.keras.layers.ReLU(),
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(units),
        ])
        self._model.build([None, units])

        self._pvf = energy_based(rescale(1e-1),
                                 rescale(1e+1),
                                 lambda _, x: self._model(x))
        self._node_fn = get_node_function(RKSolver(self.dt),
                                          tf.constant(0.),
                                          self._pvf)

    def call(self, x):
        y = self._node_fn(self.tN, x)
        return y


def process(X, y):
    X = X / 255.
    X = tf.reshape(X, [-1, 28 * 28])
    y = tf.one_hot(y, 10)
    return tf.cast(X, tf.float32), tf.cast(y, tf.float32)


mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = process(x_train, y_train)
x_test, y_test = process(x_test, y_test)

model = tf.keras.Sequential([
    tf.keras.layers.Input([28 * 28]),
    tf.keras.layers.Dense(64),
    MyLayer(64, dt=1e-1, num_grids=10),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(loss='categorical_crossentropy',
              optimizer=tf.optimizers.Adam(1e-3, clipvalue=1.),
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=12)

Train on 60000 samples
Epoch 1/12
Epoch 2/12

KeyboardInterrupt: 

In [None]:
model.layers

In [None]:
my_layer_id = 1
my_layer = model.layers[my_layer_id]
trace = tracer(RKSolver(0.1), my_layer._pvf)
energy_fn = get_energy(my_layer._model)

truncated_model = tf.keras.Sequential(model.layers[:my_layer_id])
hidden = truncated_model(x_train[:100])
labels = y_train[:100]
trajectory = trace(t0=tf.constant(0.),
                   t1=tf.constant(5.),
                   dt=tf.constant(0.1),
                   x=hidden)


def energy_along_trajectory(trajectory):
    shape = trajectory.get_shape().as_list()
    batch_size, trajectory_size, *phase_dims = shape
    phase_points = tf.reshape(trajectory, [-1, *phase_dims])
    e = energy_fn(phase_points)
    return tf.reshape(e, [batch_size, trajectory_size])


energy = energy_along_trajectory(trajectory)

In [None]:
from node.utils.trajectory import visualize_trajectory
from IPython.display import HTML

In [None]:
labels_ = np.argmax(labels.numpy(), -1)
print(list(enumerate(labels_)))

In [None]:
def visualize(trajectory, label):
    print(f'label: {label}')
    anim = visualize_trajectory(trajectory.reshape([-1, 8, 8]))
    return HTML(anim.to_html5_video())

In [None]:
i = 0
visualize(trajectory.numpy()[i], labels_[i])

In [None]:
i = 11
visualize(trajectory.numpy()[i], labels_[i])

In [None]:
i = 35
visualize(trajectory.numpy()[i], labels_[i])

In [None]:
i = 36
visualize(trajectory.numpy()[i], labels_[i])

In [None]:
i = 37
visualize(trajectory.numpy()[i], labels_[i])

In [None]:
mean, variance = tf.nn.moments(trajectory, axes=[-1])

In [None]:
variance[10]