## Data

In [1]:
# !pip install --upgrade -q git+https://github.com/shuiruge/neural-ode.git@master

import numpy as np
import tensorflow as tf
from node.core import get_dynamical_node_function
from node.solvers.runge_kutta import RK4Solver, RKF56Solver
from node.solvers.dynamical_runge_kutta import DynamicalRK4Solver, DynamicalRKF56Solver
from node.hopfield import StopCondition, ContinuousTimeHopfieldLayer, DiscreteTimeHopfieldLayer

tf.random.set_seed(42)

tf.keras.backend.clear_session()

print(tf.__version__)

2.4.0-dev20200901


In [2]:
IMAGE_SIZE = (32, 32)
# IMAGE_SIZE = (8, 8)  # XXX: test!
BINARIZE = True
# BINARIZE = False



def pooling(x, size):
    # x shape: [None, width, height]
    x = tf.expand_dims(x, axis=-1)
    x = tf.image.resize(x, size)
    return x  # shape: [None, size[0], size[1], 1]


def process_data(X, y, image_size, binarize):
    X = pooling(X, image_size)
    X = X / 255.
    if binarize:
        X = tf.where(X < 0.5, -1., 1.)
    else:
        X = X * 2 - 1
    X = tf.reshape(X, [-1, image_size[0] * image_size[1]])
    y = tf.one_hot(y, 10)
    return tf.cast(X, tf.float32), tf.cast(y, tf.float32)


def display_reconstruction_accuracy(model, X, benchmark):
    error = np.abs(X - model.predict(X))
    benchmark_error = np.abs(X - benchmark)
    # plot
    plt.hist(error.reshape([-1]), label='error',
             alpha=0.5, bins=100, range=(0.1, 3))
    plt.hist(benchmark_error.reshape([-1]), label='benchmark_error',
             alpha=0.5, bins=100, range=(0.1, 3))
    plt.legend()
    plt.show()


def evaluate(model, X, y):
    yhat = model.predict(X)
    acc = np.mean(np.argmax(y, axis=-1) == np.argmax(yhat, axis=-1))
    return acc


(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train, y_train = process_data(x_train, y_train, IMAGE_SIZE, BINARIZE)

## Hopfield Layer with CNN AE

In [3]:
def get_center_mask(dim):
    assert dim % 2 == 1
    center = int(dim / 2) + 1
    mask = np.ones([dim, dim, 1, 1])
    mask[center, center, 0, 0] = 0
    return mask
    

def mask_center(kernel):
    # kernel shape: [dim, dim, n_channels, n_filters]
    dim, *_ = kernel.get_shape().as_list()
    mask = tf.constant(get_center_mask(dim), dtype=kernel.dtype)
    return kernel * mask


class HAELayer(tf.keras.layers.Layer):

    def __init__(self,
                 filters,
                 kernel_size,
                 activation=tf.nn.tanh,
                 tau=1,
                 static_solver=RKF56Solver(
                     dt=1e-1, tol=1e-3, min_dt=1e-2),
                 dynamical_solver=DynamicalRKF56Solver(
                     dt=1e-1, tol=1e-3, min_dt=1e-2),
                 max_time=1e+3,
                 relax_tol=1e-3,
                 reg_factor=1,
                 **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.activation = activation
        self.tau = float(tau)
        self.static_solver = static_solver
        self.dynamical_solver = dynamical_solver
        self.max_time = float(max_time)
        self.relax_tol = float(relax_tol)
        self.reg_factor = float(reg_factor)

    def build(self, input_shape):
        depth = input_shape[-1]
        two_dim_shape = [int(np.sqrt(depth))] * 2 + [1]

        ae = tf.keras.Sequential([
            tf.keras.layers.Reshape(two_dim_shape),
            tf.keras.layers.Conv2D(
                filters=self.filters,
                kernel_size=self.kernel_size,
                activation='relu',
                padding='same',
                kernel_constraint=mask_center),
            tf.keras.layers.Conv2D(1, 1, activation=self.activation),
            tf.keras.layers.Reshape([depth])
        ])
        ae.build(input_shape)

        def dynamics(t, x):
            return (-x + ae(x)) / self.tau

        stop_condition = StopCondition(dynamics, self.max_time, self.relax_tol)
        node_fn = get_dynamical_node_function(
          self.dynamical_solver, self.static_solver, dynamics, stop_condition)

        self._ae = ae
        self._stop_condition = stop_condition
        self._node_fn = node_fn

        super().build(input_shape)

    def call(self, x, training=None):
        if training:
            r = self._ae(x)
            loss = tf.reduce_mean(tf.abs(x - r))
            self.add_loss(self.reg_factor * loss)
            return r
        else:
            t0 = tf.constant(0.)
            return self._node_fn(t0, x)
        

def decompose_model(model):
    for i, layer in enumerate(model.layers):
        if isinstance(layer, BaseAELayer):
            break
    encoding_part = tf.keras.Sequential(model.layers[:i])

    for j in range(i, len(model.layers)):
        if not isinstance(layer, BaseAELayer):
            break
    ae_part = tf.keras.Sequential(model.layers[i:j])

    output_part = tf.keras.Sequential(model.layers[j:])

    return encoding_part, ae_part, output_part

In [4]:
def create_model():
    model = tf.keras.Sequential([
        HAELayer(filters=16, kernel_size=3, reg_factor=1),
    ])
    model.compile(optimizer='adam')
    return model

In [5]:
model = create_model()
X = x_train[:100].numpy()
ds0 = tf.data.Dataset.from_tensor_slices(X)
ds = ds0.shuffle(10000).repeat(10000).batch(128)
model.fit(ds)



<tensorflow.python.keras.callbacks.History at 0x7f868e716490>

In [6]:
noised_X = X + np.random.normal(size=X.shape) * 0.1
recon_X = model.predict(noised_X)
try:
    print('Relax time:', model.layers[-1]._stop_condition.relax_time.numpy())
except Exception as e:
    print(e)

Relax time: 7.771289


In [7]:
orig_err = noised_X - X
err = recon_X - X
rel_err = err / X
print(f'{np.quantile(np.abs(orig_err), 0.999)} => '
      f'{np.quantile(np.abs(err), 0.999)} '
      f'({np.quantile(np.abs(rel_err), 0.999)})')

0.3277522240637046 => 0.0002702606916427548 (0.0002702606916427548)


In [8]:
model.summary()  # XXX: Much much less number of parameters than the
                 # fully connected version.

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
hae_layer (HAELayer)         (None, 1024)              177       
Total params: 177
Trainable params: 177
Non-trainable params: 0
_________________________________________________________________


* CNN version needs only ~ 10^2 parameters. Recall that dense version needs 10^7 parameters.