# ChoiceNet for Classification

In [1]:
import nbloader,os,warnings
warnings.filterwarnings("ignore") 
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
import tensorflow as tf
import tensorflow.contrib.slim as slim
from sklearn.utils import shuffle
from util import gpusession,create_gradient_clipping,load_mnist_with_noise,print_n_txt,mixup
%matplotlib inline  
%config InlineBackend.figure_format = 'retina'
if __name__ == "__main__":
    print ("TensorFlow version is [%s]."%(tf.__version__))

TensorFlow version is [1.4.1].


### Define ChoiceNet Class

In [2]:
class choiceNet_cls_class(object):
    def __init__(self,_name='',_xdim=[28,28,1],_ydim=10,_hdims=[64,64],_filterSizes=[3,3],_max_pools=[2,2],_feat_dim=128
                 ,_kmix=5,_actv=tf.nn.relu,_bn=slim.batch_norm
                 ,_rho_ref_train=0.95,_tau_inv=1e-4,_pi1_bias=0.0,_logSigmaZval=-2
                 ,_logsumexp_coef=0.1,_kl_reg_coef=0.1,_l2_reg_coef=1e-5
                 ,_momentum = 0.5
                 ,_USE_INPUT_BN=False,_USE_RESNET=False,_USE_GAP=False,_USE_KENDALL_LOSS=False,_USE_SGD=False
                 ,_USE_MIXUP=False
                 ,_GPU_ID=0,_VERBOSE=True):
        self.name = _name
        self.xdim = _xdim
        self.ydim = _ydim
        self.hdims = _hdims
        self.filterSizes = _filterSizes
        self.max_pools = _max_pools
        self.feat_dim = _feat_dim
        self.kmix = _kmix
        self.actv = _actv 
        self.bn   = _bn # slim.batch_norm / None
        self.rho_ref_train = _rho_ref_train
        self.tau_inv = _tau_inv
        self.pi1_bias = _pi1_bias
        self.logSigmaZval = _logSigmaZval
        self.logsumexp_coef = _logsumexp_coef
        self.kl_reg_coef = _kl_reg_coef
        self.l2_reg_coef = _l2_reg_coef
        self.momentum = _momentum
        self.USE_INPUT_BN = _USE_INPUT_BN
        self.USE_RESNET = _USE_RESNET
        self.USE_GAP = _USE_GAP
        self.USE_KENDALL_LOSS = _USE_KENDALL_LOSS
        self.USE_SGD = _USE_SGD
        self.USE_MIXUP = _USE_MIXUP
        self.GPU_ID = (int)(_GPU_ID)
        self.VERBOSE = _VERBOSE
        with tf.device('/device:GPU:%d'%(self.GPU_ID)):
            # Build model
            self.build_model()
            # Build graph
            self.build_graph()
            # Check parameters
            self.check_params()
        
    def build_model(self):
        _xdim = self.xdim[0]*self.xdim[1]*self.xdim[2] # Total dimension
        self.x = tf.placeholder(dtype=tf.float32,shape=[None,_xdim],name='x') # Input [None x xdim]
        self.t = tf.placeholder(dtype=tf.float32,shape=[None,self.ydim],name='t') # Output [None x ydim]
        self.kp = tf.placeholder(dtype=tf.float32,shape=[],name='kp') # Keep probability 
        self.lr = tf.placeholder(dtype=tf.float32,shape=[],name='lr') # Learning rate
        self.is_training = tf.placeholder(dtype=tf.bool,shape=[]) # Training flag
        self.rho_ref = tf.placeholder(dtype=tf.float32,shape=[],name='rho_ref') 
        # Initailizers
        self.fully_init  = tf.random_normal_initializer(stddev=0.01)
        self.bias_init   = tf.constant_initializer(0.)
        self.bn_init     = {'beta': tf.constant_initializer(0.),
                           'gamma': tf.random_normal_initializer(1., 0.01)}
        self.bn_params   = {'is_training':self.is_training,'decay':0.9,'epsilon':1e-5,
                           'param_initializers':self.bn_init,'updates_collections':None}

        # Build graph
        with tf.variable_scope(self.name,reuse=False) as scope:
            with slim.arg_scope([slim.fully_connected],activation_fn=self.actv,
                                weights_initializer=self.fully_init,biases_initializer=self.bias_init,
                                normalizer_fn=self.bn,normalizer_params=self.bn_params,
                                weights_regularizer=None):            
                
                # List of features
                self.layers = []
                self.layers.append(self.x)

                # Reshape input 
                _net = tf.reshape(self.x,[-1]+self.xdim) 
                self.layers.append(_net)

                # Input normalization 
                if self.USE_INPUT_BN:
                    _net = slim.batch_norm(_net,param_initializers=self.bn_init,is_training=self.is_training,updates_collections=None)
                
                for hidx,hdim in enumerate(self.hdims): # For all layers
                    fs = self.filterSizes[hidx]
                    if self.USE_RESNET: # Use residual connection 
                        cChannelSize = _net.get_shape()[3] # Current channel size
                        if cChannelSize == hdim:
                            _identity = _net
                        else: # Expand dimension if required 
                            _identity = slim.conv2d(_net,hdim,[fs,fs],padding='SAME',activation_fn=None 
                                                  , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
                                                  , normalizer_fn       = self.bn
                                                  , normalizer_params   = self.bn_params
                                                  , scope='identity_%d'%(hidx))
                        # First conv 
                        _net = slim.conv2d(_net,hdim,[fs,fs],padding='SAME'
                                         , activation_fn       = None 
                                         , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
                                         , normalizer_fn       = self.bn
                                         , normalizer_params   = self.bn_params
                                         , scope='res_a_%d'%(hidx))
                        # Relu
                        _net = self.actv(_net)
                        self.layers.append(_net) # Append to layers
                        # Second conv
                        _net = slim.conv2d(_net,hdim,[fs,fs],padding='SAME'
                                         , activation_fn       = None
                                         , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
                                         , normalizer_fn       = self.bn
                                         , normalizer_params   = self.bn_params
                                         , scope='res_b_%d'%(hidx))
                        # Skip connection
                        _net = _net + _identity
                        # Relu
                        _net = self.actv(_net)
                        self.layers.append(_net) # Append to layers
                    else: # Without residual connection
                        _net = slim.conv2d(_net,hdim,[fs,fs],padding='SAME'
                                         , activation_fn       = self.actv
                                         , weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
                                         , normalizer_fn       = self.bn
                                         , normalizer_params   = self.bn_params
                                         , scope='conv_%d'%(hidx))
                    # Max pooling (if required)
                    max_pool = self.max_pools[hidx]
                    if max_pool > 1:
                        _net = slim.max_pool2d(_net,[max_pool,max_pool],scope='pool_%d'%(hidx))
                        self.layers.append(_net) # Append to layers
                        
                if self.USE_GAP: # Global average pooling 
                    _net = tf.reduce_mean(_net,[1,2]) # [N x R]
                    self.layers.append(_net) # Append to layers
                    # Optional dense layer after GAP (this increases performance)
                    _net = slim.fully_connected(_net,self.feat_dim,scope='gap_fc') # [N x Q]
                    # Feature
                    self.feat = _net # [N x Q]
                else:
                    # Flatten 
                    _net = slim.flatten(_net, scope='flatten')
                    self.layers.append(_net) # Append to layers
                    # Dense
                    _net = slim.fully_connected(_net,self.feat_dim,scope='fc')
                    self.layers.append(_net) # Append to layers
                    # Feature
                    self.feat = _net # [N x Q]
                
                # Feature to K rhos
                _rho_raw = slim.fully_connected(self.feat,self.kmix,scope='rho_raw')
                # self.rho_temp = tf.nn.tanh(_rho_raw) # [N x K] # Regression
                self.rho_temp = tf.nn.sigmoid(_rho_raw) # [N x K] # Classification
                self.rho = tf.concat([self.rho_temp[:,0:1]*0.0+self.rho_ref,self.rho_temp[:,1:]]
                                     ,axis=1) # [N x K]
                
                # Sampler variables
                _Q = self.feat.get_shape().as_list()[1] # Feature dimension
                self.Q = _Q
                self.muW = tf.get_variable(name='muW',shape=[_Q,self.ydim],
                                          initializer=tf.random_normal_initializer(stddev=0.1)
                                           ,dtype=tf.float32) # [Q x D]
                self.logSigmaW = tf.get_variable(name='logSigmaW'
                                        ,shape=[_Q,self.ydim]
                                        ,initializer=tf.constant_initializer(-3.0)
                                        ,dtype=tf.float32) # [Q x D]
                self.muZ = tf.constant(np.zeros((_Q,self.ydim))
                                        ,name='muZ',dtype=tf.float32) # [Q x D]
                self.logSigmaZ = tf.constant(self.logSigmaZval*np.ones((_Q,self.ydim)) # -2.0 <== Important Heuristics
                                        ,name='logSigmaZ',dtype=tf.float32) # [Q x D]
                
                # Make sampler
                _N = tf.shape(self.x)[0]
                _muW_tile = tf.tile(self.muW[tf.newaxis,:,:]
                                    ,multiples=[_N,1,1]) # [N x Q x D]
                _sigmaW_tile = tf.exp(tf.tile(self.logSigmaW[tf.newaxis,:,:]
                                              ,multiples=[_N,1,1])) # [N x Q x D]
                _muZ_tile = tf.tile(self.muZ[tf.newaxis,:,:]
                                    ,multiples=[_N,1,1]) # [N x Q x D]
                _sigmaZ_tile = tf.exp(tf.tile(self.logSigmaZ[tf.newaxis,:,:]
                                              ,multiples=[_N,1,1])) # [N x Q x D]
                samplerList = []
                for jIdx in range(self.kmix): # For all K mixtures
                    _rho_j = self.rho[:,jIdx:jIdx+1] # [N x 1] 
                    _rho_tile = tf.tile(_rho_j[:,:,tf.newaxis]
                                        ,multiples=[1,_Q,self.ydim]) # [N x Q x D]
                    _epsW = tf.random_normal(shape=[_N,_Q,self.ydim],mean=0,stddev=1
                                             ,dtype=tf.float32) # [N x Q x D]
                    _W = _muW_tile + tf.sqrt(_sigmaW_tile)*_epsW # [N x Q x D]
                    _epsZ = tf.random_normal(shape=[_N,_Q,self.ydim]
                                             ,mean=0,stddev=1,dtype=tf.float32) # [N x Q x D]
                    _Z = _muZ_tile + tf.sqrt(_sigmaZ_tile)*_epsZ # [N x Q x D]
                    # Append to list
                    _Y = _rho_tile*_muW_tile + (1.0-_rho_tile**2) \
                        *(_rho_tile*tf.sqrt(_sigmaZ_tile)/tf.sqrt(_sigmaW_tile) \
                              *(_W-_muW_tile)+tf.sqrt(1-_rho_tile**2)*_Z)
                    samplerList.append(_Y) # Append 
                # Make list to tensor
                WlistConcat = tf.convert_to_tensor(samplerList) # K*[N x Q x D] => [K x N x Q x D]
                self.wSample = tf.transpose(WlistConcat,perm=[1,3,0,2]) # [N x D x K x Q]

                # K mean mixtures [N x D x K]
                _wTemp = tf.reshape(self.wSample
                                ,shape=[_N,self.kmix*self.ydim,_Q]) # [N x KD x Q]
                _featRsh = tf.reshape(self.feat,shape=[_N,_Q,1]) # [N x Q x 1]
                _mu = tf.matmul(_wTemp,_featRsh) # [N x KD x Q] x [N x Q x 1] => [N x KD x 1]
                self.mu = tf.reshape(_mu,shape=[_N,self.ydim,self.kmix]) # [N x D x K]
                
                # (optional) Add bias to mu
                USE_BIAS = False
                if USE_BIAS:
                    self.muBias = tf.get_variable(name='muBias'
                                            ,shape=[self.ydim]
                                            ,initializer=tf.constant_initializer(0.0)
                                            ,dtype=tf.float32) # [D]
                    muBias_tile = tf.tile(self.muBias[tf.newaxis,:,tf.newaxis]
                                        ,multiples=[_N,1,self.kmix]) # [N x D x K]
                    self.mu += muBias_tile

                # K var mixtures [N x D x K]
                _logvar_raw = slim.fully_connected(self.feat,self.ydim,scope='var_raw') # [N x D]
                _var_raw = tf.exp(_logvar_raw) # [N x D]
                _var_tile = tf.tile(_var_raw[:,:,tf.newaxis]
                                    ,multiples=[1,1,self.kmix]) # [N x D x K]
                _rho_tile = tf.tile(self.rho[:,tf.newaxis,:]
                                    ,multiples=[1,self.ydim,1]) # [N x D x K]
                _tau_inv = self.tau_inv
                self.var = (1.0-_rho_tile**2)*_var_tile + _tau_inv # [N x D x K]
                
                # Weight allocation probability pi [N x K]
                _pi_logits = slim.fully_connected(self.feat,self.kmix
                                                  ,scope='pi_logits') # [N x K]
                self.pi_temp = tf.nn.softmax(_pi_logits,dim=1) # [N x K]
                # Some heuristics to ensure that pi_1(x) is high enough
                self.pi_temp = tf.concat([self.pi_temp[:,0:1]+self.pi1_bias
                                          ,self.pi_temp[:,1:]],axis=1) # [N x K]
                self.pi = tf.nn.softmax(self.pi_temp,dim=1) # [N x K]
                
                # Intermediate tensors
                self.tensors = [self.x,self.feat,self.rho,self.mu,self.var,self.pi] 
    
    # Build graph
    def build_graph(self):
        # MDN loss
        _N = tf.shape(self.x)[0]
        t,mu,var = self.t,self.mu,self.var
        pi = self.pi # [N x K]
        yhat = mu + tf.sqrt(var)*tf.random_normal(shape=[_N,self.ydim,self.kmix]) # Sampled y [N x D x K]
        tTile = tf.tile(t[:,:,tf.newaxis],[1,1,self.kmix]) # Target [N x D x K]
        piTile = tf.tile(pi[:,tf.newaxis,:],[1,self.ydim,1]) # piTile: [N x D x K]
        
        if self.USE_KENDALL_LOSS: # Alex Kendal's loss extended to a mixture model
            self._loss_fit = tf.reduce_sum(-piTile*yhat*tTile,axis=[1,2]) # [N]
            self.loss_fit = tf.reduce_mean(self._loss_fit) # [1]
            
            self._loss_reg = pi*tf.reduce_logsumexp(yhat,axis=[1]) # [N x K]
            self.__loss_reg = tf.reduce_sum(self._loss_reg,axis=[1]) # [N]
            self.loss_reg = tf.reduce_mean(self.__loss_reg) # [1] 
            
            # self._loss_reg = tf.reduce_logsumexp(piTile*yhat,axis=[1,2]) # [N]
            # self.loss_reg = tf.reduce_mean(self._loss_reg) # [1]
        else: # Mine (normalized x)
            self.yhat_normalized = tf.nn.softmax(yhat,dim=1) # [N x D x K]
            self._loss_fit = tf.reduce_sum(-piTile*self.yhat_normalized*tTile,axis=[1,2]) # [N]
            self.loss_fit = tf.reduce_mean(self._loss_fit) # [1]
            
            self._loss_reg = pi*tf.reduce_logsumexp(yhat,axis=[1]) # [N x K]
            self.__loss_reg = tf.reduce_sum(self._loss_reg,axis=[1]) # [N]
            self.loss_reg = self.logsumexp_coef*tf.reduce_mean(self.__loss_reg) # [1] 
            
            # self._loss_reg = self.logsumexp_coef*tf.reduce_logsumexp(piTile*yhat,axis=[1,2]) # [N]
            # self.loss_reg = tf.reduce_mean(self._loss_reg) # [1]
        
        # KL-divergence regularizer 
        _eps = 1e-8
        self._kl_reg = self.kl_reg_coef*tf.reduce_sum(-self.rho
                        *(tf.log(self.pi+_eps)-tf.log(self.rho+_eps)),axis=1) # (N)
        self.kl_reg = tf.reduce_mean(self._kl_reg) # (1)
        
        # Weight decay 
        # _g_vars = tf.global_variables()
        _g_vars = tf.trainable_variables()
        self.c_vars = [var for var in _g_vars if '%s/'%(self.name) in var.name]
        self.l2_reg = self.l2_reg_coef*tf.reduce_sum(tf.stack([tf.nn.l2_loss(v) for v in self.c_vars])) # [1]

        # Total loss
        self.loss_total = tf.reduce_mean(self.loss_fit+self.loss_reg+self.kl_reg+self.l2_reg) # [1]
        # Optimizer
        GRAD_CLIP = True
        if GRAD_CLIP: # Gradient clipping
            if self.USE_SGD:
                # _optm = tf.train.GradientDescentOptimizer(learning_rate=self.lr)
                _optm = tf.train.MomentumOptimizer(learning_rate=self.lr,momentum=self.momentum)
            else:
                _optm = tf.train.AdamOptimizer(learning_rate=self.lr
                                               ,beta1=0.9,beta2=0.999,epsilon=1e-6)
            self.optm = create_gradient_clipping(self.loss_total
                                            ,_optm,tf.trainable_variables(),clipVal=1.0)
        else:
            if self.USE_SGD:
                self.optm = tf.train.GradientDescentOptimizer(learning_rate=self.lr).minimize(self.loss_total) 
            else:
                self.optm = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss_total) 
        
        # Compute accuray 
        maxIdx = tf.argmax(input=pi,axis=1, output_type=tf.int32) # Argmax Index [N]
        maxIdx = 0*tf.ones_like(maxIdx)
        coords = tf.stack([tf.transpose(gv) for gv in tf.meshgrid(tf.range(_N),tf.range(self.ydim))] + 
                          [tf.reshape(tf.tile(maxIdx[:,tf.newaxis],[1,self.ydim]),shape=(_N,self.ydim))]
                          ,axis=2) # [N x D x 3]
        mu_bar = tf.gather_nd(mu,coords) # [N x D]
        _corr = tf.equal(tf.argmax(mu_bar, 1), tf.argmax(self.t, 1))    
        self.accr = tf.reduce_mean(tf.cast(_corr,tf.float32)) # Accuracy
        
    # Check parameters
    def check_params(self):
        _g_vars = tf.global_variables()
        self.g_vars = [var for var in _g_vars if '%s/'%(self.name) in var.name]
        if self.VERBOSE:
            print ("==== Global Variables ====")
        for i in range(len(self.g_vars)):
            w_name  = self.g_vars[i].name
            w_shape = self.g_vars[i].get_shape().as_list()
            if self.VERBOSE:
                print (" [%02d] Name:[%s] Shape:[%s]" % (i,w_name,w_shape))
        # Print layers
        if self.VERBOSE:
            print ("Layers:")
            nLayers = len(self.layers)
            for i in range(nLayers):
                print ("[%02d/%d] %s %s"%(i,nLayers,self.layers[i].name,self.layers[i].shape))
    
    # Sampler
    def sampler(self,_sess,_x,n_samples=10):
        pi, mu, var = _sess.run([self.pi, self.mu, self.var],
                                feed_dict={self.x:_x,self.kp:1.0,self.is_training:False
                                          ,self.rho_ref:1.0}) #
        n_points = _x.shape[0]
        _y_sampled = np.zeros([n_points,self.ydim,n_samples])
        for i in range(n_points):
            for j in range(n_samples):
                k = np.random.choice(self.kmix,p=pi[i,:])
                k = 0
                _y_sampled[i,:,j] = mu[i,:,k] # + np.random.randn(1,self.ydim)*np.sqrt(var[i,:,k])
        return _y_sampled
    
    # Save 
    def save(self,_sess,_savename=None):
        """ Save name """
        if _savename==None:
            _savename='../net/net_%s.npz'%(self.name)
        """ Get global variables """
        self.g_wnames,self.g_wvals,self.g_wshapes = [],[],[]
        for i in range(len(self.g_vars)):
            curr_wname = self.g_vars[i].name
            curr_wvar  = [v for v in tf.global_variables() if v.name==curr_wname][0]
            curr_wval  = _sess.run(curr_wvar)
            
            curr_wval_sqz = curr_wval
            # curr_wval_sqz  = curr_wval.squeeze() # ???
            curr_wval_sqz = np.asanyarray(curr_wval_sqz,order=(1,-1))
            
            self.g_wnames.append(curr_wname)
            self.g_wvals.append(curr_wval_sqz)
            self.g_wshapes.append(curr_wval.shape)
        """ Save """
        np.savez(_savename,g_wnames=self.g_wnames,g_wvals=self.g_wvals,g_wshapes=self.g_wshapes)
        if self.VERBOSE:
            print ("[%s] Saved. Size is [%.4f]MB" % 
                   (_savename,os.path.getsize(_savename)/1000./1000.))
        
    # Restore
    def restore(self,_sess,_loadname=None):
        if _loadname==None:
            _loadname='../net/net_%s.npz'%(self.name)
        l = np.load(_loadname)
        g_wnames = l['g_wnames']
        g_wvals  = l['g_wvals']
        g_wshapes = l['g_wshapes']
        for widx,wname in enumerate(g_wnames):
            curr_wvar  = [v for v in tf.global_variables() if v.name==wname][0]
            _sess.run(tf.assign(curr_wvar,g_wvals[widx].reshape(g_wshapes[widx])))
        if self.VERBOSE:
            print ("Weight restored from [%s] Size is [%.4f]MB" % 
                   (_loadname,os.path.getsize(_loadname)/1000./1000.))
    
    # Train 
    def train(self,_sess,_trainimg,_trainlabel,_testimg,_testlabel,_valimg,_vallabel
              ,_maxEpoch=10,_batchSize=256,_lr=1e-3,_kp=0.8
              ,_LR_SCHEDULE=False,_PRINT_EVERY=10,_SAVE_BEST=True,_DO_AUGMENTATION=False,_VERBOSE_TRAIN=True):
        tf.set_random_seed(0)
        nTrain,nVal,nTest = _trainimg.shape[0],_valimg.shape[0],_testimg.shape[0]
        txtName = ('../res/res_%s.txt'%(self.name))
        f = open(txtName,'w') # Open txt file
        print_n_txt(_f=f,_chars='Text name: '+txtName)
        print_period=max(1,_maxEpoch//_PRINT_EVERY)
        maxIter,maxValAccr,maxTestAccr = max(nTrain//_batchSize,1),0.0,0.0
        for epoch in range(_maxEpoch+1): # For every epoch 
            _trainimg,_trainlabel = shuffle(_trainimg,_trainlabel) 
            for iter in range(maxIter): # For every iteration in one epoch
                start,end = iter*_batchSize,(iter+1)*_batchSize
                # Learning rate scheduling
                if _LR_SCHEDULE:
                    if epoch < 0.5*_maxEpoch:
                        _lr_use = _lr
                    elif epoch < 0.75*_maxEpoch:
                        _lr_use = _lr/10.0
                    else:
                        _lr_use = _lr/100.0
                else:
                    _lr_use = _lr
                if _DO_AUGMENTATION:
                    trainImgBatch = augment_img(_trainimg[start:end,:],self.xdim) 
                else:
                    trainImgBatch = _trainimg[start:end,:]
                if self.USE_MIXUP:
                    xBatch = trainImgBatch
                    tBatch = _trainlabel[start:end,:]
                    xBatch,tBatch = mixup(xBatch,tBatch,1/2)
                else:
                    xBatch = trainImgBatch
                    tBatch = _trainlabel[start:end,:]
                feeds = {self.x:xBatch,self.t:tBatch
                         ,self.rho_ref:self.rho_ref_train,self.kp:_kp,self.lr:_lr_use,self.is_training:True}
                _sess.run(self.optm,feed_dict=feeds)
            # Print training losses, training accuracy, validation accuracy, and test accuracy
            if (epoch%print_period)==0 or (epoch==(_maxEpoch)):
                batchSize4print = 512 
                # Compute train loss and accuracy
                maxIter4print = max(nTrain//batchSize4print,1)
                trainLoss,trainAccr,fit,reg,kl,l2,nTemp = 0,0,0,0,0,0,0
                for iter in range(maxIter4print):
                    start,end = iter*batchSize4print,(iter+1)*batchSize4print
                    feeds_train = {self.x:_trainimg[start:end,:],self.t:_trainlabel[start:end,:]
                                   ,self.rho_ref:1.0,self.kp:1.0,self.is_training:False}
                    opers_train = [self.loss_total,self.accr,self.loss_fit,self.loss_reg,self.kl_reg,self.l2_reg]
                    _trainLoss,_trainAccr,_fit,_reg,_kl,_l2 = _sess.run(opers_train,feed_dict=feeds_train) 
                    _nTemp = end-start; nTemp+=_nTemp
                    trainLoss+=(_nTemp*_trainLoss);trainAccr+=(_nTemp*_trainAccr)
                    fit+=(_nTemp*_fit);reg+=(_nTemp*_reg);kl+=(_nTemp*_kl);l2+=(_nTemp*_l2)
                trainLoss/=nTemp;trainAccr/=nTemp
                fit/=nTemp;reg/=nTemp;kl/=nTemp;l2/=nTemp;
                # Compute validation loss and accuracy
                maxIter4print = max(nVal//batchSize4print,1)
                valLoss,valAccr,nTemp = 0,0,0
                for iter in range(maxIter4print):
                    start,end = iter*batchSize4print,(iter+1)*batchSize4print
                    feeds_val = {self.x:_valimg[start:end,:],self.t:_vallabel[start:end,:]
                                 ,self.rho_ref:1.0,self.kp:1.0,self.is_training:False}
                    _valLoss,_valAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_val) 
                    _nTemp = end-start; nTemp+=_nTemp
                    valLoss+=(_nTemp*_valLoss); valAccr+=(_nTemp*_valAccr)
                valLoss/=nTemp;valAccr/=nTemp 
                # Compute test loss and accuracy
                maxIter4print = max(nTest//batchSize4print,1)
                testLoss,testAccr,nTemp = 0,0,0
                for iter in range(maxIter4print):
                    start,end = iter*batchSize4print,(iter+1)*batchSize4print
                    feeds_test = {self.x:_testimg[start:end,:],self.t:_testlabel[start:end,:]
                                  ,self.rho_ref:1.0,self.kp:1.0,self.is_training:False}
                    _testLoss,_testAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_test) 
                    _nTemp = end-start; nTemp+=_nTemp
                    testLoss+=(_nTemp*_testLoss); testAccr+=(_nTemp*_testAccr)
                testLoss/=nTemp;testAccr/=nTemp
                # Compute max val accr 
                if valAccr > maxValAccr:
                    maxValAccr = valAccr
                    maxTestAccr = testAccr
                    if _SAVE_BEST: self.save(_sess) 
                strTemp = (("[%02d/%d] [Loss] train:%.3f(f:%.3f+r:%.3f+k:%.3f+l:%.3f) val:%.3f test:%.3f"
                            +" [Accr] train:%.1f%% val:%.1f%% test:%.1f%% maxVal:%.1f%% maxTest:%.1f%%")
                       %(epoch,_maxEpoch,trainLoss,fit,reg,kl,l2,valLoss,testLoss
                         ,trainAccr*100,valAccr*100,testAccr*100,maxValAccr*100,maxTestAccr*100))
                print_n_txt(_f=f,_chars=strTemp,_DO_PRINT=_VERBOSE_TRAIN)
        # Done 
        print ("Training finished.")
    
    # Test
    def test(self,_sess,_trainimg,_trainlabel,_testimg,_testlabel,_valimg,_vallabel):
        nTrain,nVal,nTest = _trainimg.shape[0],_valimg.shape[0],_testimg.shape[0]
        # Check accuracies (train, val, and test)
        batchSize4print = 512 
        # Compute train loss and accuracy
        maxIter4print = max(nTrain//batchSize4print,1)
        trainLoss,trainAccr,nTemp = 0,0,0
        for iter in range(maxIter4print):
            start,end = iter*batchSize4print,(iter+1)*batchSize4print
            feeds_train = {self.x:_trainimg[start:end,:],self.t:_trainlabel[start:end,:]
                           ,self.rho_ref:1.0,self.kp:1.0,self.is_training:False}
            _trainLoss,_trainAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_train) 
            _nTemp = end-start; nTemp+=_nTemp
            trainLoss+=(_nTemp*_trainLoss); trainAccr+=(_nTemp*_trainAccr)
        trainLoss/=nTemp;trainAccr/=nTemp
        # Compute validation loss and accuracy
        maxIter4print = max(nVal//batchSize4print,1)
        valLoss,valAccr,nTemp = 0,0,0
        for iter in range(maxIter4print):
            start,end = iter*batchSize4print,(iter+1)*batchSize4print
            feeds_val = {self.x:_valimg[start:end,:],self.t:_vallabel[start:end,:]
                         ,self.rho_ref:1.0,self.kp:1.0,self.is_training:False}
            _valLoss,_valAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_val) 
            _nTemp = end-start; nTemp+=_nTemp
            valLoss+=(_nTemp*_valLoss); valAccr+=(_nTemp*_valAccr)
        valLoss/=nTemp;valAccr/=nTemp
        # Compute test loss and accuracy
        maxIter4print = max(nTest//batchSize4print,1)
        testLoss,testAccr,nTemp = 0,0,0
        for iter in range(maxIter4print):
            start,end = iter*batchSize4print,(iter+1)*batchSize4print
            feeds_test = {self.x:_testimg[start:end,:],self.t:_testlabel[start:end,:]
                          ,self.rho_ref:1.0,self.kp:1.0,self.is_training:False}
            _testLoss,_testAccr = _sess.run([self.loss_total,self.accr],feed_dict=feeds_test) 
            _nTemp = end-start; nTemp+=_nTemp
            testLoss+=(_nTemp*_testLoss); testAccr+=(_nTemp*_testAccr)
        testLoss/=nTemp;testAccr/=nTemp
        strTemp = (("[%s] [Loss] train:%.3f val:%.3f test:%.3f"
                    +" [Accr] train:%.3f%% val:%.3f%% test:%.3f%%")
               %(self.name,trainLoss,valLoss,testLoss,trainAccr*100,valAccr*100,testAccr*100))
        print(strTemp)
        
if __name__ == "__main__":
    print ("choiceNet_cls_class defined.")    

choiceNet_cls_class defined.


### Train ChoiceNet on MNIST

In [3]:
# Setting
def get_mnist_config():
    trainimg,trainlabel,testimg,testlabel,valimg,vallabel \
        = load_mnist_with_noise(_errType='rp',_outlierRatio=0.9,_seed=0)
    xdim,ydim,hdims,filterSizes,max_pools,feat_dim = [28,28,1],10,[64,64],[3,3],[2,2],256
    kmix,actv,bn,VERBOSE = 10,tf.nn.relu,slim.batch_norm,True
    rho_ref_train,tau_inv,pi1_bias,logSigmaZval = 0.95,1e-4,0.0,-2
    logsumexp_coef,kl_reg_coef,l2_reg_coef = 1e-2,1e-4,1e-5
    USE_INPUT_BN,USE_RESNET,USE_GAP,USE_KENDALL_LOSS = False,True,False,False
    USE_MIXUP = False
    return trainimg,trainlabel,testimg,testlabel,valimg,vallabel, \
        xdim,ydim,hdims,filterSizes,max_pools,feat_dim, \
        kmix,actv,bn,VERBOSE, \
        rho_ref_train,tau_inv,pi1_bias,logSigmaZval,logsumexp_coef,kl_reg_coef,l2_reg_coef, \
        USE_INPUT_BN,USE_RESNET,USE_GAP,USE_KENDALL_LOSS,USE_MIXUP

In [4]:
if __name__ == "__main__":
    trainimg,trainlabel,testimg,testlabel,valimg,vallabel, \
    xdim,ydim,hdims,filterSizes,max_pools,feat_dim, \
    kmix,actv,bn,VERBOSE, \
    rho_ref_train,tau_inv,pi1_bias,logSigmaZval,logsumexp_coef,kl_reg_coef,l2_reg_coef, \
    USE_INPUT_BN,USE_RESNET,USE_GAP,USE_KENDALL_LOSS,USE_MIXUP = get_mnist_config()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    CN = choiceNet_cls_class(_name='basic_choicenet_mnist'
                          ,_xdim=xdim,_ydim=ydim,_hdims=hdims,_filterSizes=filterSizes
                          ,_max_pools=max_pools,_feat_dim=feat_dim,_kmix=kmix,_actv=actv,_bn=slim.batch_norm
                          ,_rho_ref_train=rho_ref_train,_tau_inv=tau_inv,_pi1_bias=pi1_bias,_logSigmaZval=logSigmaZval
                          ,_logsumexp_coef=logsumexp_coef,_kl_reg_coef=kl_reg_coef,_l2_reg_coef=l2_reg_coef
                          ,_USE_INPUT_BN=USE_INPUT_BN,_USE_RESNET=USE_RESNET,_USE_GAP=USE_GAP,_USE_KENDALL_LOSS=USE_KENDALL_LOSS
                          ,_USE_MIXUP=USE_MIXUP,_GPU_ID=0,_VERBOSE=VERBOSE)
    sess = gpusession(); sess.run(tf.global_variables_initializer()) 
    CN.train(_sess=sess,_trainimg=trainimg,_trainlabel=trainlabel
               ,_testimg=testimg,_testlabel=testlabel,_valimg=valimg,_vallabel=vallabel
               ,_maxEpoch=50,_batchSize=256,_lr=1e-5,_LR_SCHEDULE=True,_kp=0.95,_PRINT_EVERY=10,_SAVE_BEST=True)
    sess.close()

Extracting ../data/train-images-idx3-ubyte.gz
Extracting ../data/train-labels-idx1-ubyte.gz
Extracting ../data/t10k-images-idx3-ubyte.gz
Extracting ../data/t10k-labels-idx1-ubyte.gz
==== Global Variables ====
 [00] Name:[basic_choicenet_mnist/identity_0/weights:0] Shape:[[3, 3, 1, 64]]
 [01] Name:[basic_choicenet_mnist/identity_0/BatchNorm/beta:0] Shape:[[64]]
 [02] Name:[basic_choicenet_mnist/identity_0/BatchNorm/moving_mean:0] Shape:[[64]]
 [03] Name:[basic_choicenet_mnist/identity_0/BatchNorm/moving_variance:0] Shape:[[64]]
 [04] Name:[basic_choicenet_mnist/res_a_0/weights:0] Shape:[[3, 3, 1, 64]]
 [05] Name:[basic_choicenet_mnist/res_a_0/BatchNorm/beta:0] Shape:[[64]]
 [06] Name:[basic_choicenet_mnist/res_a_0/BatchNorm/moving_mean:0] Shape:[[64]]
 [07] Name:[basic_choicenet_mnist/res_a_0/BatchNorm/moving_variance:0] Shape:[[64]]
 [08] Name:[basic_choicenet_mnist/res_b_0/weights:0] Shape:[[3, 3, 64, 64]]
 [09] Name:[basic_choicenet_mnist/res_b_0/BatchNorm/beta:0] Shape:[[64]]
 [10] 

[../net/net_basic_choicenet_mnist.npz] Saved. Size is [16.2410]MB
[35/50] [Loss] train:-0.319(f:-0.501+r:0.070+k:0.001+l:0.111) val:-0.627 test:-0.625 [Accr] train:54.9% val:98.2% test:97.9% maxVal:98.2% maxTest:97.9%
[40/50] [Loss] train:-0.322(f:-0.503+r:0.069+k:0.001+l:0.111) val:-0.624 test:-0.623 [Accr] train:55.0% val:98.2% test:97.9% maxVal:98.2% maxTest:97.9%
[45/50] [Loss] train:-0.324(f:-0.504+r:0.068+k:0.001+l:0.111) val:-0.626 test:-0.625 [Accr] train:55.0% val:98.1% test:97.9% maxVal:98.2% maxTest:97.9%
[../net/net_basic_choicenet_mnist.npz] Saved. Size is [16.2456]MB
[50/50] [Loss] train:-0.325(f:-0.505+r:0.068+k:0.001+l:0.111) val:-0.628 test:-0.627 [Accr] train:55.1% val:98.2% test:97.9% maxVal:98.2% maxTest:97.9%
Training finished.


### Restore and Re-run

In [5]:
if __name__ == "__main__":
    trainimg,trainlabel,testimg,testlabel,valimg,vallabel, \
    xdim,ydim,hdims,filterSizes,max_pools,feat_dim, \
    kmix,actv,bn,VERBOSE, \
    rho_ref_train,tau_inv,pi1_bias,logSigmaZval,logsumexp_coef,kl_reg_coef,l2_reg_coef, \
    USE_INPUT_BN,USE_RESNET,USE_GAP,USE_KENDALL_LOSS,USE_MIXUP = get_mnist_config()
    tf.reset_default_graph(); tf.set_random_seed(0)
    CN2 = choiceNet_cls_class(_name='basic_choicenet_mnist'
                          ,_xdim=xdim,_ydim=ydim,_hdims=hdims,_filterSizes=filterSizes
                          ,_max_pools=max_pools,_feat_dim=feat_dim,_kmix=kmix,_actv=actv,_bn=slim.batch_norm
                          ,_rho_ref_train=rho_ref_train,_tau_inv=tau_inv,_pi1_bias=pi1_bias,_logSigmaZval=logSigmaZval
                          ,_logsumexp_coef=logsumexp_coef,_kl_reg_coef=kl_reg_coef,_l2_reg_coef=l2_reg_coef
                          ,_USE_INPUT_BN=USE_INPUT_BN,_USE_RESNET=USE_RESNET,_USE_GAP=USE_GAP,_USE_KENDALL_LOSS=USE_KENDALL_LOSS
                          ,_USE_MIXUP=USE_MIXUP,_GPU_ID=0,_VERBOSE=False)
    sess = gpusession(); sess.run(tf.global_variables_initializer()) 
    CN2.restore(sess) # Restore weights
    CN2.test(sess,_trainimg=trainimg,_trainlabel=trainlabel
             ,_testimg=testimg,_testlabel=testlabel,_valimg=valimg,_vallabel=vallabel)
    sess.close()

Extracting ../data/train-images-idx3-ubyte.gz
Extracting ../data/train-labels-idx1-ubyte.gz
Extracting ../data/t10k-images-idx3-ubyte.gz
Extracting ../data/t10k-labels-idx1-ubyte.gz
[basic_choicenet_mnist] [Loss] train:-0.325 val:-0.625 test:-0.625 [Accr] train:55.042% val:98.199% test:97.862%
