## Data

In [1]:
# !pip install --upgrade -q git+https://github.com/shuiruge/neural-ode.git@master

import numpy as np
import tensorflow as tf

from node.core import get_dynamical_node_function
from node.solvers.runge_kutta import RKF56Solver
from node.solvers.dynamical_runge_kutta import DynamicalRKF56Solver
from node.hopfield import StopCondition

tf.random.set_seed(42)

tf.keras.backend.clear_session()

print(tf.__version__)

2.4.0-dev20200804


In [2]:
IMAGE_SIZE = (32, 32)
# IMAGE_SIZE = (8, 8)  # XXX: test!
BINARIZE = True
# BINARIZE = False


def pooling(x, size):
    # x shape: [None, width, height]
    x = tf.expand_dims(x, axis=-1)
    x = tf.image.resize(x, size)
    return x  # shape: [None, size[0], size[1], 1]


def process_data(X, y, image_size, binarize):
    X = pooling(X, image_size)
    X = X / 255.
    if binarize:
        X = tf.where(X < 0.5, -1., 1.)
    else:
        X = X * 2 - 1
    X = tf.reshape(X, [-1, image_size[0] * image_size[1]])
    y = tf.one_hot(y, 10)
    return tf.cast(X, tf.float32), tf.cast(y, tf.float32)


def evaluate(model, X, y):
    yhat = model.predict(X)
    acc = np.mean(np.argmax(y, axis=-1) == np.argmax(yhat, axis=-1))
    return acc


(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train, y_train = process_data(x_train, y_train, IMAGE_SIZE, BINARIZE)

## Model

In [3]:
def sign(x):
    return tf.where(x > 0, 1., -1.)


class NonidentityRecon(tf.keras.layers.Layer):
    """Base class of re-constructor which is further constrainted
    to avoid learning to be an identity map."""


class DenseRecon(NonidentityRecon):
    """Fully connected non-identity re-constructor.

    Parameters
    ----------
    activation : callable
    binarize: callable, optional
        Binarization method for non-training process. If `None`, then no
        binarization.
    use_bias : bool, optional
        For simplicity, bias is not employed by default.
    """

    def __init__(self,
                 activation,
                 binarize=None,
                 use_bias=False,
                 **kwargs):
        super().__init__(**kwargs)
        self.activation = activation
        self.binarize = binarize
        self.use_bias = use_bias

    def build(self, input_shape):
        depth = input_shape[-1]
        self._dense = tf.keras.layers.Dense(
            units=depth,
            activation=self.activation,
            use_bias=self.use_bias,
            kernel_constraint=symmetrize_and_mask_diagonal,
        )
        super().build(input_shape)

    def call(self, x, training=None):
        y = self._dense(x)
        if training:
            return y
        if self.binarize is not None:
            y = self.binarize(y)
        return y


def symmetrize_and_mask_diagonal(kernel):
    """Symmetric kernel with vanishing diagonal.

    Parameters
    ----------
    kernel : tensor
        Shape (N, N) for a positive integer N.

    Returns
    -------
    tensor
        The shape and dtype as the input.
    """
    w = (kernel + tf.transpose(kernel)) / 2
    w = tf.linalg.set_diag(w, tf.zeros(kernel.shape[0:-1]))
    return w


class Conv2dRecon(NonidentityRecon):
    """Cellular automata based non-identity re-constructor.

    References
    ----------
    1. Cellular automata as convolutional neural networks (arXiv: 1809.02942).

    Parameters
    ----------
    filters : int
    kernel_size : int
    activation : callable
    binarize: callable, optional
        Binarization method for non-training process. If `None`, then no
        binarization.
    flatten : bool, optional
    """

    def __init__(self,
                 filters,
                 kernel_size,
                 activation,
                 binarize=None,
                 flatten=False,
                 **kwargs):
        super().__init__(**kwargs)
        self.filters = int(filters)
        self.kernel_size = int(kernel_size)
        self.activation = activation
        self.binarize = binarize
        self.flatten = flatten

    def build(self, input_shape):
        recon_layers = [
            tf.keras.layers.Conv2D(
                filters=self.filters,
                kernel_size=self.kernel_size,
                activation='relu',
                padding='same',
                kernel_constraint=mask_center,
            ),
            tf.keras.layers.Conv2D(1, 1, activation=self.activation),
        ]
        if self.flatten:
            depth = input_shape[-1]
            two_dim_shape = [int(np.sqrt(depth))] * 2 + [1]
            recon_layers.insert(0, tf.keras.layers.Reshape(two_dim_shape))
            recon_layers.append(tf.keras.layers.Reshape([depth]))
        self._recon = tf.keras.Sequential(recon_layers)
        self._recon.build(input_shape)

        super().build(input_shape)

    def call(self, x, training=None):
        y = self._recon(x)
        if training:
            return y
        if self.binarize is not None:
            y = self.binarize(y)
        return y


def _get_center_mask(dim: int) -> np.array:
    assert dim % 2 == 1
    center = int(dim / 2) + 1
    mask = np.ones([dim, dim, 1, 1])
    mask[center, center, 0, 0] = 0
    return mask


def mask_center(kernel):
    # kernel shape: [dim, dim, n_channels, n_filters]
    dim, *_ = kernel.get_shape().as_list()
    mask = tf.constant(_get_center_mask(dim), dtype=kernel.dtype)
    return kernel * mask


class ContinuousTimeHopfieldLayer(tf.keras.layers.Layer):

    def __init__(self,
                 non_identity_recon: NonidentityRecon,
                 tau=1,
                 static_solver=RKF56Solver(
                     dt=1e-1, tol=1e-3, min_dt=1e-2),
                 dynamical_solver=DynamicalRKF56Solver(
                     dt=1e-1, tol=1e-3, min_dt=1e-2),
                 max_time=1e+3,
                 relax_tol=1e-3,
                 reg_factor=1e-0,
                 **kwargs):
        super().__init__(**kwargs)
        self.non_identity_recon = non_identity_recon
        self.tau = float(tau)
        self.static_solver = static_solver
        self.dynamical_solver = dynamical_solver
        self.max_time = float(max_time)
        self.relax_tol = float(relax_tol)
        self.reg_factor = float(reg_factor)

    def build(self, input_shape):
        f = self.non_identity_recon

        def dynamics(t, x):
            return (-x + f(x)) / self.tau

        stop_condition = StopCondition(dynamics, self.max_time, self.relax_tol)
        node_fn = get_dynamical_node_function(
          self.dynamical_solver, self.static_solver, dynamics, stop_condition)

        self._stop_condition = stop_condition
        self._node_fn = node_fn

        super().build(input_shape)

    def call(self, x, training=None):
        if training:
            r = self.non_identity_recon(x)
            loss = tf.reduce_mean(tf.abs(x - r))
            self.add_loss(self.reg_factor * loss)
            return r
        else:
            t0 = tf.constant(0.)
            return self._node_fn(t0, x)


class DiscreteTimeHopfieldLayer(tf.keras.layers.Layer):
    """
    References
    ----------
    1. Information Theory, Inference, and Learning Algorithm (D. Mackay),
       chapter 42.

    Parameters
    ----------
    non_identity_recon : NonidentityRecon
    max_steps : int
    async_ratio : float, optional
        Percentage of "bits" to be randomly masked in each updation.
    relax_tol : float, optional
    reg_factor : float, optional

    Attributes
    ----------
    final_step : int32 scalar
    """

    def __init__(self,
                 non_identity_recon: NonidentityRecon,
                 max_steps: int,
                 async_ratio=0.,
                 relax_tol=1e-3,
                 reg_factor=1e-0,
                 **kwargs):
        super().__init__(**kwargs)
        self.non_identity_recon = non_identity_recon
        self.max_steps = max_steps
        self.async_ratio = float(async_ratio)
        self.relax_tol = float(relax_tol)
        self.reg_factor = float(reg_factor)

        self.final_step = tf.Variable(0, trainable=False)

    @tf.function
    def _update(self, x):
        y = self.non_identity_recon(x, training=False)
        if self.async_ratio > 0:
            # mask has no batch dim
            mask = tf.where(
                tf.random.uniform(y.shape[1:]) < self.async_ratio,
                0., 1.)
            y *= mask[tf.newaxis, ...]
        return y

    def call(self, x, training=None):
        if training:
            r = self.non_identity_recon(x, training=True)
            loss = tf.reduce_mean(tf.abs(x - r))
            self.add_loss(self.reg_factor * loss)
            return r

        else:
            for step in tf.range(self.max_steps):
                next_x = self._update(x)
                if diff(next_x, x) < self.relax_tol:
                    break
                x = next_x
            self.final_step.assign(step + 1)
            return x


def diff(x, y):
    return tf.reduce_max(tf.abs(x - y))

In [4]:
def create_model(model_type):
    if model_type == 'continuous_dense':
        model = tf.keras.Sequential([
            CTHLayer(DenseRecon(activation=tf.tanh),
                     reg_factor=1),
        ])
    elif model_type == 'continuous_cnn':
        model = tf.keras.Sequential([
            CTHLayer(Conv2dRecon(filters=16, kernel_size=5, activation=tf.tanh,
                                 flatten=True),
                     reg_factor=1),
        ])
    elif model_type == 'discrete_dense':
        model = tf.keras.Sequential([
            DTHLayer(DenseRecon(activation=tf.tanh, binarize=sign),
                     max_steps=100,
                     reg_factor=1),
        ])
    elif model_type == 'discrete_cnn':
        model = tf.keras.Sequential([
            DTHLayer(Conv2dRecon(filters=16, kernel_size=5,
                                 activation=tf.tanh, binarize=sign,
                                 flatten=True),
                     max_steps=100,
                     reg_factor=1),
        ])
    else:
        raise ValueError()
    model.compile(optimizer='adam')
    return model

In [5]:
model = create_model('discrete_dense')
X = x_train[:100].numpy()
ds0 = tf.data.Dataset.from_tensor_slices(X)
ds = ds0.shuffle(10000).repeat(10000).batch(128)
model.fit(ds)



<tensorflow.python.keras.callbacks.History at 0x7f9fe45ab990>

In [6]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dth_layer (DTHLayer)         (None, 1024)              1048577   
Total params: 1,048,577
Trainable params: 1,048,576
Non-trainable params: 1
_________________________________________________________________


In [15]:
# noised_X = X + np.random.normal(size=X.shape) * 0.1
noised_X = np.where(np.random.random(size=X.shape) < 0.3, -X, X)
recon_X = model.predict(noised_X)

try:
    print('Relax time:', model.layers[-1]._stop_condition.relax_time.numpy())
except Exception:
    pass
try:
    print('Relax steps:', model.layers[-1].final_step.numpy())
except Exception:
    pass

orig_err = noised_X - X
err = recon_X - X
print(f'{np.quantile(np.abs(orig_err), 0.99)} => '
      f'{np.quantile(np.abs(err), 0.99)}')

Relax steps: 6
2.0 => 0.0


## Conclusions and Discussions

### Resource Occupations

#### Time

1. Dense version is much faster than CNN version.

#### Space

1. CNN version needs only ~ 10^2 parameters. Recall that dense version needs 10^7 parameters.

1. To reduce the number of variables in the dense version, use [prunning](https://stackoverflow.com/a/56451791/1218716) after training.

### De-noising

1. However, CNN version is not robust to bit-flipping. Dense version is still very robust to it. Bit-flipping fails for CNN version hints that the information is not sparsely (distributedly) stored. Thus it cannot re-construct the original bit only from the information stored in its local neighbors. (Notice that bit-flipping creates non-smooth, thus always great, differences.) To see this, run the re-constructor on the bit-flipping noised inputs to see the 0.99-quantile of the re-construction error, comparing for both dense and CNN versions.

1. Dense version gains 99% re-construction even for 40% bit-flipping.

### Binarization

1. Binarization is also essential to CNN version. Non-binarized inputs won't de-noise. The essense of binarization maybe traced to the simplicity it leads to. Indeed, the final loss without binarization will be greater (0.03X -> 0.04X).

1. Change X in {-1, 1} to {0, 1} causes error in de-noising. Don't know why.

### Discrete Time

1. Discrete time version is much much faster in predicting. Without lossing the attributes the continuous version has

1. Async update decreases the performance.

### Discrete State

1. Discrete time when using discrete time improves performance significantly.

## References

1. Cellular automata as convolutional neural networks (arXiv: 1809.02942).