In [1]:
# !pip install tensorflow
# For integer quantization:
# !pip install tensorflow_model_optimization

In [2]:
import tensorflow as tf
from mnist import load_mnist, evaluate
from hopfield.utils import softly_binarize, SoftBinarization

tf.random.set_seed(42)

tf.keras.backend.clear_session()

print(tf.__version__)

2.3.0


In [3]:
# global configurations

IMAGE_SIZE = (16, 16)
BINARIZE = True
AUTOENCODER_TYPE = ('vanilla', 'variational')[1]
NUM_DATA = 100000

In [4]:
(x_train, _), _ = load_mnist(image_size=IMAGE_SIZE, binarize=BINARIZE)
x_train = (x_train + 1) / 2  # x \in [0, 1]

In [5]:
# def softly(fn):
#     r"""Decorator that returns func(*x, **kwargs), with the gradients on the x
#     :math:`\partial f_i / \partial x_j = \delta_{i j}`, i.e. an unit Jacobian,
#     or say, an identity vector-Jacobian-product.
#     """

#     def identity(*dy):
#         if len(dy) == 1:
#             dy = dy[0]
#         return dy

#     @tf.custom_gradient
#     def softly_fn(*args, **kwargs):
#         y = fn(*args, **kwargs)
#         return y, identity

#     return softly_fn


# def softly_binarize(x, threshold, minval=0, maxval=1, from_logits=False):
#     r"""Returns `maxval` if x > threshold else `minval`, element-wisely, with
#     the gradients :math:`\partial f_i / \partial x_j = \delta_{i j}`, i.e. an
#     unit Jacobian.

#     Parameters
#     ----------
#     x : tensor
#     threshold : float
#     minval : real number
#     maxval : real number
#     from_logits : bool, optional
#         If true, then softly binarize sigmoid(x) instead of x.

#     Returns
#     -------
#     tensor
#         The same shape and dtype as x.
#     """

#     @softly
#     def binarize(x):
#         y = tf.where(x > threshold, maxval, minval)
#         y = tf.cast(y, x.dtype)
#         return y

#     return binarize(tf.nn.sigmoid(x)) if from_logits else binarize(x)


# class SoftBinarization(tf.keras.layers.Layer):
#     """For using in tf.keras.Sequential.

#     If in training phase, then do nothing. Otherwise, make soft binarization.

#     Parameters
#     ----------
#     threshold : float
#     from_logits : bool, optional
#         If true, then softly binarize sigmoid(x) instead of x.
#     """

#     def __init__(self, threshold, from_logits=False, **kwargs):
#         super().__init__(**kwargs)
#         self.threshold = threshold
#         self.from_logits = from_logits

#     def get_config(self):
#         config = super().get_config()
#         config['threshold'] = self.threshold
#         config['from_logits'] = self.from_logits
#         return config

#     def call(self, x, training=None):
#         if self.from_logits:
#             x = tf.nn.sigmoid(x)
#         if training:
#             return x
#         return softly_binarize(x, self.threshold)

In [6]:
class MLP(tf.keras.layers.Layer):
    """Multi-layer perceptron (MLP).

    Parameters
    ----------
    units : [int]
    activation : callable or string, optional
    """

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

    def get_config(self):
        config = super().get_config()
        config['units'] = self.units
        config['activation'] = self.activation
        return config

    def build(self, batch_input_shape):
        layers = [tf.keras.layers.Dense(n, 'relu') for n in self.units[:-1]]
        layers.append(tf.keras.layers.Dense(self.units[-1], self.activation))
        self._ffn = tf.keras.Sequential(layers)
        self._ffn.build(batch_input_shape)
        super().build(batch_input_shape)

    def call(self, x):
        y = self._ffn(x)
        return y

In [7]:
class LatentBernoulliVanillaAutoencoder(tf.keras.layers.Layer):
    """
    Parameters
    ----------
    units : [int]
        Hidden units along the encoder. The decoder use symmetric structure.
    activation : callable or string, optional
        Final output activation.
    """

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

        self._encoder = MLP(units, 'sigmoid')

    def get_config(self):
        config = super().get_config()
        config['units'] = self.units
        config['activation'] = self.activation
        return config

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = MLP(units, self.activation)
        super().build(batch_input_shape)

    def encode(self, x):
        z = self._encoder(x)
        z = softly_binarize(z, threshold=0.5)
        return z

    def decode(self, z):
        x = self._decoder(z)
        return x

    def call(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon

In [8]:
class LatentBernoulliVariationalAutoencoder(tf.keras.layers.Layer):
    """
    References
    ----------
    1. https://davidstutz.de/bernoulli-variational-auto-encoder-in-torch/
    2. Information Theory, Inference, and Learning Algorithms, D. Mackay,
       section 33.2

    Parameters
    ----------
    units : [int]
        Hidden units along the encoder. The decoder use symmetric structure.
    activation : callable or string, optional
        Final output activation.
    num_samples : int, optional
        Number of samples for Monte-Carlo integral.
    """

    def __init__(self, units, activation=None, num_samples=1, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation
        self.num_samples = num_samples

        self._encoder = MLP(units, name='encoder')

    def get_config(self):
        config = super().get_config()
        config['units'] = self.units
        config['activation'] = self.activation
        config['num_samples'] = self.num_samples
        return config

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = MLP(units, self.activation, name='decoder')

        # NOTE:
        # Since losses are computed by mean instead of sum, over all axes,
        # we shall add a factor between the log-likelihood and the entropy.
        # Precisely, the loss should be mean_b(sum_a(...) + sum_l(...)),
        # where b for batch axis, a for ambient axis, and l for latent axis.
        # However, the loss computed is mean_b(mean_a(...) + r * mean_l(...)),
        # where r is a factor to be determined. To make the two losses
        # proportional, we shall set r = #l / #a.
        latent_dim = self.units[-1]
        self._reg_factor = latent_dim / ambient_dim

        super().build(batch_input_shape)

    def encode(self, x, training=None):
        """Add regularizer if training."""
        latent_logits = self._encoder(x)

        if training:
            sampled_latent_logits = self._reparam_trick(
                latent_logits, self.num_samples)
            latent_entropy = self._latent_entropy(sampled_latent_logits)
            self.add_loss(- self._reg_factor * tf.reduce_mean(latent_entropy))

        else:
            sampled_latent_logits = self._reparam_trick(latent_logits, 1)

        # use the first sample as the result to return
        z = softly_binarize(tf.nn.sigmoid(sampled_latent_logits[0]),
                            threshold=0.5)
        return z

    def decode(self, z):
        return self._decoder(z)

    def call(self, x, training=None):
        z = self.encode(x, training)
        x_recon = self.decode(z)
        return x_recon

    @staticmethod
    def _reparam_trick(logits, num_samples):
        """
        Notes
        -----
        s: random seed
        p: Bernoulli probability
        g(s, p): re-parameterization trick for Bernoulli distribution
            E_{z ~ bernoulli(p=f(x; w))} [...(z)]
            -> E_{s ~ uniform(0, 1)} [...(z=g(s, p=f(x;w)))],
            where gradient(g(s, p), p) exists.

        Lemma:
            s ~ uniform(0, 1)
            a = s / (1 - s) * p / (1 - p)
            z = 1 if a > 1 else 0
            => z ~ bernoulli(p)

        Parameters
        ----------
        logits : tensor
            Shape [batch_size, depth].
        num_samples : int

        Returns
        -------
        tensor
            Shape [num_samples, batch_size, depth].
        """
        # seed
        eps = 1e-8
        s = tf.random.uniform(shape=([num_samples, 1] + logits.shape[1:]),
                              minval=eps, maxval=1-eps)  # [S, B, D]

        logits = logits[tf.newaxis, ...]  # [1, B, D]
        # employ the relation log(sigmoid(x)) - log(1 - sigmoid(x)) = x
        sampled_logits = tf.math.log(s) - tf.math.log(1 - s) + logits
        return sampled_logits  # [S, B, D]

    @staticmethod
    def _latent_entropy(latent_logits):
        """Entropy of the latent variable, per dimension.

        Using Monte-Carlo integral.

        Parameters
        ----------
        latent_logits : tensor
            Shape [num_samples, batch_size, depth].

        Returns
        -------
        tensor
            Shape [batch_size, depth].
        """
        # [S, B, D]
        p = tf.nn.sigmoid(latent_logits)
        log_p = log_sigmoid(latent_logits)
        log_1mp = log_1m_sigmoid(latent_logits)
        entropy = tf.reduce_mean(
            -p * log_p - (1 - p) * log_1mp,  # [S, B, D]
            axis=0)  # [B, D]
        return entropy

    def latent_entropy(self, x):
        latent_logits = self._encoder(x)
        sampled_latent_logits = self._reparam_trick(
            latent_logits, self.num_samples)
        return self._latent_entropy(sampled_latent_logits)  # [B, D]


def log_sigmoid(x):
    """log(sigmoid(x)) = x - softplus(x)"""
    return x - tf.nn.softplus(x)


def log_1m_sigmoid(x):
    """log(1 - sigmoid(x)) = - softplus(x)"""
    return - tf.nn.softplus(x)

In [9]:
layers = [tf.keras.Input([IMAGE_SIZE[0] * IMAGE_SIZE[1]])]
if AUTOENCODER_TYPE == 'vanilla':
    layers += [
        LatentBernoulliVanillaAutoencoder([64], 'sigmoid'),
    ]
elif AUTOENCODER_TYPE == 'variational':
    layers += [
        LatentBernoulliVariationalAutoencoder([64], 'sigmoid',
                                              num_samples=10),
    ]
else:
    raise ValueError()
if BINARIZE:
    layers.append(SoftBinarization(0.5))
ae = tf.keras.Sequential(layers)
ae.compile(loss='binary_crossentropy', optimizer='adam')

In [10]:
evaluate(ae, x_train)

0.47646171875

In [11]:
if AUTOENCODER_TYPE == 'variational':
    print(tf.reduce_sum(ae.layers[0].latent_entropy(x_train[:1000]), axis=1))

tf.Tensor(
[31.543896 31.524052 31.53149  31.965534 31.54101  31.55637  31.642353
 31.308805 31.86382  31.670486 31.400383 31.900991 31.452078 31.507177
 31.86161  31.555687 31.520576 31.709038 31.836185 31.489807 31.212074
 31.33173  31.81814  31.879488 31.947119 31.338264 31.722912 31.301363
 31.329525 31.791729 31.463463 31.574093 31.787598 31.548285 31.375692
 31.91607  31.27531  31.2925   31.684343 31.446676 31.688192 31.567421
 31.909021 31.463312 31.869528 31.429607 31.477207 31.585764 31.791092
 31.462646 31.760052 31.316238 31.465881 31.797115 31.379246 31.431948
 31.285156 31.437628 31.185686 31.842068 31.993721 31.832811 31.284138
 30.98376  31.367378 31.889423 31.5312   31.977505 31.53477  31.258032
 31.53238  31.772896 31.785233 31.66428  31.378485 31.450638 31.551704
 31.812004 31.750397 31.769497 31.7119   31.47872  31.466946 31.427832
 31.693497 31.478878 31.856358 31.35709  31.503155 31.619205 31.147844
 31.556229 31.573456 31.657623 31.503632 31.397198 31.68028  31.39

In [12]:
ds = tf.data.Dataset.from_tensor_slices(
    (x_train[:NUM_DATA], x_train[:NUM_DATA]))
ds = ds.shuffle(10000).repeat(30).batch(128)
ae.fit(ds)



<tensorflow.python.keras.callbacks.History at 0x7ff82e9f2050>

In [13]:
evaluate(ae, x_train)

0.9384083333333333

In [14]:
if AUTOENCODER_TYPE == 'variational':
    print(tf.reduce_sum(ae.layers[0].latent_entropy(x_train[:1000]), axis=1))

tf.Tensor(
[24.115875 23.449432 24.686804 25.713158 26.370564 24.737907 24.313774
 23.923756 26.921467 25.201267 24.497602 24.8684   19.681847 23.896519
 25.995813 24.782166 27.07545  24.605915 28.056368 26.23744  20.100037
 21.993702 27.763973 25.671062 26.260931 21.10859  27.865786 21.467285
 20.319391 26.330294 24.35022  22.199757 25.53824  24.434954 22.852627
 25.563826 23.516817 22.747967 24.421787 26.06112  26.24363  26.158676
 27.798985 26.474957 27.848686 25.855877 27.258793 24.138905 27.198025
 24.395645 26.659698 19.643288 23.43298  26.087477 23.673374 24.426994
 20.330238 26.128368 22.321009 24.999609 19.720984 26.505764 22.941936
 21.254086 24.284264 25.678848 24.701141 26.678791 22.817497 22.092617
 23.734932 24.93187  26.456257 25.266794 25.315685 23.578976 25.272968
 25.973007 25.515083 25.663837 24.7832   23.09982  22.075518 24.729897
 25.607943 25.264973 22.566319 26.18824  22.075565 23.621403 21.677147
 22.908035 27.433706 26.778246 25.514889 22.735844 26.073597 24.20