In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt 
from tensorflow.contrib.slim import fully_connected as fc
from keras import objectives
from sklearn.utils.linear_assignment_ import linear_assignment 
from sklearn.mixture import GaussianMixture 
from sklearn.manifold import TSNE
from datetime import datetime
import gzip
from six.moves import cPickle
import scipy.io as sio
import pickle
from pprint import pprint

### parameters

In [None]:
params = {    
    'beta_arr': np.linspace(1, 5, 500, dtype='float32'),
    'intermediate_dims': [500, 500, 2000],  
    'latent_dim': 10,
    'batch_size': 100,
    'cluster_number': 10 }

params['num_epochs'] = len(params['beta_arr'])

### learning rate parameters

In [None]:
lr_params = {
    'lr_start': 0.002, # start with lr_start 
    'lr_decaysteps': 20, # decrease every lr_decaysteps epochs
    'lr_decay': 0.9 # with a lr_decay decay rate 
}

### logs

In [None]:
logs = {
    'parameters': params,
    'learning_rate_parameters': lr_params,
    'acc': [],
    'kl_div_loss': [],
    'recon_loss': []
}

### data set class

In [None]:
class Dataset:
    
    num_batches = 0    
    batch_index = 0

    def __init__(self, X,  Y):
        self.X = X
        self.Y = Y
        
        self.calculate_num_data()
        self.calculate_input_dim()
   
    def calculate_num_data(self):
        self.num_data = self.X.shape[0] 
    
    def calculate_input_dim(self):
        self.input_dim = self.X.shape[1]
        
    def calculate_num_batches(self, batch_size):        
        self.num_batches = int(np.ceil(self.num_data/batch_size))       
    
    def next_batch(self, batch_size):
        if self.batch_index < (self.num_batches-1):  
            idx = range(self.batch_index * batch_size, (self.batch_index + 1) * batch_size)
            batch_x = self.X[idx, :]
            batch_y = self.Y[idx]
            self.batch_index += 1
        else:
            batch_x = self.X[self.batch_index * batch_size: self.num_data, :]
            batch_y = self.Y[self.batch_index * batch_size: self.num_data]
            self.batch_index = 0
            
        return batch_x, batch_y   

### loading MNIST data

In [None]:
(x_train, y_train), (x_test, y_test) = cPickle.load(gzip.open('../MNIST_dataset/mnist.pkl.gz', 'rb'), encoding="bytes")

# normalize data
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# flatten data
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

# test + train data will be used for unsupervised clusterting
X = np.concatenate((x_train,x_test))
Y = np.concatenate((y_train,y_test))

MNISTdata = Dataset(X, Y)

del x_train, x_test, X, Y

# calculate number of batches
MNISTdata.calculate_num_batches(params['batch_size'])

# update parameters
params['input_dim'] = MNISTdata.input_dim
params['num_data'] = MNISTdata.num_data

### how to update the learning rate at $n$-th epoch

In [None]:
def update_learning_rate(epoch_index):
    return max(lr_params['lr_start'] * lr_params['lr_decay']** np.floor(epoch_index / lr_params['lr_decaysteps']), 0.0005) 

### set the global batch index to 0

In [None]:
def initialize_global_batch_index():
    sess.run(global_batch_index.assign(0)) 

### unsupervised clustering accuracy (ACC) score

In [None]:
# unsupervised clustering accuracy (ACC)
# return accuracy in [0,1]
# higher means better accuracy
def acc_score(y_true, y_class, num_ground_label):  
    
    N = y_true.size # number of data 
    
    assert y_class.size == y_true.size # stop if sizes are not equal     
 
    occurance_matrix = np.zeros([num_ground_label, num_ground_label], dtype=np.int64)  
    
    for i in range(N):
        occurance_matrix[y_class[i], y_true[i]] += 1
        
    # solve the linear assignment problem using the Hungarian algorithm   
    # the problem is also known as maximum weight matching in bipartite graphs
    # the method is also known as the Munkres or Kuhn-Munkres algorithm 
    # this makes sure that each cluster assigns to a label 
    # output = [ [c1,l1], ..., [c10,l10] ]
    ind = linear_assignment(-occurance_matrix) 

    # example: 
    # occurance_matrix = [ [5, 15, 100], [15, 5, 50], [10, 50, 40] ]
    # ind = [ [0, 2], [1, 0], [2, 1]]
    # [occurance_matrix[i, j] for i, j in ind] -> [100, 15, 50] # most frequent entires (if possible)

    return sum([occurance_matrix[i, j] for i, j in ind]) / N

###  visualisation of the latent space, using t-SNE algorithm

In [None]:
def plot_latent_space(space, num_samp):
        
    # colors to be plotted corresponding each label
    colors =  ['blue', 'green', 'red', 'cyan', 'magenta', 'orange', 'black', 'yellow', 'pink', 'brown', 'olivedrab']

    # transform x to u    
    z = sess.run(space, feed_dict={x:MNISTdata.X})

    # take a subset of latent space to be visualised
    z_subset = z[0:num_samp,:]
    label_subset = MNISTdata.Y[0:num_samp]
    
    print('Features are transformed to the latent space and %i samples are taken to be plotted.' % num_samp)
    print('***************************************************************************************************')
    
    centers = sess.run(gmm_mu)
    
    label_tsne = np.concatenate((label_subset,10*np.ones(params['cluster_number'])),axis=0).astype(int)
    Z = np.concatenate((z_subset,centers),axis=0)

    # apply t-SNE algorithm on latent space 
    x_embedded = TSNE(n_components=2).fit_transform(Z)
               
    # plot
    plt.figure(figsize=(10,10))
    for i in range(10):
        ind = np.where(label_tsne == i)
        plt.scatter(x_embedded[ind, 0], x_embedded[ind, 1], c = colors[i] , s=8, label = str(i))

    ind = np.where(label_tsne == 10)
    plt.scatter(x_embedded[ind, 0], x_embedded[ind, 1], c = colors[10] , s=200, label = 'centers', marker = 'P')    
    
    plt.legend(bbox_to_anchor = (1.05, 1), loc = 2, borderaxespad = 0, fontsize = 'xx-large', markerscale = 3)

### encoder: $\mathbf{x} \rightarrow (\boldsymbol\mu_\mathrm{e}, \boldsymbol\Sigma_\mathrm{e})$

In [None]:
def encoder(x_):    
    with tf.variable_scope(name_or_scope = 'encoder', reuse = tf.AUTO_REUSE):
        f1 = fc(x_, params['intermediate_dims'][0], activation_fn=tf.nn.relu, scope='HiddenLayer1')  
        f2 = fc(f1, params['intermediate_dims'][1], activation_fn=tf.nn.relu, scope='HiddenLayer2')
        f3 = fc(f2, params['intermediate_dims'][2], activation_fn=tf.nn.relu, scope='HiddenLayer3')
        mu_enc_ = fc(f3, params['latent_dim'], activation_fn=None, scope='LatentLayerMean') 
        log_sigma_enc_ = fc(f3, params['latent_dim'], activation_fn=None, scope='LatentLayerVariance') 
    return mu_enc_, log_sigma_enc_  

### sampler: $(\boldsymbol\mu_\mathrm{e}, \boldsymbol\Sigma_\mathrm{e}) \rightarrow \mathbf{u}$

In [None]:
def sampler(mu_enc_, log_sigma_enc_):
    eps = tf.random_normal(shape=tf.shape(log_sigma_enc_), mean=0, stddev=1, dtype=tf.float32)
    u_ = mu_enc_ + tf.exp(log_sigma_enc_ / 2) * eps
    return u_

### decoder: $\mathbf{u} \rightarrow \mathbf{\hat{x}}$

In [None]:
def decoder(u_): 
    with tf.variable_scope(name_or_scope = 'decoder', reuse = tf.AUTO_REUSE):
        g1 = fc(u_, params['intermediate_dims'][-1], activation_fn=tf.nn.relu, scope='HiddenLayer1') 
        g2 = fc(g1, params['intermediate_dims'][-2], activation_fn=tf.nn.relu, scope='HiddenLayer2')
        g3 = fc(g2, params['intermediate_dims'][-3], activation_fn=tf.nn.relu, scope='HiddenLayer3')        
        x_hat_ = fc(g3, params['input_dim'], activation_fn=tf.sigmoid, scope='OutputLayer') 
    return x_hat_

### reconstruction loss ( cross entropy between $\mathbf{x}$ and $\mathbf{\hat{x}}$ )

In [None]:
def reconstruction_loss(x_, x_hat_):    
    cross_entropy = params['input_dim'] * objectives.binary_crossentropy(x_, x_hat_)
    recon_loss_ = tf.reduce_mean(cross_entropy) 
    return recon_loss_    

### KL divergence loss, variational approximation

In [None]:
def KLdiv_loss_variational(u_, mu_enc_, log_sigma_enc_, gmm_coef_, gmm_mu_, gmm_sigma_):
    
    sigma_enc_ = tf.exp(log_sigma_enc_)
    
    GMM_COEF = tf.tile(tf.expand_dims(gmm_coef_, 0), [params['batch_size'], 1])
    GMM_MU = tf.tile(tf.expand_dims(gmm_mu_, 0), [params['batch_size'], 1, 1])
    GMM_SIGMA = tf.tile(tf.expand_dims(gmm_sigma_, 0), [params['batch_size'], 1, 1])
    
    MU_ENC = tf.tile(tf.expand_dims(mu_enc_, 1), [1, params['cluster_number'], 1])
    SIGMA_ENC = tf.tile(tf.expand_dims(sigma_enc_, 1), [1, params['cluster_number'], 1])  
    
    KL = 0.5 * tf.reduce_sum(tf.square(MU_ENC-GMM_MU)/GMM_SIGMA + tf.math.log(GMM_SIGMA/SIGMA_ENC) - 1 + SIGMA_ENC/GMM_SIGMA, axis=2)
    
    KL_variational = - tf.reduce_logsumexp(tf.math.log(GMM_COEF) - KL, axis=1)
    
    return tf.reduce_mean(KL_variational)

### calculate assignment probabilities $P_{C|X} = Q_{C|U}$, then do the prediction by taking the argmax

In [None]:
def predict(u_, gmm_coef_, gmm_mu_, gmm_sigma_):  
    
    U = tf.tile(tf.expand_dims(u_, 1), [1, params['cluster_number'], 1])
    GMM_COEF = tf.tile(tf.expand_dims(gmm_coef_, 0), [params['num_data'], 1])
    GMM_MU = tf.tile(tf.expand_dims(gmm_mu_, 0), [params['num_data'], 1, 1])
    GMM_SIGMA = tf.tile(tf.expand_dims(gmm_sigma_, 0), [params['num_data'], 1, 1])
    
    temp1 = tf.math.log(2*np.pi*GMM_SIGMA) + tf.square(U-GMM_MU)/GMM_SIGMA
    temp2 = tf.exp( -0.5 * tf.reduce_sum(temp1, axis=2) ) + 1e-10 
    q_u_c_ = temp2 * GMM_COEF
    
    gamma_ = q_u_c_ / tf.reduce_sum(q_u_c_, axis=1, keepdims=True)

    return tf.argmax(gamma_, axis=1)

### train function

In [None]:
def train(num_epochs, beta_curr):  

    # initialise epoch index
    epoch_index = int( sess.run(global_batch_index) / MNISTdata.num_batches )  
    
    # loop over epoches 
    for _ in range(num_epochs):
            
        # learing rate fot the current epoch      
        lr_rate = update_learning_rate(epoch_index)
        
        # average losses over epoches 
        recon_loss_sum = 0
        kl_div_loss_sum = 0
        loss_sum = 0
        
        # loop over batches  
        for _ in range(MNISTdata.num_batches):
            
            # get the next batch             
            batch_x = MNISTdata.next_batch(100)[0]
            
            # optimize 
            sess.run(optimizer, feed_dict={x: batch_x, beta: beta_curr, learning_rate: lr_rate})
            
            # current loss
            total_loss_curr, recon_loss_curr, kl_div_loss_curr = sess.run([total_loss, recon_loss, kl_div_loss], feed_dict={x: batch_x, beta: beta_curr})  
            
            # add the loss for each batch, to be averaged after the batch loop
            loss_sum += total_loss_curr            
            recon_loss_sum += recon_loss_curr
            kl_div_loss_sum += kl_div_loss_curr 

        # average losses    
        loss_av = loss_sum / MNISTdata.num_batches    
        recon_loss_av = recon_loss_sum / MNISTdata.num_batches    
        kl_div_loss_av = kl_div_loss_sum / MNISTdata.num_batches
    
        # estimate of y
        y_hat_ = sess.run(y_hat, feed_dict={x: MNISTdata.X, learning_rate: lr_rate})
        
        # accuracy score
        acc = acc_score(MNISTdata.Y, y_hat_, params['cluster_number'])          
        
        # update epoch index
        epoch_index = int( sess.run(global_batch_index) / MNISTdata.num_batches )
                
        # save accuracy score
        logs['acc'].append(acc)
        logs['kl_div_loss'].append(kl_div_loss_av)
        logs['recon_loss'].append(recon_loss_av)

        print('-----------------------------------------------------------------------------------------------------------')
        print('[Epoch {0: 3d}] Learning Rate: {1:g} \t Beta: {2:g}'.format(epoch_index, lr_rate, beta_curr)) 
        print('[Losses averaged over batches] \t Total Loss: {0:.4f} \t Reconstruction Loss: {1:.4f} \t KL Loss: {2:.4f}'.format(loss_av, recon_loss_av, kl_div_loss_av)) 
        print('ACC score: % {0:.2f}'.format(100*acc))

### placeholders

In [None]:
x = tf.placeholder(tf.float32, shape=[None, params['input_dim']])
learning_rate = tf.placeholder(tf.float32, shape=()) 
beta = tf.placeholder(tf.float32, shape=()) 

### GMM parameters

In [None]:
gmm_coef_init = np.ones(params['cluster_number'], dtype='float32') / params['cluster_number']
gmm_mu_init = np.random.randn(params['cluster_number'], params['latent_dim']).astype('float32')    
gmm_sigma_init = 0.25*np.abs(np.random.randn(params['cluster_number'], params['latent_dim']).astype('float32'))   

with tf.variable_scope(name_or_scope = 'gmm', reuse = False): 
    gmm_coef = tf.get_variable(name = 'coef', initializer = gmm_coef_init)
    gmm_mu = tf.get_variable(name = 'mu', initializer = gmm_mu_init)
    gmm_sigma = tf.get_variable(name = 'sigma', initializer = gmm_sigma_init)

### define the network

In [None]:
mu_enc, log_sigma_enc = encoder(x)
u = sampler(mu_enc, log_sigma_enc)
x_hat = decoder(u)

### define losses

In [None]:
recon_loss = reconstruction_loss(x, x_hat)
kl_div_loss = KLdiv_loss_variational(u, mu_enc, log_sigma_enc, gmm_coef, gmm_mu, gmm_sigma) 
total_loss = recon_loss + beta * kl_div_loss

### define estimate

In [None]:
y_hat = predict(u, gmm_coef, gmm_mu, gmm_sigma)

### define the optimizer

In [None]:
# trainable variables 
enc_var = [var for var in tf.trainable_variables() if var.name.startswith('encoder')]
dec_var = [var for var in tf.trainable_variables() if var.name.startswith('decoder')]
gmm_var = [var for var in tf.trainable_variables() if var.name.startswith('gmm')]

# batch counter, which increase at each run of the optimizer
global_batch_index = tf.Variable(0, name='global_step', trainable=False)

# optimizer
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(total_loss, var_list = enc_var + dec_var + gmm_var, global_step = global_batch_index) 

### print the model network variables

In [None]:
print('***************************************************************************************************')  
print('Trainable variables:')
pprint(enc_var)
print('---------------------------------------------------------------------------------------------------')  
pprint(dec_var)
print('---------------------------------------------------------------------------------------------------')  
pprint(gmm_var)
print('***************************************************************************************************')

### open the sessoin and initialise

In [None]:
# open the session
sess = tf.Session()

# run the initializer
sess.run(tf.global_variables_initializer())

### initial ACC

In [None]:
# transform x to latent space, x -> u
latent_space = sess.run(u, feed_dict={x: MNISTdata.X})

# apply EM on latent space 
gmm =  GaussianMixture(n_components = params['cluster_number'], covariance_type = 'diag', n_init = 2).fit(latent_space)    

# initial accuracy score
acc = acc_score(MNISTdata.Y, gmm.predict(latent_space), params['cluster_number'])  
    
print('---------------------------------------------------------------------------------------------------------------------------------')
print('Initial Accuracy: ', acc)
print('---------------------------------------------------------------------------------------------------------------------------------')

# intialize GMM parameters from pretrained NN weights

In [None]:
assign_op1 = gmm_coef.assign(gmm.weights_)
assign_op2 = gmm_mu.assign(gmm.means_)
assign_op3 = gmm_sigma.assign(gmm.covariances_)    
sess.run([assign_op1, assign_op2, assign_op3]) 

print('---------------------------------------------------------------------------------------------------------------------------------')
print('GMM parameters are initialized! (VAE)')
print('---------------------------------------------------------------------------------------------------------------------------------')

### display initial GMM parameters

In [None]:
print('---------------------------------------------------------------------------------------------------------------------------------')
print('Initial GMM coefficient:')
print(sess.run(gmm_coef))
print('---------------------------------------------------------------------------------------------------------------------------------')
print('Initial GMM mean:')
print(sess.run(gmm_mu))
print('---------------------------------------------------------------------------------------------------------------------------------')
print('Initial GMM variance:')
print(sess.run(gmm_sigma))
print('---------------------------------------------------------------------------------------------------------------------------------')  

### train the network 

In [None]:
for i in range(params['num_epochs']):    
    train(num_epochs = 1, beta_curr = params['beta_arr'][i])

### save logs

In [None]:
f_dataset_name = 'MNIST_VIB-GMM'
    
current_time = '_' + datetime.now().strftime('%b%d_%H%M%S')

f_nn_name = '_NN'
for i in range(np.shape(params['intermediate_dims'])[0]):
    f_nn_name = f_nn_name + '_' + str(params['intermediate_dims'][i])  
f_nn_name = f_nn_name + '_' + str(params['latent_dim']) 
        
f_name = f_dataset_name + f_nn_name + current_time       
    
# save to pickle file
with open('logs/'+f_name+'.pickle', 'wb') as f:
    pickle.dump(logs, f) 
print('***************************************************************************************************') 
print('Logs stored in the directory: logs/', f_name)    
print('***************************************************************************************************')     

###  visualisation of the latent space, using t-SNE algorithm

In [None]:
plot_latent_space(u, num_samp = 2000)