In [1]:
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import notebook_util
notebook_util.setup_one_gpu()

import tensorflow as tf
import numpy as np
import random
import logpdf

from sklearn.metrics import accuracy_score

random.seed(31415)
np.random.seed(31415)
tf.set_random_seed(31415)

Picking GPU 0


In [None]:
import tensorflow.contrib.layers as tcl
class SSL:
    def __init__(self, x_dim = 50):
        self.name = 'SSL'
        
        # classifier params
        self.hidden_size = 500
        self.num_labels = 10
        
        # encode & decoder params
        self.z_dim = 50
        self.x_dim = x_dim
        self.batch_size   = 500
        self.dataset_size = 50000
        
        # training
        self.lr_decay_factor = 0.95
        self.learning_rate = 1e-3
        
        with tf.variable_scope(self.name):
            self._build_annealing()
            self._build_graph()
            self._build_train_op()
        self.check_parameters()
    
    def global_vars(self):
        var_list = [var for var in tf.global_variables() if self.name in var.name]
        return var_list
    
    def trainable_vars(self):
        var_list = [var for var in tf.trainable_variables() if self.name in var.name]
        return var_list
    
    def check_parameters(self):
        for var in tf.trainable_variables():
            print('%s: %s' % (var.name, var.get_shape()))
        print()
    
    def get_collection(self, collections):
        return [var for var in tf.get_collection(collections)]
    
    def classify(self, x, reuse = False):
        with tf.variable_scope('classifier', reuse = reuse):
            h = tcl.fully_connected(x, self.hidden_size, activation_fn = tf.nn.softplus)
            y = tcl.fully_connected(h, self.num_labels,  activation_fn = tf.nn.softplus)
            return y
    
    def reparameterize(self, mu, logvar):
        std = tf.exp(logvar * 0.5)
        eps = tf.random_normal(tf.shape(mu))
        z = mu + eps * std
        return z
    
    def encode(self, x, y, reuse = False):
        with tf.variable_scope('encoder', reuse = reuse):
            concat = tf.concat([x, y], 1)
            h = tcl.fully_connected(concat, self.hidden_size, activation_fn = tf.nn.softplus)
            mu     = tcl.fully_connected(h, self.z_dim, activation_fn = None)
            logvar = tcl.fully_connected(h, self.z_dim, activation_fn = None)
            z = self.reparameterize(mu, logvar)
            return z, mu, logvar
    
    def decode(self, z, y, reuse = False):
        with tf.variable_scope('decoder', reuse = reuse):
            concat = tf.concat([z, y], 1)
            h = tcl.fully_connected(concat, self.hidden_size, activation_fn = tf.nn.softplus)
            mu     = tcl.fully_connected(h, self.x_dim, activation_fn = None)
            logvar = tcl.fully_connected(h, self.x_dim, activation_fn = None)
            return mu, logvar
        
    def likelihood(self, x, mu_x, logvar_x, y, z, mu_z, logvar_z):
        # uniform prior
        prior_y = (1. / self.num_labels) * tf.ones([tf.shape(x)[0], 10], tf.float32)
        logpy = - tf.nn.softmax_cross_entropy_with_logits(logits = prior_y, labels = y)
        
        kld = tf.reduce_sum(logpdf.KLD(mu_z, logvar_z), 1)
        logpx = tf.reduce_sum(logpdf.gaussian(x, mu_x, logvar_x), 1)
        likelihood = logpx + logpy - kld  
        return likelihood
    
    def prior_likelihood(self):
        likelihood = 0
        vars = self.trainable_vars()
        for var in vars:
            likelihood += tf.reduce_sum(logpdf.std_gaussian(var))
        return likelihood
    
    def _build_graph(self, reuse = False):
        self.x_l = tf.placeholder(tf.float32, shape = (None, self.x_dim))
        self.y_l = tf.placeholder(tf.int32, shape = (None, ))
        self.x_u = tf.placeholder(tf.float32, shape = (None, self.x_dim))
        
        self.y_l_onehot = tcl.one_hot_encoding(self.y_l, num_classes = self.num_labels)
                
        '''
            classifier, labelled
        '''
        scores_l = self.classify(self.x_l, reuse = reuse)
        # loss of classifier
        self.loss_clf = tf.nn.softmax_cross_entropy_with_logits(\
                        logits = scores_l, labels = self.y_l_onehot )

        '''
            labelled data, encoder & decoder
        '''
        z_l, mu_z_l, logvar_z_l = self.encode(self.x_l, self.y_l_onehot, reuse = reuse)
        mu_x_l, logvar_x_l = self.decode(z_l, self.y_l_onehot, reuse = reuse)
        
        # loss of labelled data, refered as L(x, y)
        likelihood_l = self.likelihood(self.x_l, mu_x_l, logvar_x_l, \
                                       self.y_l_onehot, z_l, mu_z_l, logvar_z_l)
        
        '''
            unlabelled data, encoder & decoder
        '''
        for i in range(self.num_labels):
            y_us = i*tf.ones([tf.shape(self.x_u)[0]], tf.int32)
            y_us = tcl.one_hot_encoding(y_us, num_classes = self.num_labels)
            
            z_u, mu_u, logvar_u = self.encode(self.x_u, y_us, reuse = True)
            mu_recon_u, logvar_recon_u = self.decode(z_u, y_us, reuse = True)
            
            _likelihood_u = self.likelihood(self.x_u, mu_recon_u, logvar_recon_u,\
                                y_us, z_u, mu_u, logvar_u)
            _likelihood_u = tf.expand_dims(_likelihood_u, 1)
            
            if i == 0:
                likelihood_u = tf.identity( _likelihood_u )
            else:
                likelihood_u = tf.concat([likelihood_u, _likelihood_u], 1)
            
        # with x & clf, give the dist over y
        scores_u = self.classify(self.x_u, reuse = True)
        y_u_prob = tf.nn.softmax(scores_u, dim=-1)
        
        # add the H(q(y|x))
        likelihood_u = tf.multiply(y_u_prob, likelihood_u + -tf.log(y_u_prob)) 
        likelihood_u = tf.reduce_sum(likelihood_u, 1)

        alpha = 0.1 * self.batch_size
        self.loss_clf = tf.reduce_sum(self.loss_clf, 0)
        self.loss_l = - tf.reduce_sum(likelihood_l, 0)
        self.loss_u = - tf.reduce_sum(likelihood_u, 0)
        self.loss = (self.loss_l + alpha* self.loss_clf + self.loss_u)/self.batch_size

        print('loss_u  : '+str(self.loss_u.shape))
        print('loss_l  : '+str(self.loss_l.shape))
        print('loss_clf: '+str(self.loss_clf.shape))
        
        prior_weight = 1./(self.dataset_size) 
        self.loss_prior = - self.prior_likelihood()
        self.loss += prior_weight * self.loss_prior
        
        self.pred_y = tf.argmax(scores_u, 1)
    
    def _build_train_op(self):
        self.global_step = tf.Variable(0, name="global_step", trainable = False)
        self.lr = tf.Variable(self.learning_rate, trainable=False, 
                    dtype=tf.float32)
        optimizer = tf.train.AdamOptimizer(self.lr)
        grads_and_vars = optimizer.compute_gradients(self.loss)
        def ClipIfNotNone(grad):
            if grad is None:
                return grad
            return tf.clip_by_value(grad, -1, 1)
        capped_gvs = [(ClipIfNotNone(grad), var) for grad, var in grads_and_vars]
        self.train_op = optimizer.apply_gradients(capped_gvs, self.global_step)
        self.lr_decay_op = self.lr.assign(
                self.lr * self.lr_decay_factor)
    
    def lr_decay(self, sess):
        _ = sess.run([self.lr_decay_op])
        
    def _build_annealing(self):
        self.kld_weight = tf.Variable(float(0.0), trainable=False, 
                                    dtype=tf.float32)
        kld_anneal_factor = 0.95
        self.anneal_decay_op = self.kld_weight.assign(
                self.kld_weight * kld_anneal_factor)

    def kld_anneal(self, sess):
        _ = sess.run([self.anneal_decay_op])
    
    def predict(self, x, sess):
        feed_dict = {
            self.x_u: x,
        }
        pred = sess.run([self.pred_y], feed_dict = feed_dict)[0]
        return pred
    
    def optimize(self, sess, x_l, y_l, x_u):
        feed_dict = {
            self.x_l: x_l,
            self.y_l: y_l,

            self.x_u: x_u,
        }
        eval_train = [self.train_op, self.loss]
        eval_loss  = [self.loss_clf, self.loss_l, self.loss_u, self.loss_prior]
        eval_vars = eval_train + eval_loss
        _, loss, loss_clf, loss_l, loss_u, loss_p = sess.run(eval_vars, feed_dict = feed_dict)
        return loss, loss_clf, loss_l, loss_u, loss_p

In [None]:
import time
from utils import (load_mnist_split, create_semisupervised, \
                   ssl_batch_gen, time_since, get_dims, sample)
from vae import VAE

'''
    here are the config params for the experiment.
'''
data_size = 50000
batch_size = 500
n_labelled = 100
n_epoch = 1000

# load data from mnist
train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist_split()
# split training set
X_labelled, Y_labelled, X_unlabelled, Y_unlabelled = create_semisupervised(\
                                                        train_x, train_y, n_labelled)

np.random.seed(31415)
tf.set_random_seed(31415)

data_x_l = X_labelled
data_x_u = X_unlabelled
data_y_l = Y_labelled

# Set config for tensorflow session.
tf_config = tf.ConfigProto(
    device_count = {'GPU': 1}, # single gpu
)
tf_config.gpu_options.allow_growth=True

x_dim = data_x_u.shape[1]
g = tf.Graph()
with g.as_default():
    with tf.Session(config = tf_config) as sess:
        model = SSL(x_dim = x_dim)
        var_list = model.global_vars()
        init_op = tf.variables_initializer(var_list)
        sess.run(init_op)

        start = time.time()
        for epoch in range(10000):
            l_batch_gen, u_batch_gen = ssl_batch_gen(data_x_l, data_y_l, data_x_u, 500, 1)

            for l_batch, u_batch in zip(l_batch_gen, u_batch_gen):
                x_l, y_l = zip(*l_batch)
                x_u = zip(*u_batch)[0]
                loss, loss_clf, loss_l, loss_u, loss_p = model.optimize(sess, x_l, y_l, x_u)
            
            if epoch % 10 == 0:
                pred_valid = model.predict(valid_x, sess)
                pred_test  = model.predict(test_x,  sess)
                accuracy_valid = accuracy_score(valid_y, pred_valid)
                accuracy_test  = accuracy_score(test_y,  pred_test)

                print('time: %s' % time_since(start))
                print('epoch: %d' % epoch)
                print('  labelled loss: %.2f' % loss_l)
                print('unlabelled loss: %.2f' % loss_u)
                print('classifier loss: %.2f' % loss_clf)
                print('     prior loss: %.2f' % loss_p)
                print('     total loss: %.2f' % loss)
                print(' valid accuracy: %.3f' % accuracy_valid)
                print('  test accuracy: %.3f' % accuracy_test)
                print()
            
            if epoch % 10 == 0:
                model.lr_decay(sess)

loss_u  : ()
loss_l  : ()
loss_clf: ()
SSL/classifier/fully_connected/weights:0: (784, 500)
SSL/classifier/fully_connected/biases:0: (500,)
SSL/classifier/fully_connected_1/weights:0: (500, 10)
SSL/classifier/fully_connected_1/biases:0: (10,)
SSL/encoder/fully_connected/weights:0: (794, 500)
SSL/encoder/fully_connected/biases:0: (500,)
SSL/encoder/fully_connected_1/weights:0: (500, 50)
SSL/encoder/fully_connected_1/biases:0: (50,)
SSL/encoder/fully_connected_2/weights:0: (500, 50)
SSL/encoder/fully_connected_2/biases:0: (50,)
SSL/decoder/fully_connected/weights:0: (60, 500)
SSL/decoder/fully_connected/biases:0: (500,)
SSL/decoder/fully_connected_1/weights:0: (500, 784)
SSL/decoder/fully_connected_1/biases:0: (784,)
SSL/decoder/fully_connected_2/weights:0: (500, 784)
SSL/decoder/fully_connected_2/biases:0: (784,)

time: 0 m 4 s
epoch: 0
  labelled loss: -406.02
unlabelled loss: -187689.25
classifier loss: 8.96
     prior loss: 1527875.50
     total loss: -344.74
 valid accuracy: 0.101
 