In [None]:
num_mode=10 ###number of mode of prior###
import numpy as np

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import tensorflow_datasets as tfds
import tensorflow_probability as tfp


tfk = tf.keras
tfkl = tf.keras.layers
tfpl = tfp.layers
tfd = tfp.distributions
datasets, datasets_info = tfds.load(name='mnist',
                                    with_info=True,
                                    as_supervised=False)

def _preprocess(sample):
  image = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
  image = image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
  return image, image

train_dataset = (datasets['train']
                 .map(_preprocess)
                 .batch(128)
                 .prefetch(tf.data.AUTOTUNE)
                 .shuffle(int(10e3)))
input_shape = datasets_info.features['image'].shape
encoded_size = 2
base_depth = 32
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors


In [3]:

def get_prior(num_modes, latent_dim):
    prior = tfd.MixtureSameFamily(
        mixture_distribution=tfd.Categorical(probs=[1 / num_modes,] * num_modes),
        components_distribution=tfd.MultivariateNormalDiag(
            loc=tf.Variable(tf.random.normal(shape=[num_modes, latent_dim])),
            scale_diag=tfp.util.TransformedVariable(tf.Variable(tf.ones(shape=[num_modes, latent_dim])), bijector=tfb.Softplus())
        )
    )
    return prior

prior = get_prior(num_modes=num_mode, latent_dim=encoded_size)
 

In [4]:
encoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=input_shape),
    tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5),
    tfkl.Conv2D(base_depth, 3, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Conv2D(2 * base_depth,3, strides=2,
                padding='same', activation=tf.nn.leaky_relu),
    tfkl.Flatten(),
    tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size),
               activation=None),
    tfpl.MultivariateNormalTriL(
        encoded_size,
        activity_regularizer=tfpl.KLDivergenceRegularizer(prior)),
])

In [5]:
from tensorflow import sigmoid
decoder = tfk.Sequential([
    tfkl.InputLayer(input_shape=[encoded_size]),
    tfkl.Dense(6 * 6 * 32, activation="relu"),
    tfkl.Reshape([6, 6, 32]),
    tfkl.Conv2DTranspose(2 * base_depth, 3, strides=2,
                         padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose( base_depth, 3, strides=2,
                         padding='valid', activation=tf.nn.leaky_relu),
    tfkl.Conv2DTranspose( 1, 2, strides=1,
                         padding='valid', activation=tf.nn.leaky_relu),                   
    tfkl.Flatten(),
    tfpl.IndependentBernoulli(input_shape, tfd.Distribution.sample),
   
])

In [6]:
vae = tfk.Model(inputs=encoder.inputs,
                outputs=decoder(encoder.outputs[0]))

In [None]:
negloglik = lambda x, rv_x: -rv_x.log_prob(x)




vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
            loss=negloglik)

_ = vae.fit(train_dataset,
            epochs=50 )

In [9]:
def _preprocess1(sample):
  image = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
  image = image < tf.random.uniform(tf.shape(image))   # Randomly binarize.
  return image, sample['label']

In [19]:
eval_dataset = (datasets['test']
                .map(_preprocess1)
                .batch(10000)
                .prefetch(tf.data.AUTOTUNE))

In [20]:
xx = next(iter(eval_dataset))
x=xx[0]
y=xx[1]

In [12]:
#####compute test LL####
from scipy.special import logsumexp
T=100
K=10
w=0
stat=[]
for j in range(T):
  a=[]
  for r in range(K):
      a1=negloglik(x,decoder(encoder(x)))
      a1=tf.expand_dims(a1,axis=0)
      a.append(a1)
  aa=tf.concat(a,axis=0)
  stat.append(np.mean(logsumexp(aa,axis=0)-np.log(K)))
  w=w+np.mean(logsumexp(aa,axis=0)-np.log(K))/T
print(w)

In [None]:
###compute statistic T_{\gamma,J}########
get_bin = lambda x, n: format(x, 'b').zfill(n)
d=2
n=10000
####define wavelet basis #####
def Tobinary(x):
   k=get_bin(x,d)
   return([int(numeric_string) for numeric_string in k])
def psi(x,index):
  if (index==1):
    if ((x>=0)&(x<1/2)):
        return(1)
    elif ((x>=1/2)&(x<=1)):
         return(-1)
    else:
      return (0)
  if (index==0):
     if((x>=0)&(x<=1)):
        return(1)
     else:
        return(0)
def psi1(x,j,k,index):
  return(2**(j/2)*psi((2**j)*x-k,index))
def getorder(k,j):
  a=0
  for l in range(d):
    a=a+k[l]*2**(l*j)
  return(int(a))
 
def NormalizeData(data):
    return (data - np.min(data,axis=0)) / (np.max(data,axis=0) - np.min(data,axis=0))
xx = next(iter(eval_dataset))
x=xx[0]
A=encoder(x).sample(1)
Y=tf.squeeze( A, axis=0)
Y=Y.numpy()
X=prior.sample(10000)
X=X.numpy()
Z=NormalizeData(np.vstack([X,Y]))
X=Z[0:n,]
Y=Z[n:(2*n),]
###compute wavelet coefficients#######
import math
J=8
coefX1=[]
coefX2=[]
coefY1=[]
coefY2=[]

for j in range(J):
  coefX1.append(np.zeros((2**(d*j),2**d-1)))
  coefX2.append(np.zeros((2**(d*j),2**d-1)))
  coefY1.append(np.zeros((2**(d*j),2**d-1)))
  coefY2.append(np.zeros((2**(d*j),2**d-1)))
for i in  range(n):
  x=X[i,]
  for j in range(J):
    k=np.zeros(d)
    for l in range(d):
      if((2**j*x[l])%1==0):
          k[l]=2**j*x[l]-1
      else:
          k[l]=math.floor(2**j*x[l])
    k1=getorder(k,j) 
    for l in range(1,2**d):
        a=1
        index=Tobinary(l)
        for ss in range(d):
          a=a*psi1(x[ss],j,k[ss],index[ss])
        coefX1[j][k1,l-1]=coefX1[j][k1,l-1]+a
        coefX2[j][k1,l-1]=coefX2[j][k1,l-1]+a**2

for i in  range(n):
  x=Y[i,]
  for j in range(J):
    k=np.zeros(d)
    for l in range(d):
      if((2**j*x[l])%1==0):
        k[l]=2**j*x[l]-1
      else:
        k[l]=math.floor(2**j*x[l])
    k1=getorder(k,j) 
    for l in range(1,2**d):
      a=1
      index=Tobinary(l)
      for ss in range(d):
        a=a*psi1(x[ss],j,k[ss],index[ss])
      coefY1[j][k1,l-1]=coefY1[j][k1,l-1]+a
      coefY2[j][k1,l-1]=coefY2[j][k1,l-1]+a**2
###compute statistic T_{\gamma,J} based on wavelet coefficients#######
T1=0
for j in range(J):
  for k in range(2**(d*j)):
    for l in range(1,2**d):
        T1=T1+2**(-2*(j+1)*2/4)*(((coefX1[j][k,l-1])**2-coefX2[j][k,l-1])/(n*(n-1))+((coefY1[j][k,l-1])**2-coefY2[j][k,l-1])/(n*(n-1))-2*coefX1[j][k,l-1]*coefY1[j][k,l-1]/(n*n))
  
print(T1)