In [1]:
import tensorflow as tf
from mnist import load_mnist, evaluate

tf.random.set_seed(42)

tf.keras.backend.clear_session()

print(tf.__version__)

2.3.0


In [2]:
# configurations

IMAGE_SIZE = (16, 16)
BINARIZE = True
AUTOENCODER_TYPE = ('vanilla', 'variational')[1]
NUM_DATA = 100000

In [3]:
(x_train, _), _ = load_mnist(image_size=IMAGE_SIZE, binarize=BINARIZE)
x_train = (x_train + 1) / 2  # x \in [0, 1]

In [4]:
def binarize(x, threshold):
    """Returns 1 if x > threshold else 0, element-wisely.

    Parameters
    ----------
    x : tensor
    threshold : float

    Returns
    -------
    tensor
        The same shape and dtype as x.
    """
    y = tf.where(x > threshold, 1, 0)
    y = tf.cast(y, x.dtype)
    return y


def softly_binarize(x, threshold):
    """Returns 1 if x > threshold else 0, element-wisely, with the gradients
    :math:`\partial f_i / \partial x_j = \delta_{i j}`, i.e. an unit Jacobian.

    Parameters
    ----------
    x : tensor
    threshold : float

    Returns
    -------
    tensor
        The same shape and dtype as x.
    """

    def identity(dy):
        return dy

    @tf.custom_gradient
    def fn(x):
        y = binarize(x, threshold)
        return y, identity

    return fn(x)


class SoftBinarization(tf.keras.layers.Layer):
    """For using in tf.keras.Sequential.

    If in training phase, then do nothing. Otherwise, make soft binarization.

    Parameters
    ----------
    threshold : float
    """

    def __init__(self, threshold, **kwargs):
        super().__init__(**kwargs)
        self.threshold = threshold

    def call(self, x, training=None):
        if training:
            return x
        return softly_binarize(x, self.threshold)

In [5]:
class MLP(tf.keras.layers.Layer):
    """Multi-layer perceptron (MLP).

    Parameters
    ----------
    units : [int]
    activation : callable or string, optional
    """

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

    def build(self, batch_input_shape):
        layers = [tf.keras.layers.Dense(n, 'relu') for n in self.units[:-1]]
        layers.append(tf.keras.layers.Dense(self.units[-1], self.activation))
        self._ffn = tf.keras.Sequential(layers)
        self._ffn.build(batch_input_shape)
        super().build(batch_input_shape)

    def call(self, x):
        y = self._ffn(x)
        return y

In [6]:
class LatentBernoulliVanillaAutoencoder(tf.keras.layers.Layer):
    """
    Parameters
    ----------
    units : [int]
    activation : callable or string, optional
    """

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

        self._encoder = MLP(units, 'sigmoid')

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = MLP(units, self.activation)
        super().build(batch_input_shape)

    def encode(self, x):
        z = self._encoder(x)
        z = softly_binarize(z, 0.5)
        return z

    def decode(self, z):
        x = self._decoder(z)
        return x

    def call(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon

In [7]:
class LatentBernoulliVariationalAutoencoder(tf.keras.layers.Layer):
    """
    References
    ----------
    1. https://davidstutz.de/bernoulli-variational-auto-encoder-in-torch/
    2. Information Theory, Inference, and Learning Algorithms, D. Mackay,
       section 33.2

    Parameters
    ----------
    units : [int]
    activation : callable or string, optional
    temperature : float, optional
    """

    def __init__(self, units, activation=None, temperature=0, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation
        self.temperature = float(temperature)

        self._encoder = MLP(units, name='encoder')

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = MLP(units, self.activation, name='decoder')
        super().build(batch_input_shape)

    def encode(self, x, training=None):
        """Add regularizer if training."""
        latent_logits = self._encoder(x)
        z = self._reparam_trick(latent_logits)
        z = softly_binarize(z, 0.5)

        if training and self.temperature:
            latent_entropy = self._latent_entropy(latent_logits)
            self.add_loss(- self.temperature * latent_entropy)

        return z

    def decode(self, z):
        return self._decoder(z)

    def call(self, x, training=None):
        z = self.encode(x, training)
        x_recon = self.decode(z)
        return x_recon

    @staticmethod
    def _reparam_trick(logits):
        """
        Notes
        -----
        s: random seed
        p: Bernoulli probability
        g(s, p): re-parameterization trick for Bernoulli distribution
            E_{z ~ bernoulli(p=f(x; w))} [...(z)]
            -> E_{s ~ uniform(0, 1)} [...(z=g(s, p=f(x;w)))],
            where gradient(g(s, p), p) exists.

        Lemma:
            s ~ uniform(0, 1)
            a = s / (1 - s) * p / (1 - p)
            z = 1 if a > 1 else 0
            => z ~ bernoulli(p)
        """
        # seed
        eps = 1e-8
        s = tf.random.uniform(logits.shape[1:], minval=eps, maxval=1-eps)
        s = s[tf.newaxis, ...]

        # employ log(sigmoid(x)) - log(1 - sigmoid(x)) = x
        a = tf.math.log(s) - tf.math.log(1 - s) + logits

        z = tf.nn.sigmoid(a)
        return z

    @staticmethod
    def _latent_entropy(latent_logits):
        """Entropy of the latent variable."""
        p = tf.nn.sigmoid(latent_logits)
        log_p = log_sigmoid(latent_logits)
        log_1mp = log_1m_sigmoid(latent_logits)
        entropy = -p * log_p - (1 - p) * log_1mp
        return tf.reduce_mean(entropy)
    
    def latent_entropy(self, x):
        latent_logits = self._encoder(x)
        return self._latent_entropy(latent_logits)


def log_sigmoid(x):
    """log(sigmoid(x)) = x - softplus(x)"""
    return x - tf.nn.softplus(x)


def log_1m_sigmoid(x):
    """log(1 - sigmoid(x)) = - softplus(x)"""
    return - tf.nn.softplus(x)

In [8]:
layers = [tf.keras.Input([IMAGE_SIZE[0] * IMAGE_SIZE[1]])]
if AUTOENCODER_TYPE == 'vanilla':
    layers += [
        LatentBernoulliVanillaAutoencoder([128], 'sigmoid'),
    ]
elif AUTOENCODER_TYPE == 'variational':
    layers += [
        LatentBernoulliVariationalAutoencoder([128], 'sigmoid', temperature=1e-1),
    ]
else:
    raise ValueError()
if BINARIZE:
    layers.append(SoftBinarization(0.5))
ae = tf.keras.Sequential(layers)
ae.compile(loss='binary_crossentropy', optimizer='adam')

In [9]:
if AUTOENCODER_TYPE == 'variational':
    print(ae.layers[0].latent_entropy(x_train[:1000]))

tf.Tensor(0.67251235, shape=(), dtype=float32)


In [10]:
ds = tf.data.Dataset.from_tensor_slices((x_train[:NUM_DATA], x_train[:NUM_DATA]))
ds = ds.shuffle(10000).repeat(20).batch(128)
ae.fit(ds)



<tensorflow.python.keras.callbacks.History at 0x7fb922dbab50>

In [11]:
evaluate(ae, x_train[:1000])

0.96283984375

In [12]:
if AUTOENCODER_TYPE == 'variational':
    print(ae.layers[0].latent_entropy(x_train[:1000]))

tf.Tensor(0.39440462, shape=(), dtype=float32)
