In [1]:
import numpy as np
import tensorflow as tf
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from sklearn.preprocessing import StandardScaler
from node.core import get_node_function
from node.solvers import RK4Solver
from node.hopfield import hopfield, identity, rescale


# for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


def process(X, y):
    X = X / 255.
    X = np.reshape(X, [-1, 28 * 28])
    y = np.eye(10)[y]
    return X.astype('float32'), y.astype('float32')


mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, y_train = process(x_train, y_train)
x_test, y_test = process(x_test, y_test)

scalar = StandardScaler()
scalar.fit(x_train)
x_train = scalar.transform(x_train)
x_test = scalar.transform(x_test)

In [19]:
class BasicNodeLayer(tf.keras.layers.Layer):

    def __init__(self, units, dt, num_grids, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.num_grids = num_grids

        t0 = tf.constant(0.)
        self.tN = t0 + num_grids * dt

        self._model = tf.keras.Sequential([
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(units),
        ])
        self._model.build([None, units])
        self._pvf = lambda _, x: self._model(x)

        solver = RK4Solver(self.dt)
        self._node_fn = get_node_function(solver, tf.constant(0.), self._pvf)

    def call(self, x):
        return self._node_fn(self.tN, x)

    def get_config(self):
        return super().get_config().copy()


def get_trained_basic_model(logdir=None, epochs=10, lr=1e-3):
    tf.keras.backend.clear_session()

    basic_model = tf.keras.Sequential([
        tf.keras.layers.Input([28 * 28]),
        tf.keras.layers.LayerNormalization(scale=False, center=False),
        tf.keras.layers.Dense(64, use_bias=False),  # down-sampling
        BasicNodeLayer(64, dt=1e-1, num_grids=10),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    basic_model.compile(
        loss='categorical_crossentropy',
        optimizer=tf.optimizers.Nadam(lr),
        metrics=['accuracy'])

    callbacks = []
    if logdir:
        tensorboard = tf.keras.callbacks.TensorBoard(
            log_dir=logdir, histogram_freq=1, write_images=True, update_freq=10)
        callbacks.append(tensorboard)

    basic_model.fit(x_train, y_train, epochs=epochs, batch_size=128, callbacks=callbacks)
    return basic_model

In [20]:
basic_model = get_trained_basic_model()

Train on 60000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [21]:
longer_trajectory_basic_model = tf.keras.Sequential([
    tf.keras.layers.Input([28 * 28]),
    tf.keras.layers.Dense(64, use_bias=False),  # down-sampling
    tf.keras.layers.LayerNormalization(scale=False, center=False),
    BasicNodeLayer(64, dt=1e-1, num_grids=50),
    tf.keras.layers.Dense(10, activation='softmax')
])

longer_trajectory_basic_model.compile(
    loss='categorical_crossentropy',
    metrics=['accuracy'])

longer_trajectory_basic_model.set_weights(basic_model.get_weights())
longer_trajectory_basic_model.evaluate(x_train, y_train, batch_size=128)



[0.30558200271924335, 0.90363336]

In [8]:
from node.utils.nest import nest_map

@nest_map
@tf.function
def layer_normalize(x, axes=None, eps=1e-8):
    """
    Reference:
        1. [Layer Normalization](https://arxiv.org/abs/1607.06450)
        2. [TF implementation](https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/layers/normalization.py#L910-L1158)
    """
    if axes is None:
        non_batch_axes = tf.range(1, len(x.shape))
        axes = non_batch_axes
    mean, variance = tf.nn.moments(x, axes, keepdims=True)
    return (x - mean) / tf.sqrt(variance + eps)

In [14]:
class NormalizedNodeLayer(tf.keras.layers.Layer):

    def __init__(self, units, dt, num_grids, **kwargs):
        super().__init__(**kwargs)
        self.dt = dt
        self.num_grids = num_grids

        t0 = tf.constant(0.)
        self.tN = t0 + num_grids * dt

        self._model = tf.keras.Sequential([
            tf.keras.layers.Dense(1024, activation='relu'),
            tf.keras.layers.Dense(units),
        ])
        self._model.build([None, units])
        
        def pvf(t, x):
            v = self._model(x)
            with tf.GradientTape() as g:
                g.watch(x)
                norm = layer_normalize(x)
            return g.gradient(norm, x, v)

        self._pvf = pvf
        solver = RK4Solver(self.dt)
        self._node_fn = get_node_function(
            solver, tf.constant(0.), self._pvf)

    def call(self, x):
        y = self._node_fn(self.tN, x)
        return y

    def get_config(self):
        return super().get_config().copy()


def get_trained_normalized_model(logdir=None, epochs=10, lr=1e-3):
    tf.keras.backend.clear_session()

    normalized_model = tf.keras.Sequential([
        tf.keras.layers.Input([28 * 28]),
        tf.keras.layers.Dense(64, use_bias=False),  # down-sampling
        tf.keras.layers.LayerNormalization(scale=False, center=False),
        NormalizedNodeLayer(64, dt=1e-1, num_grids=10),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    normalized_model.compile(
        loss='categorical_crossentropy',
        optimizer=tf.optimizers.Nadam(lr),
        metrics=['accuracy'])

    callbacks = []
    if logdir is not None:
        tensorboard = tf.keras.callbacks.TensorBoard(
            log_dir=logdir, histogram_freq=1, write_images=True, update_freq=10)
        callbacks.append(tensorboard)
    normalized_model.fit(x_train, y_train, epochs=epochs, batch_size=128, callbacks=callbacks)
    return normalized_model

In [15]:
normalized_model = get_trained_normalized_model()

Train on 60000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [17]:
longer_trajectory_normalized_model = tf.keras.Sequential([
    tf.keras.layers.Input([28 * 28]),
    tf.keras.layers.Dense(64, use_bias=False),  # down-sampling
    tf.keras.layers.LayerNormalization(scale=False, center=False),
    NormalizedNodeLayer(64, dt=1e-1, num_grids=50),
    tf.keras.layers.Dense(10, activation='softmax')
])

longer_trajectory_normalized_model.compile(
    loss='categorical_crossentropy',
    metrics=['accuracy'])

longer_trajectory_normalized_model.set_weights(normalized_model.get_weights())
longer_trajectory_normalized_model.evaluate(x_train, y_train, batch_size=128)



[0.5223327584107716, 0.87726665]

### Conclusion

With layer normalization, no convergence (relaxation) is observed.