In [1]:
import tensorflow as tf
from mnist import load_mnist, evaluate

tf.random.set_seed(42)

tf.keras.backend.clear_session()

print(tf.__version__)

2.3.0


In [2]:
(x_train, _), _ = load_mnist(image_size=(16, 16), binarize=True)
x_train = (x_train + 1) / 2

In [3]:
def binarize(x, threshold):
    """
    Parameters
    ----------
    x : tensor
    threshold : float

    Returns
    -------
    tensor
        The same shape and dtype as x
    """
    y = tf.where(x > threshold, 1, 0)
    y = tf.cast(y, x.dtype)
    return y


def softly_binarize(x, threshold):

    def identity(dy):
        return dy

    @tf.custom_gradient
    def fn(x):
        y = binarize(x, threshold)
        return y, identity

    return fn(x)


class SoftBinarization(tf.keras.layers.Layer):

    def __init__(self, threshold, **kwargs):
        super().__init__(**kwargs)
        self.threshold = threshold

    def call(self, x):
        return softly_binarize(x, self.threshold)

In [4]:
class FeedForwardNetwork(tf.keras.layers.Layer):

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

    def build(self, batch_input_shape):
        layers = [tf.keras.layers.Dense(n, 'relu') for n in self.units[:-1]]
        layers.append(tf.keras.layers.Dense(self.units[-1], self.activation))
        self._ffn = tf.keras.Sequential(layers)
        self._ffn.build(batch_input_shape)
        super().build(batch_input_shape)

    def call(self, x):
        y = self._ffn(x)
        return y


class LatentBernoulliVanillaAutoencoder(tf.keras.layers.Layer):
    """
    References
    ----------
    1. https://davidstutz.de/bernoulli-variational-auto-encoder-in-torch/
    """

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

        self._encoder = FeedForwardNetwork(units, 'sigmoid')
        self._soft_binarization = SoftBinarization(0.5)

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = FeedForwardNetwork(units, self.activation)
        super().build(batch_input_shape)

    def encode(self, x):
        z = self._encoder(x)
        z = self._soft_binarization(z)
        return z

    def decode(self, z, training=None):
        x = self._decoder(z)
        if not training:
            x = self._soft_binarization(x)
        return x

    def call(self, x, training=None):
        return self.decode(self.encode(x), training)

In [5]:
class LatentBernoulliVariationalAutoencoder(tf.keras.Model):

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

        self._encoder = FeedForwardNetwork(units)
        self._soft_binarization = SoftBinarization(0.5)

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = FeedForwardNetwork(units, self.activation)
        super().build(batch_input_shape)

    def call(self, x, training=None):
        logits = self._encoder(x)
        z = self._reparam_trick(logits)
        x_recon = self._decoder(z)
        if training:
            self.add_loss(self._kl_div(z, logits))
        else:
            x_recon = self._soft_binarization(x_recon)
        return x_recon

    def _reparam_trick(self, logits):
        eps = 1e-8
        s = tf.random.uniform(logits.shape[1:], minval=eps, maxval=1-eps)
        s = s[tf.newaxis, ...]
        z = tf.nn.sigmoid(
            tf.math.log(s) - tf.math.log(1 - s)
            + logits)
        return z

    def _kl_div(self, z, logits):
        """KL-divergence between Q-distribution and latent prior."""
        log_z = -tf.nn.softplus(-logits)
        log_1mz = log_z - logits
        return tf.reduce_mean(
            tf.where(z > 0.5, log_z, log_1mz)
        )

In [6]:
ae = tf.keras.Sequential([
    # LatentBernoulliVanillaAutoencoder([128], 'sigmoid'),
    LatentBernoulliVariationalAutoencoder([128], 'sigmoid'),
])
ae.compile(loss='binary_crossentropy', optimizer='adam')

In [7]:
ds = tf.data.Dataset.from_tensor_slices((x_train[:10000], x_train[:10000]))
ds = ds.shuffle(10000).repeat(100).batch(128)
# ae.fit(x_train[:1000], x_train[:1000], epochs=1)
ae.fit(ds)



<tensorflow.python.keras.callbacks.History at 0x7fa098294210>

In [8]:
evaluate(ae, x_train[:100])

0.953671875