In [1]:
# !pip install tensorflow
# For integer quantization:
# !pip install tensorflow_model_optimization

In [2]:
import tensorflow as tf
from mnist import load_mnist, evaluate

tf.random.set_seed(42)

tf.keras.backend.clear_session()

print(tf.__version__)

2.3.0


In [3]:
# global configurations

IMAGE_SIZE = (16, 16)
BINARIZE = True
AUTOENCODER_TYPE = ('vanilla', 'variational')[0]
NUM_DATA = 100000

In [4]:
(x_train, _), _ = load_mnist(image_size=IMAGE_SIZE, binarize=BINARIZE)
x_train = (x_train + 1) / 2  # x \in [0, 1]

In [5]:
def binarize(x, threshold):
    """Returns 1 if x > threshold else 0, element-wisely.

    Parameters
    ----------
    x : tensor
    threshold : float

    Returns
    -------
    tensor
        The same shape and dtype as x.
    """
    y = tf.where(x > threshold, 1, 0)
    y = tf.cast(y, x.dtype)
    return y


def softly_binarize(x, threshold, from_logits=False):
    """Returns 1 if x > threshold else 0, element-wisely, with the gradients
    :math:`\partial f_i / \partial x_j = \delta_{i j}`, i.e. an unit Jacobian.

    Parameters
    ----------
    x : tensor
    threshold : float
    from_logits : bool, optional
        If true, then softly binarize sigmoid(x) instead of x.

    Returns
    -------
    tensor
        The same shape and dtype as x.
    """

    def identity(dy):
        return dy

    @tf.custom_gradient
    def fn(x):
        y = binarize(x, threshold)
        return y, identity

    return fn(tf.nn.sigmoid(x)) if from_logits else fn(x)


class SoftBinarization(tf.keras.layers.Layer):
    """For using in tf.keras.Sequential.

    If in training phase, then do nothing. Otherwise, make soft binarization.

    Parameters
    ----------
    threshold : float
    from_logits : bool, optional
        If true, then softly binarize sigmoid(x) instead of x.
    """

    def __init__(self, threshold, from_logits=False, **kwargs):
        super().__init__(**kwargs)
        self.threshold = threshold
        self.from_logits = from_logits

    def get_config(self):
        config = super().get_config()
        config['threshold'] = self.threshold
        config['from_logits'] = self.from_logits
        return config

    def call(self, x, training=None):
        if self.from_logits:
            x = tf.nn.sigmoid(x)
        if training:
            return x
        return softly_binarize(x, self.threshold)

In [6]:
class MLP(tf.keras.layers.Layer):
    """Multi-layer perceptron (MLP).

    Parameters
    ----------
    units : [int]
    activation : callable or string, optional
    """

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

    def get_config(self):
        config = super().get_config()
        config['units'] = self.units
        config['activation'] = self.activation
        return config

    def build(self, batch_input_shape):
        layers = [tf.keras.layers.Dense(n, 'relu') for n in self.units[:-1]]
        layers.append(tf.keras.layers.Dense(self.units[-1], self.activation))
        self._ffn = tf.keras.Sequential(layers)
        self._ffn.build(batch_input_shape)
        super().build(batch_input_shape)

    def call(self, x):
        y = self._ffn(x)
        return y

In [7]:
class LatentBernoulliVanillaAutoencoder(tf.keras.layers.Layer):
    """
    Parameters
    ----------
    units : [int]
        Hidden units along the encoder. The decoder use symmetric structure.
    activation : callable or string, optional
        Final output activation.
    """

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

        self._encoder = MLP(units, 'sigmoid')

    def get_config(self):
        config = super().get_config()
        config['units'] = self.units
        config['activation'] = self.activation
        return config

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = MLP(units, self.activation)
        super().build(batch_input_shape)

    def encode(self, x):
        z = self._encoder(x)
        z = softly_binarize(z, threshold=0.5)
        return z

    def decode(self, z):
        x = self._decoder(z)
        return x

    def call(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon

In [8]:
class LatentBernoulliVariationalAutoencoder(tf.keras.layers.Layer):
    """
    References
    ----------
    1. https://davidstutz.de/bernoulli-variational-auto-encoder-in-torch/
    2. Information Theory, Inference, and Learning Algorithms, D. Mackay,
       section 33.2

    Parameters
    ----------
    units : [int]
        Hidden units along the encoder. The decoder use symmetric structure.
    activation : callable or string, optional
        Final output activation.
    temperature : float, optional
    num_samples : int, optional
        Number of samples for Monte-Carlo integral.
    """

    def __init__(self, units, activation=None, temperature=0,
                 num_samples=32, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation
        self.temperature = float(temperature)
        self.num_samples = num_samples

        self._encoder = MLP(units, name='encoder')

    def get_config(self):
        config = super().get_config()
        config['units'] = self.units
        config['activation'] = self.activation
        config['temperature'] = self.temperature
        config['num_samples'] = self.asciinum_samples
        return config

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = MLP(units, self.activation, name='decoder')
        super().build(batch_input_shape)

    def encode(self, x, training=None):
        """Add regularizer if training."""
        latent_logits = self._encoder(x)

        if training:
            sampled_latent_logits = self._reparam_trick(
                latent_logits, self.num_samples)
            latent_entropy = self._latent_entropy(sampled_latent_logits)
            self.add_loss(- self.temperature * tf.reduce_mean(latent_entropy))

        else:
            sampled_latent_logits = self._reparam_trick(latent_logits, 1)

        # use the first sample as the result to return
        z = softly_binarize(tf.nn.sigmoid(sampled_latent_logits[0]),
                            threshold=0.5)
        return z

    def decode(self, z):
        return self._decoder(z)

    def call(self, x, training=None):
        z = self.encode(x, training)
        x_recon = self.decode(z)
        return x_recon

    @staticmethod
    def _reparam_trick(logits, num_samples):
        """
        Notes
        -----
        s: random seed
        p: Bernoulli probability
        g(s, p): re-parameterization trick for Bernoulli distribution
            E_{z ~ bernoulli(p=f(x; w))} [...(z)]
            -> E_{s ~ uniform(0, 1)} [...(z=g(s, p=f(x;w)))],
            where gradient(g(s, p), p) exists.

        Lemma:
            s ~ uniform(0, 1)
            a = s / (1 - s) * p / (1 - p)
            z = 1 if a > 1 else 0
            => z ~ bernoulli(p)

        Parameters
        ----------
        logits : tensor
            Shape [batch_size, depth...].
        num_samples : int

        Returns
        -------
        tensor
            Shape [num_samples, batch_size, depth...].
        """
        # seed
        eps = 1e-8
        s = tf.random.uniform(shape=([num_samples, 1] + logits.shape[1:]),
                              minval=eps, maxval=1-eps)  # [S, B, D...]

        logits = logits[tf.newaxis, ...]  # [1, B, D...]
        # employ the relation log(sigmoid(x)) - log(1 - sigmoid(x)) = x
        sampled_logits = tf.math.log(s) - tf.math.log(1 - s) + logits
        return sampled_logits  # [S, B, D...]

    @staticmethod
    def _latent_entropy(latent_logits):
        """Entropy of the latent variable.

        Parameters
        ----------
        latent_logits : tensor
            Shape [num_samples, batch_size, depth...].

        Returns
        -------
        tensor
            Shape [batch_size, depth...].
        """
        p = tf.nn.sigmoid(latent_logits)
        log_p = log_sigmoid(latent_logits)
        log_1mp = log_1m_sigmoid(latent_logits)
        entropy = tf.reduce_mean(
            -p * log_p - (1 - p) * log_1mp,
            axis=0)
        return entropy

    def latent_entropy(self, x):
        latent_logits = self._encoder(x)
        sampled_latent_logits = self._reparam_trick(
            latent_logits, self.num_samples)
        return self._latent_entropy(sampled_latent_logits)


def log_sigmoid(x):
    """log(sigmoid(x)) = x - softplus(x)"""
    return x - tf.nn.softplus(x)


def log_1m_sigmoid(x):
    """log(1 - sigmoid(x)) = - softplus(x)"""
    return - tf.nn.softplus(x)

In [9]:
layers = [tf.keras.Input([IMAGE_SIZE[0] * IMAGE_SIZE[1]])]
if AUTOENCODER_TYPE == 'vanilla':
    layers += [
        LatentBernoulliVanillaAutoencoder([64], 'sigmoid'),
    ]
elif AUTOENCODER_TYPE == 'variational':
    layers += [
        LatentBernoulliVariationalAutoencoder([64], 'sigmoid',
                                              temperature=1e-1),
    ]
else:
    raise ValueError()
if BINARIZE:
    layers.append(SoftBinarization(0.5))
ae = tf.keras.Sequential(layers)
ae.compile(loss='binary_crossentropy', optimizer='adam')

In [10]:
if AUTOENCODER_TYPE == 'variational':
    print(ae.layers[0].latent_entropy(x_train[:1000]))

In [11]:
ds = tf.data.Dataset.from_tensor_slices(
    (x_train[:NUM_DATA], x_train[:NUM_DATA]))
ds = ds.shuffle(10000).repeat(30).batch(128)
ae.fit(ds)



<tensorflow.python.keras.callbacks.History at 0x7f764c2d3410>

In [12]:
evaluate(ae, x_train)

0.9706830729166667

In [13]:
if AUTOENCODER_TYPE == 'variational':
    print(ae.layers[0].latent_entropy(x_train[:1000]))

In [14]:
import pickle

training_z = ae.layers[0].encode(x_train)

!mkdir ../dat/
with open('../dat/training_z.pkl', 'wb') as f:
    pickle.dump(training_z, f)
!ls -lhtr ../dat/

mkdir: cannot create directory ‘../dat/’: File exists
total 15M
-rwxrwxrwx 1 pxj pxj 15M Sep 22 10:35 training_z.pkl


In [15]:
# Quantization fails:

# import tensorflow_model_optimization as tfmot
# import re

# def get_class_name(layer):
#     s = str(type(layer))
#     m = re.search(r"<class '.*\.([a-zA-Z]+)'>", s)
#     return m.group(1)


# def clone_layer(layer):
#     return type(layer)(**layer.get_config())


# LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
# MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer

# class DefaultDenseQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
#     # Configure how to quantize weights.
#     def get_weights_and_quantizers(self, layer):
#         return [(layer.kernel, LastValueQuantizer(num_bits=8, symmetric=True, narrow_range=False, per_axis=False))]

#     # Configure how to quantize activations.
#     def get_activations_and_quantizers(self, layer):
#         return [(layer.activation, MovingAverageQuantizer(num_bits=8, symmetric=False, narrow_range=False, per_axis=False))]

#     def set_quantize_weights(self, layer, quantize_weights):
#         # Add this line for each item returned in `get_weights_and_quantizers`
#         # , in the same order
#         layer.kernel = quantize_weights[0]

#     def set_quantize_activations(self, layer, quantize_activations):
#         # Add this line for each item returned in `get_activations_and_quantizers`
#         # , in the same order.
#         layer.activation = quantize_activations[0]

#     # Configure how to quantize outputs (may be equivalent to activations).
#     def get_output_quantizers(self, layer):
#         return []

#     def get_config(self):
#         return {}


# quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
# quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
# quantize_scope = tfmot.quantization.keras.quantize_scope

# quant_layers = [tf.keras.Input([IMAGE_SIZE[0] * IMAGE_SIZE[1]])]
# quant_scope = {
#     'DefaultDenseQuantizeConfig': DefaultDenseQuantizeConfig,
# }
# for layer in ae.layers:
#     quant_scope[get_class_name(layer)] = type(layer)
#     quant_layers.append(
#         quantize_annotate_layer(
#             clone_layer(layer),
#             DefaultDenseQuantizeConfig()))

# model = quantize_annotate_model(tf.keras.Sequential(quant_layers))

# # `quantize_apply` requires mentioning `DefaultDenseQuantizeConfig` with `quantize_scope`
# # as well as the custom Keras layer.
# with quantize_scope(quant_scope):
#     # Use `quantize_apply` to actually make the model quantization aware.
#     quant_aware_model = tfmot.quantization.keras.quantize_apply(model)

# quant_aware_model.summary()