In [None]:
import numpy as np
import tensorflow as tf
from datetime import datetime
from IPython.display import HTML
from node.core import get_node_function
from node.fix_grid import RKSolver
from node.utils.trajectory import tracer, visualize_trajectory
from node.energy_based import Energy, energy_based, identity


# for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


DTYPE = 'float32'
tf.keras.backend.set_floatx(DTYPE)


class MyLayer(tf.keras.layers.Layer):

    def __init__(self, units, dt, num_grids, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.num_grids = num_grids

        t0 = tf.constant(0., dtype=DTYPE)
        self.tN = t0 + num_grids * dt

        self._model = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu', dtype=DTYPE),
            tf.keras.layers.Dense(units, dtype=DTYPE),
        ])
        self._model.build([None, units])

        self._raw_pvf = lambda _, x: self._model(x)
        self._pvf = energy_based(identity, identity, self._raw_pvf)
        self._node_fn = get_node_function(RKSolver(self.dt, dtype=DTYPE),
                                          tf.constant(0., dtype=DTYPE),
                                          self._pvf)
                                          # self._raw_pvf)  # for comparison.

    def call(self, x):
        y = self._node_fn(self.tN, x)
        return y

    def get_config(self):
        return super().get_config().copy()


def process(X, y):
    X = X / 255.
    X = tf.reshape(X, [-1, 28 * 28])
    y = tf.one_hot(y, 10)
    return tf.cast(X, DTYPE), tf.cast(y, DTYPE)


mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = process(x_train, y_train)
x_test, y_test = process(x_test, y_test)

model = tf.keras.Sequential([
    tf.keras.layers.Input([28 * 28]),
    tf.keras.layers.BatchNormalization(),  # input stardarization.
    tf.keras.layers.Dense(64),
    MyLayer(64, dt=1e-1, num_grids=10),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(loss='categorical_crossentropy',
              optimizer=tf.optimizers.Adam(1e-3),
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=12, batch_size=128)

In [None]:
model_2 = tf.keras.Sequential([
    tf.keras.layers.Input([28 * 28]),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(128),
    MyLayer(128, dt=1e-1, num_grids=30),
    tf.keras.layers.Dense(10, activation='softmax')
])

model_2.compile(loss='categorical_crossentropy',
              optimizer=tf.optimizers.Adam(2e-4),
              metrics=['accuracy'])

model_2.set_weights(model.get_weights())
model_2.evaluate(x_train, y_train)

In [None]:
my_layer_id = 2
my_layer_id = 2
my_layer = model.layers[my_layer_id]
trace = tracer(RKSolver(0.1), my_layer._pvf)
energy_fn = Energy(identity, my_layer._pvf)

truncated_model = tf.keras.Sequential(model.layers[my_layer_id:(my_layer_id + 1)])
hidden = truncated_model(x_train[:100])
labels = y_train[:100]
trajectory = trace(t0=tf.constant(0.),
                   t1=tf.constant(10.),
                   dt=tf.constant(0.1),
                   x=hidden)


def energy_along_trajectory(trajectory):
    shape = trajectory.get_shape().as_list()
    batch_size, trajectory_size, *phase_dims = shape
    phase_points = tf.reshape(trajectory, [-1, *phase_dims])
    e = energy_fn(phase_points)
    return tf.reshape(e, [batch_size, trajectory_size])


energy = energy_along_trajectory(trajectory)

In [None]:
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt


i = 11
diffs = (trajectory[:,-1,:] - trajectory[:,0,:])
plt.hist(diffs[i], bins=50, range=(-5, 5))
plt.show()