In [81]:
import os
import sys

import numpy as np

os.environ['CUDA_VISIBLE_DEVICES'] = '15'

from keras import backend as K

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [98]:
from inspect import signature

def _calc_gaussian_log_density(data, mu, log_sigma):
    """
    Given data samples and their distributions, calculate each samples' log
    density. Gaussian probability density function:

      p(x) = (1 / (sigma * sqrt(2 * pi))) * exp(-0.5 * (x - mu)/sigma)^2)

    Then:

      log[p(x)] = -log[sigma] - 0.5 * log[2 * pi] - 0.5 * [(x - mu)/sigma]^2
      log[p(x)] = -0.5 * [(x - mu / sigma)^2 + 2 * log[sigma] + log[2 * pi]]

    Let:

      c = log[2 * pi]
      inv_sigma = 1 / sigma
      tmp = (x - mu) / sigma = (x - mu) * inv_sigma

    Then:

      log[p(x)] = -0.5 * [tmp * tmp + 2 * log[sigma] + c]
    """
    
    c = np.log(2 * np.pi)
    inv_sigma = K.exp(-log_sigma)
    tmp = (data - mu) * inv_sigma
    log_density = -0.5 * (tmp*tmp + 2*log_sigma + c)
    return log_density


def _logsumexp(value, axis, keepdims=False):
    """
    A numerically stable computation for chaining the operations: log, sum, and
    exp.
    
        log[sum_i(exp(x_i))]
        = m - m + log[sum_i(exp(x_i))]
        = m + log[1/exp(m)] + log[sum_i(exp(x_i))]
        = m + log[(1/exp(m))*sum_i(exp(x_i))]
        = m + log[sum_i(exp(x_i)/exp(m))]
        = m + log[sum_i(exp(x_i-m))]
    """
    
    m = K.max(value, axis=axis, keepdims=True)
    _value = value - m
    if not keepdims:
        m = K.squeeze(m, axis=axis)
    result = m + K.log(K.sum(K.exp(_value), axis=axis, keepdims=keepdims))
    return result


def _log_importance_weight_matrix(batch_size, dataset_size):
    """
    Returns a weight matrix that is used to estimate the unconditional latent
    distribution q(z).
    """
    N = dataset_size
    M = batch_size - 1
    strat_weight = (N - M) / (N * M)
    W = K.ones(shape=(batch_size, batch_size))
    W.fill(1 / M)
    W = W * 1/m
    W[:, 0] = 1 / N
    W[:, 1] = strat_weight
    W[M-1, 0] = strat_weight
    W = np.log(W)
    return W


def mutual_information_index(batch_size, dataset_size):
    
    def mutual_information_index(z, z_mu, z_log_sigma, **_):
        
        # Calculate log densities for sampled latent vectors given the
        # distributions. Generated by the encoder network from the input
        # images (aka q(z|x)).
        logqz_condx = _calc_gaussian_log_density(
            z, z_mu, z_log_sigma)
        logqz_condx = K.sum(logqz_condx, axis=1)
        
        # Calculate the log densities from the aggregate latent posterior
        # distribution q(z).
        # log q(z) ~= log 1 /(NM) sum_m=1^M q(z|x_m)
        # = - log(MN) + logsumexp_m(q(z|x_m))
        _logqz = _calc_gaussian_log_density(
            K.reshape(z, shape=(-1, 1, int(z.shape[1]))),
            K.reshape(z_mu, shape=(-1, 1, int(z_mu.shape[1]))),
            K.reshape(z_log_sigma, shape=(-1, 1, int(z_log_sigma.shape[1]))))
        
        # Estimate log[q(z)]
        logiw_matrix = _log_importance_weight_matrix(batch_size, dataset_size)
        logqz = _logsumexp(
            logiw_matrix + K.sum(_logqz, axis=2), axis=1, keepdims=False)
        
        # This tensor corresponds to but is not equivalent to the mi_index term
        # of the decomposed divergence penalty. However, minimizing this
        # serves to minimize the true penalty term.
        mi_index = K.mean(logqz_condx - logqz)
        return mi_index
    
    return mutual_information_index


def total_correlation(batch_size, dataset_size):
    
    def total_correlation(z, z_mu, z_log_sigma, **_):
        
        # Calculate the log densities from the aggregate latent posterior
        # distribution q(z).
        # log q(z) ~= log 1 /(NM) sum_m=1^M q(z|x_m)
        # = - log(MN) + logsumexp_m(q(z|x_m))
        _logqz = _calc_gaussian_log_density(
            K.reshape(z, shape=(-1, 1, int(z.shape[1]))),
            K.reshape(z_mu, shape=(-1, 1, int(z_mu.shape[1]))),
            K.reshape(z_log_sigma, shape=(-1, 1, int(z_log_sigma.shape[1]))))
        
        # Estimate log[q(z)]
        logiw_matrix = _log_importance_weight_matrix(batch_size, dataset_size)
        logqz = _logsumexp(
            logiw_matrix + K.sum(_logqz, axis=2), axis=1, keepdims=False)
        
        # Estimate log[prod_j[q(z_j)]]
        shape = (batch_size, batch_size, 1)
        logqz_prod_marginals = _logsumexp(
            K.reshape(logiw_matrix, shape=shape) + _logqz,
            axis=1, keepdims=False)
        logqz_prod_marginals = K.sum(logqz_prod_marginals, axis=1)
        
        # This tensor corresponds to but is not equivalent to the total
        # correlation term of the decomposed divergence penalty. However,
        # minimizing this serves to minimize the true penalty term.
        tc = K.mean(logqz - logqz_prod_marginals)
        return tc
    
    return total_correlation


def dimensional_kl(batch_size, dataset_size):
    
    def dimensional_kl(z, z_mu, z_log_sigma, **_):
        
        # Calculate log densities for sampled latent vectors in a standard
        # Gaussian distribution (zero mean, unit variance)
        logpz = _calc_gaussian_log_density(
          z, K.zeros_like(z_mu), K.zeros_like(z_log_sigma))
        logpz = K.sum(logpz, axis=1)
        
        # Estimate log[prod_j[q(z_j)]]
        _logqz = _calc_gaussian_log_density(
            K.reshape(z, shape=(-1, 1, int(z.shape[1]))),
            K.reshape(z_mu, shape=(-1, 1, int(z_mu.shape[1]))),
            K.reshape(z_log_sigma, shape=(-1, 1, int(z_log_sigma.shape[1]))))
        logiw_matrix = _log_importance_weight_matrix(batch_size, dataset_size)
        shape = (batch_size, batch_size, 1)
        logqz_prod_marginals = _logsumexp(
            K.reshape(logiw_matrix, shape=shape) + _logqz,
            axis=1, keepdims=False)
        logqz_prod_marginals = K.sum(logqz_prod_marginals, axis=1)
        
        # This tensor corresponds to but is not equivalent to the dimensional
        # KL term of the decomposed divergence penalty. However,
        # minimizing this serves to minimize the true penalty term.     
        dim_kl = K.mean(logqz_prod_marginals - logpz)
        return dim_kl
    
    return dimensional_kl   

In [74]:
def _check_loss_fn(loss_fn, batch_size, dataset_size):
    """
    Converts TCVAE latent loss functions into a format similar to that
    of reconstruction and VAE latent loss functions.
    """
    
    fn_signature = tuple(signature(loss_fn).parameters.keys())
    if 'batch_size' in fn_signature or 'dataset_size' in fn_signature:
        return loss_fn(batch_size, dataset_size)
    else:
        return loss_fn

In [97]:
K.reshape(np.random.randn(100, 32), shape=(32, 100))

<tf.Tensor 'Reshape_5:0' shape=(32, 100) dtype=float64>

In [103]:
K.cast(np.expand_dims(np.random.randn(32, 32), axis=-1), dtype='float32')

<tf.Tensor 'Cast_3:0' shape=(32, 32, 1) dtype=float32>