In [1]:
import numpy as np 
import tensorflow as tf
import multiprocessing as mp
from multiprocessing.pool import ThreadPool
from src import *


In [2]:
def dense(x, scope, num_h, n_x, haveBias = True):
    """
    Standard affine layer
    
    scope = name tf variable scope
    num_h = number of hidden units
    num_x = number of input units
    """
    with tf.variable_scope(scope):
        w = tf.get_variable('w', [n_x, num_h], initializer=tf.random_normal_initializer(stddev=0.04))
        if(haveBias):
            b = tf.get_variable('b', [num_h], initializer=tf.constant_initializer(0))
            return(tf.matmul(x, w)+b)
        else:
            return(tf.matmul(x, w))

def lrelu(x, alpha):
    return(tf.nn.relu(x) - alpha * tf.nn.relu(-x))
    
def conv(x, scope, filter_h,filter_w, n_kernel, stride_h=1,stride_w=1, padding='SAME'):
    """
    Convolutional layer
    
    scope        = name tf variable scope
    filter_h     = height of the receptive field
    filter_w     = width of the receptive field
    n_kernel     = # of kernels
    stride_h     = stride height
    stride_w     = stride width
    """
    with tf.variable_scope(scope, reuse=False):
        n_x = x.get_shape().as_list()[-1]
        w = tf.get_variable('w',
                            [filter_h, filter_w, n_x, n_kernel],
                            initializer=tf.random_normal_initializer(stddev=0.04))
        b = tf.get_variable('b', [n_kernel], initializer=tf.constant_initializer(0))
        return tf.nn.convolution(x, w, padding=padding, strides=[stride_h, stride_w])+b    


def bnorm(X,isTraining,scope='batch_norm',axis=-1):
    """
    Batch normalization layer
    
    X          = input
    isTraining = True during training, False otherwise.
    axis       = axis for normalization
    scope      = name tf variable scope
    
    """
    return(tf.layers.batch_normalization(
        inputs=X,
        axis=axis, 
        training = isTraining,
        name=scope
    ))


class Attention:
    """
    Attention class
    Reference:
    https://github.com/tensorflow/models/blob/master/official/transformer/model/attention_layer.py
    https://github.com/DongjunLee/transformer-tensorflow/blob/master/transformer/attention.py
    """

    def __init__(
        self,
        nObs,
        hiddenSize,
        numHeads=1,
        dropout=0.2,
        isTraining=True
        ):
        assert hiddenSize % numHeads == 0
        self.nObs=nObs
        self.hiddenSize = hiddenSize
        self.numHeads = numHeads
        self.dropout = dropout
        self.isTraining=isTraining
        self.dimPerHead = hiddenSize//numHeads

    def multi_head(self, q, k, v):
        q, k, v = self._linear_projection(q, k, v)
        qs, ks, vs = self._split_heads(q, k, v)
        outputs = self._scaled_dot_product(qs, ks, vs)
        output = self._concat_heads(outputs)
        output = tf.layers.dense(output, self.hiddenSize)
        return(output)
        #return(tf.nn.dropout(output, 1.0 - self.dropout))

    def _linear_projection(self, q, k, v):
        q = tf.layers.dense(q, self.hiddenSize, use_bias=False)#,activation=tf.nn.relu)
        k = tf.layers.dense(k, self.hiddenSize, use_bias=False)#,activation=tf.nn.relu)
        v = tf.layers.dense(v, self.hiddenSize, use_bias=False)#,activation=tf.nn.relu)
        return(q,k,v)

    def _split_heads(self, q, k, v):
        def splitAndTranspose(x, numHeads, hiddenSize):
            t_shape = x.get_shape().as_list()
            x = tf.reshape(x, [-1] + t_shape[1:-1] + [numHeads, hiddenSize // numHeads])
            return( tf.transpose(x, [0, 2, 1, 3]) )# [batchSize, numHeads, numOfObs, hiddenSize//numHeads]

        qs = splitAndTranspose(q, self.numHeads, self.hiddenSize)
        ks = splitAndTranspose(k, self.numHeads, self.hiddenSize)
        vs = splitAndTranspose(v, self.numHeads, self.hiddenSize)
        return(qs,ks,vs)

    def _scaled_dot_product(self, qs, ks, vs):
        o = tf.matmul(qs, ks, transpose_b=True)
        o /= self.dimPerHead**0.5
        o = tf.nn.softmax(o)
        return tf.matmul(o, vs)

    def _concat_heads(self, outputs):
        def transpose_then_concat_last_two_dimenstion(tensor):
            tensor = tf.transpose(tensor, [0, 2, 1, 3]) # [batch_size, numOfObs, numHeads, dim]
            t_shape = tensor.get_shape().as_list()
            numHeads, dim = t_shape[-2:]
            return(tf.reshape(tensor, [-1] + t_shape[1:-2] + [numHeads * dim]) )

        return(transpose_then_concat_last_two_dimenstion(outputs))


def convObservations(X,batch_size,n_images,isTraining=False, reuse=False):
    """
    Makes convolutions over noisy signals with cyclic shifts.
    X.size = (batch_size,nObservations, signalDim,1 ). Returns the average 
    value of the encodings of the observations
    """

    with tf.variable_scope('convObservations', reuse=reuse):
        h = lrelu(conv(X, 'conv0',1,5,32,1,1), 0.1)  # 5x64
        h = bnorm(h,isTraining,'bnorm_1d_0')
        h = lrelu(conv(h, 'conv1',1,3,64,1,1), 0.1)  # 5x64
        h = bnorm(h,isTraining,'bnorm_1d_1')
        h = lrelu(conv(h, 'conv2',1,3,128,1,2), 0.1)  # 3x128
        h = bnorm(h,isTraining,'bnorm_1d_2')
        h = lrelu(conv(h, 'conv3',1,3,256,1,2), 0.1)  # 2x128
        h = bnorm(h,isTraining,'bnorm_1d_3')
        h = tf.reshape(h, [-1, n_images, 512])

        h = lrelu(tf.layers.dense(h, 512, name='fc_0'),0.1)
        h = bnorm(h,isTraining,'bnorm_1d_fc0')
        h = lrelu(tf.layers.dense(h, 512, name='fc_1'),0.1)

        attn = Attention(
            nObs=n_images,
            hiddenSize=512,
            numHeads=8,
            dropout=0.1,
            isTraining=isTraining
            )

        h = tf.add(attn.multi_head(h,h,h),h)

        hff = lrelu(tf.layers.dense(h, 512, name='fc_2'),0.1)
        hff = bnorm(hff,isTraining,'bnorm_1d_fc2')
        hff = lrelu(tf.layers.dense(h, 512, name='fc_3'),0.1)
        h = tf.add(hff,h)
        
        h = tf.reduce_mean(h, 1)
        return(h)        


    
def decodeSignal(X, batch_size, enc_dim = 512, isTraining=False, reuse=False):
    """
    Takes encoding produced by the observations as input and
    generates a the underlying true signal
    """
    with tf.variable_scope('decodeSignal', reuse=reuse):
        h = lrelu(dense(X, 'hz0', num_h=256,n_x=enc_dim), 0.1)
        h = tf.concat([X, h], 1)
        h = bnorm(h,isTraining,'bnorm_hz0')
        h = lrelu(dense(h, 'hz1', num_h=128,n_x=256+enc_dim), 0.1)
        h = tf.concat([X, h], 1)
        h = bnorm(h,isTraining,'bnorm_hz1')
        h = lrelu(dense(h, 'hz2', num_h=64,n_x=128+enc_dim), 0.1)
        h = tf.concat([X, h], 1)
        h = bnorm(h,isTraining,'bnorm_hz2')
        h = lrelu(dense(h, 'hz3', num_h=32,n_x=64+enc_dim), 0.1)
        h = tf.concat([X, h], 1)
        h = bnorm(h,isTraining,'bnorm_hz3')
        h = lrelu(dense(h, 'hz4', num_h=16,n_x=32+enc_dim), 0.1)
        h = tf.concat([X, h], 1)
        h = bnorm(h,isTraining,'bnorm_hz4')
        h = lrelu(dense(h, 'hz5', num_h=8,n_x=16+enc_dim), 0.1)
        h = tf.concat([X, h], 1)
        #h = bnorm(h,isTraining,'bnorm_hz5')
        h = lrelu(dense(h, 'hz6', num_h=5,n_x=8+enc_dim), 0.1)
        h = dense(h, 'z', num_h=5,n_x=5)
        
        return(h)    
    

class objGenNetwork(object):
    """
    Implementation of the model
    """
    def __init__(self,
                 signalDim = 5,
                 nObservationsPerSignal = 64,
                 noise = 2,
                 minibatchSize = 64,
                 testSampleSize = 1000,
                 lr = 0.001,
                 training = True,
                 skipStep = 1,
                 nProcessesDataPrep=4,
                 vers='NOT_SPECIFIED',
                 evalAfterStep=0,
                 evalNTimes=1
                ):
        self.signalDim = signalDim
        self.nObservationsPerSignal = nObservationsPerSignal
        self.noise = noise
        self.minibatchSize = minibatchSize
        self.testSampleSize = testSampleSize
        self.lr = lr
        self.isTraining = training
        self.skipStep = skipStep
        self.nProcessesDataPrep = nProcessesDataPrep
        if (vers=='NOT_SPECIFIED'):
            self.vers = str(signalDim)+'D'+'_sigma_'+str(noise) +'_obs_' + str(nObservationsPerSignal )
        else:
            self.vers = vers
        self.logFile = 'log_'+ self.vers +'.txt'
        self.evalAfterStep = evalAfterStep
        self.evalNTimes = evalNTimes
        
        self.gstep = tf.Variable(0, 
                                 dtype=tf.int32, 
                                 trainable=False,
                                 name='global_step')
        self.train_x = None
        self.train_y = None
        self.test_x = None
        self.test_y = None

    def data_generator(self,trainingBatch=True):

        """
        Generates batches of random signals and their noisy 
        observations with cyclic shifts. |Signal| = dim(Signal)
        """
        minibatchSize = self.minibatchSize
        testSampleSize = self.testSampleSize
        signalDim = self.signalDim
        noise = self.noise
        nObservationsPerSignal=self.nObservationsPerSignal
        
        
        batches = minibatchSize if trainingBatch else testSampleSize

        
        poolData = mp.Pool(processes= self.nProcessesDataPrep)
        results = poolData.starmap(genSignal,[(
            signalDim,
            nObservationsPerSignal,
            noise)] * batches )
        poolData.close(); poolData.join()

        batch_x = np.expand_dims(
            np.array([k[1] for k in results],dtype='float32'),
            axis=3
        )
        batch_y = np.array([k[0] for k in results],dtype='float32')
        

        if trainingBatch:
            self.train_x_new = batch_x
            self.train_y_new = batch_y
        else:
            self.test_x_new = batch_x
            self.test_y_new = batch_y
        
        
    def inference(self):
        h = convObservations(self.x_ph, 
                             self.minibatchSize,
                             self.nObservationsPerSignal,
                             isTraining=self.isTraining
                            )
        self.preds = decodeSignal(h, 
                                  self.minibatchSize,
                                  isTraining=self.isTraining
                                 )

        
    def loss(self):
        """
        Defines loss function
        We use mean squared loss over the predicted and the true signal
        under the best fitting cyclic shift.
        """
        # 
        with tf.name_scope('loss'):
            tiled_preds = tf.tile(tf.expand_dims(self.preds, 1),[1,self.signalDim,1])
            entropy = tf.squared_difference(self.y_ph,tiled_preds)
            entropy = tf.reduce_sum(entropy, axis = 2)
            entropy = tf.reduce_min(entropy,axis = 1)
            self.loss = tf.reduce_mean(entropy, name='loss')

    def optimize(self):
        """
        Optimization op
        """
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
#            self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, 
#                                                                   global_step=self.gstep)
            
            optimizer = tf.train.GradientDescentOptimizer(self.lr)
            grads_and_vars = optimizer.compute_gradients(self.loss)
       	    with tf.name_scope('dropgrad'):
                grads_and_vars = [(tf.nn.dropout(g,0.05), v) for g, v in grads_and_vars]
            self.opt = optimizer.apply_gradients(grads_and_vars,global_step=self.gstep)





    def additionalEvalMetrics(self):
        """
        Takes the most accurate rotation of the true object
        as reference and calculates the average accuracy of
        the occupancy grid of the object
        """
        with tf.name_scope('prediction_eval'):
            tiledPreds = tf.tile(tf.expand_dims(self.preds, 1),
                                 [1,self.signalDim,1])
            MAE = tf.reduce_sum(tf.abs(tiledPreds - self.y_ph),
                                axis = 2)
            MAE = tf.reduce_min(MAE,axis=1)
            self.MAE = tf.reduce_mean(MAE,name="MAE")
        
        
    def testEval(self, sess, writer, epoch, step,evalNTimes):
        self.isTraining = False
        l2_arr = np.zeros((evalNTimes,))
        mae_arr = np.zeros((evalNTimes,))
        for i in range(evalNTimes):
            self.test_x = np.copy(self.test_x_new)
            self.test_y = np.copy(self.test_y_new)
            
            pool = ThreadPool(processes=1)
            async_result = pool.apply_async(self.data_generator,(False,))
            mae_arr[i], l2_arr[i] = sess.run(
                    [self.MAE,self.loss],
                     feed_dict={self.x_ph: self.test_x,
                                self.y_ph: self.test_y}
                     )
            pool.close()
            pool.join()


        if (self.logFile==False):
            print('test MAE at step {0:.6}: {1:.6} '.format(step,mae_arr.mean()))
            print('test loss at step {0:.6}: {1:.6} '.format(step,l2_arr.mean()))
        else:
            with open(self.logFile,'a') as lgfile:
                lgfile.write('{0}\t{1:.6}\t{2:.6}\n'.format(step,l2_arr.mean(),mae_arr.mean()))
            

    def summary(self):
        """
        Summary for TensorBoard
        """
        with tf.name_scope('summaries'):
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('MAE', self.MAE)
            tf.summary.histogram('histogram loss', self.loss)
            self.summary_op = tf.summary.merge_all()

    def build(self):
        """
        Builds the computation graph
        """
        self.x_ph = tf.placeholder(tf.float32, [None, 
                                                None,
                                                self.signalDim,
                                                1]) 
        self.y_ph = tf.placeholder(tf.float32, [None,
                                                self.signalDim,
                                                self.signalDim])
        self.data_generator()
        self.data_generator(trainingBatch=False)
        self.inference()
        self.loss()
        self.optimize()
        self.additionalEvalMetrics()
        self.summary()
    
    def train_one_epoch(self, sess, saver, writer, epoch, step):
#        start_time = time.time()
        self.isTraining = True
        _, l, summaries = sess.run([self.opt, 
                                    self.loss,
                                    self.summary_op],
                                   feed_dict={self.x_ph: self.train_x_new,
                                              self.y_ph: self.train_y_new})
        writer.add_summary(summaries, global_step=step)
        #if (step + 1) % self.skipStep == 0:
        #    print('training Loss at step {0}: {1}'.format(step, l))
        step += 1
        saver.save(sess, 'checkpoints/cryoem_'+self.vers+'/cpoint', global_step=step)
#        print('Average loss at epoch {0}: {1}'.format(epoch, l))
#        print('Took: {0} seconds'.format(time.time() - start_time))
        return step

    def train(self, n_epochs):
        """
        Calls the training ops and prepares the training data
        for the next batch in a parallel process.
        """
        safe_mkdir('checkpoints')
        safe_mkdir('checkpoints/cryoem_'+self.vers)
        writer = tf.summary.FileWriter('./graphs/cryoem_'+self.vers, tf.get_default_graph())

        tVars = tf.trainable_variables()
        defGraph = tf.get_default_graph()

        for v in defGraph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): 
            if (('bnorm_' in v.name) and
                ('/Adam' not in v.name) and
                ('Adagrad' not in v.name) and
                (v not in tVars )):
                tVars.append(v)
                
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver(var_list= tVars)
            ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/cryoem_'+self.vers+'/cpoint'))

            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            
            step = self.gstep.eval()

            for epoch in range(n_epochs):
                self.train_x = np.copy(self.train_x_new)
                self.train_y = np.copy(self.train_y_new)
                
                pool = ThreadPool(processes=1)
                async_result = pool.apply_async(self.data_generator,())
                step = self.train_one_epoch(sess, saver, writer, epoch, step)
                pool.close()
                pool.join()

                if ( ((step + 1) % self.skipStep == 0) and (step>self.evalAfterStep ) ) :
                    self.testEval(sess, writer, epoch, step,self.evalNTimes)
                    
                    
        writer.close()
        self.isTraining = False

In [4]:
model = objGenNetwork(
        signalDim = 5,
         nObservationsPerSignal = 128,
         noise = 2,
         minibatchSize = 32,
         testSampleSize = 64,
         evalNTimes=16,
         lr = 0.0001,
         training = True,
         skipStep = 10,
         nProcessesDataPrep=2,
         vers='deneme',
         evalAfterStep=0
         )


In [5]:
model.build()

ValueError: None values not supported.