# Training and Evaluation

Still work in progress. Will be constructing a graph neural network to input the data and find the L2 orbit loss. Afterward I will also try a GAN to do the same task and compare their performances.

In [1]:
import numpy as np
import tensorflow as tf
import multiprocessing as mp
from multiprocessing.pool import ThreadPool
import datetime
from src import *

In [2]:
def dense(x, scope, num_h, n_x):
    """
    standard affine layer
    scope = name tf variable scope
    num_h = number of hidden units
    """
    with tf.variable_scope(scope):
        w = tf.get_variable('w', [n_x, num_h], initializer=tf.random_normal_initializer(stddev=0.04))
        b = tf.get_variable('b', [num_h], initializer=tf.constant_initializer(0))
        return tf.matmul(x, w)+b

def noise(X, std=1):
    """
    adds gaussian noise with unit variance
    """
    noise = tf.random_normal(shape=tf.shape(X), mean=0.0, stddev=std, dtype=tf.float32) 
    return(input_layer + noise)
    
def conv(x, scope, filter_h,filter_w, n_kernel, stride_h=1,stride_w=1, padding='SAME'):
    """
    1d convolutions over multiple 1d images
    scope        = name tf variable scope
    filter_width = receptive field of kernel
    n_kernel     = # of kernels
    stride       = locations
    """
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        n_x = x.get_shape().as_list()[-1]
        w = tf.get_variable('w',
                            [filter_h, filter_w, n_x, n_kernel],
                            initializer=tf.random_normal_initializer(stddev=0.04))
        b = tf.get_variable('b', [n_kernel], initializer=tf.constant_initializer(0))
        return tf.nn.convolution(x, w, padding=padding, strides=[stride_h, stride_w])+b    


def maxpool(X, scope, filter_h,filter_w, stride_h,stride_w, padding='VALID'):
    """
    maxpool operation over multiple 1d images
    """
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE) as scope:
        pool = tf.nn.max_pool(X, 
                            ksize=[1, filter_h, filter_w, 1], 
                            strides=[1, stride_h, stride_w, 1],
                            padding=padding)
    return pool    
    
def upsample(X, ratio=2):
    """
    Takes a 4D image tensor and increases spatial resolution by replicating values
    """
    n_h, n_w = X.get_shape().as_list()[1:3]
    return tf.image.resize_nearest_neighbor(X, [n_h*ratio, n_w*ratio])
    
    
def convOneDImages(X, reuse=False):
    """
    Makes convolutions over 1D images and takes their averages
    for given object. Currently trying an approach similar to GNN.
    """
    n_batch = tf.shape(X)[0]
    model_n_img_per_obj = X.get_shape().as_list()[1]
    model_n_img_per_obj = X
    with tf.variable_scope('convOneDImages', reuse=reuse):
        h = tf.nn.leaky_relu(conv(X, 'conv0',1,5,32,1,2), 0.1)    #32x32
        h = tf.nn.leaky_relu(conv(h, 'conv1', 1,3,64,1,1), 0.1)   #32x64
        h = maxpool(h,'pool1', 1,2,1,2,'VALID')                        #16x64
        h = tf.nn.leaky_relu(conv(h, 'conv2', 1,3,128,1,1), 0.1)  #16x128
        h = maxpool(h,'pool2', 1,2,1,2,'VALID')                        #8x128
        h = tf.nn.leaky_relu(conv(h, 'conv3', 1,3,256,1,1), 0.1)  #8x256
        h = maxpool(h,'pool3', 1,2,1,2,'VALID')                        #4x256
        h = tf.reshape(h, [50* 32, -1])                   # batchsize 50, n_img=32
        
        h = tf.nn.leaky_relu(dense(h, 'fc_0', 512,4*256), 0.1)
        h = tf.nn.leaky_relu(dense(h, 'fc_1', 512,512), 0.1)
        h = tf.reshape(h, [50, 32, 512]) 

        h = tf.reduce_mean(h, 1)
        return h


def gen2DObj(X, reuse=False):
    """
    Takes encoding produced by the 1D images as input and
    generates a 2D image of the object
    """
    with tf.variable_scope('2D_object_generator', reuse=reuse):
        h = tf.nn.leaky_relu(dense(X, 'hz', num_h=6*6*32,n_x=512), 0.1)
        h = tf.nn.leaky_relu(dense(h, 'hz1', num_h=6*6*128,n_x=6*6*32), 0.1)
        h = tf.reshape(h, [50, 6, 6, 128]) 
        h = upsample(h)                                          #12x12x128
        h = tf.nn.leaky_relu(conv(h, 'h1', 3,3, n_kernel=64), 0.1)    #12x12x64
        h = upsample(h)                                          #24x24x64
        h = tf.nn.leaky_relu(conv(h, 'h2', 3,3, n_kernel=32), 0.1)    #24x24x32
        h = upsample(h)                                          #48x48x32
        h = tf.nn.leaky_relu(conv(h, 'h3', 3,3, n_kernel=16), 0.1)    #48x48x16
        h = conv(h, 'hx', 1,1,1)
        return h    
    

class objGenNetwork(object):
    """
    not finished.
    """
    def __init__(self,
                 imgDim = 64,
                 o_imgDim = 48,
                 model_n_img_per_obj = 32,
                 noise = 0.5,
                 minibatchSize = 50,
                 testSampleSize = 1000,
                 lr = 0.001,
                 nProcess_dataprep = 4,
                 training = True,
                 shift_n = 8
                ):
        
        self.imgDim = imgDim
        self.o_imgDim = o_imgDim
        self.model_n_img_per_obj = model_n_img_per_obj
        self.noise = noise
        self.minibatchSize = minibatchSize
        self.testSampleSize = testSampleSize
        self.lr = lr
        self.nProcess_dataprep = nProcess_dataprep
        self.training = training
        self.shift_n = shift_n
        self.skip_step = 1
        
        self.gstep = tf.Variable(0, 
                                 dtype=tf.int32, 
                                 trainable=False,
                                 name='global_step')
        self.train_x = None
        self.train_y = None
        self.test_x = None
        self.test_y = None

    def data_generator(self):

        """
        Reads chunks of data starting for the i'th iteration.
        Prepares train_x, test_x, train_y, test_y 
        """
        nProcess_dataprep=self.nProcess_dataprep
        minibatchSize = self.minibatchSize
        testSampleSize = self.testSampleSize
        imgDim = self.imgDim
        o_imgDim = self.o_imgDim
        noise = self.noise
        shift_n = self.shift_n
        model_n_img_per_obj=self.model_n_img_per_obj
        trainingBatch=True
        
        batches = minibatchSize if trainingBatch else testSampleSize
        
        pool = mp.Pool(processes=nProcess_dataprep)
        results = pool.map(gen_rand_poly_images,[model_n_img_per_obj] * batches )
        pool.close(); pool.join()
        
        batch_x = np.expand_dims(np.array([k[1] for k in results],
                                          dtype='float32'),
                                 axis=3)
        batch_y = np.expand_dims(np.array([k[2] for k in results],
                                          dtype='float32'),
                                 axis=4)
        
        # add gaussian noise
        if(noise > 0):
            batch_x = batch_x + np.random.normal(0,1,batch_x.shape)

        if trainingBatch:
            self.train_x_new = batch_x
            self.train_y_new = batch_y
        else:
            self.new_test_x = batch_x
            self.new_test_y = batch_y
        
        
    def inference(self):
        h = convOneDImages(self.x_ph)
        self.logits = gen2DObj(h)

        
    def loss(self):
        """
        define loss function
        use softmax cross entropy with logits as the loss function
        compute mean cross entropy, softmax is applied internally
        """
        # 
        with tf.name_scope('loss'):
            tiled_logits = tf.tile(tf.expand_dims(self.logits, 1),[1,360,1,1,1])
            entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.y_ph, logits=tiled_logits)
            entropy = tf.reduce_min(entropy,axis = 1)
            self.loss = tf.reduce_mean(entropy, name='loss')

    def optimize(self):
        """
        define optimization op
        """
        self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, 
                                                global_step=self.gstep)

    def eval_graph(self):
        """
        For each prediction, takes the most accurate rotation
        of the true object as reference and calculates the average
        accuracy of the occupancy grid of the object
        """
        with tf.name_scope('predict'):           
            preds = tf.nn.sigmoid(self.logits)
            tiled_preds = tf.tile(tf.expand_dims(preds, 1),[1,360,1,1, 1])
            correct_preds = tf.equal(tf.round(tiled_preds), self.y_ph)
            accuracy_best = tf.reduce_max(tf.cast(correct_preds, tf.float32), axis = 1)
            self.accuracy = tf.reduce_mean(accuracy_best)
        
        
    def eval_once(self, sess, writer, epoch, step):
        accuracy_batch, summaries = sess.run([self.accuracy,
                                              self.summary_op],
                                             feed_dict={self.x_ph: self.train_x_new,
                                                        self.y_ph: self.train_y_new})
        writer.add_summary(summaries, global_step=step)
        print('Accuracy at epoch {0}: {1} '.format(epoch,accuracy_batch))
        


    def summary(self):
        """
        Summary for TensorBoard
        """
        with tf.name_scope('summaries'):
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('accuracy', self.accuracy)
            tf.summary.histogram('histogram loss', self.loss)
            self.summary_op = tf.summary.merge_all()

    def build(self):
        """
        Build the computation graph
        """
        self.x_ph = tf.placeholder(tf.float32, [None, 32,64,1]) 
        self.y_ph = tf.placeholder(tf.float32, [None, 360,48,48,1])
        self.data_generator()
        self.inference()
        self.loss()
        self.optimize()
        self.eval_graph()
        self.summary()
    
    def train_one_epoch(self, sess, saver, writer, epoch, step):
#        start_time = time.time()
        self.training = True
        _, l, summaries = sess.run([self.opt, 
                                    self.loss,
                                    self.summary_op],
                                   feed_dict={self.x_ph: self.train_x_new,
                                              self.y_ph: self.train_y_new})
        writer.add_summary(summaries, global_step=step)
        if (step + 1) % self.skip_step == 0:
            print('Loss at step {0}: {1}'.format(step, l))
        step += 1
        saver.save(sess, 'checkpoints/cryoem/cryoem_v1', step)
#        print('Average loss at epoch {0}: {1}'.format(epoch, l))
#        print('Took: {0} seconds'.format(time.time() - start_time))
        return step

    def train(self, n_epochs):
        """
        The train function alternates between training one epoch and evaluating
        """
        safe_mkdir('checkpoints')
        safe_mkdir('checkpoints/cryoem')
        writer = tf.summary.FileWriter('./graphs/cryoem_v1', tf.get_default_graph())

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/cryoem/cryoem_v1'))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            
            step = self.gstep.eval()

            for epoch in range(n_epochs):
                self.train_x = self.train_x_new
                self.train_y = self.train_y_new

                pool = ThreadPool(processes=1)
                async_result = pool.apply_async(self.data_generator,())
                step = self.train_one_epoch(sess, saver, writer, epoch, step)
                pool.close()
                pool.join()
                self.eval_once(sess, writer, epoch, step)
        writer.close()


In [3]:
model = objGenNetwork()
model.build()


INFO:tensorflow:Summary name histogram loss is illegal; using histogram_loss instead.


In [None]:
model.train(n_epochs=50)

Loss at step 0: 0.6927354335784912
Accuracy at epoch 0: 0.9044097065925598 
Loss at step 1: 0.679400622844696
Accuracy at epoch 1: 0.8859201669692993 
Loss at step 2: 0.6101904511451721
Accuracy at epoch 2: 0.8989322781562805 
Loss at step 3: 0.39158859848976135
Accuracy at epoch 3: 0.9072916507720947 
Loss at step 4: 0.2797037661075592
Accuracy at epoch 4: 0.903967022895813 
Loss at step 5: 0.24782182276248932
Accuracy at epoch 5: 0.9213975667953491 
Loss at step 6: 0.17488929629325867
Accuracy at epoch 6: 0.93729168176651 
Loss at step 7: 0.1565966159105301
Accuracy at epoch 7: 0.9616666436195374 
Loss at step 8: 0.15781818330287933
Accuracy at epoch 8: 0.9823524355888367 
Loss at step 9: 0.1516440361738205
Accuracy at epoch 9: 0.9825607538223267 
Loss at step 10: 0.1366826891899109
Accuracy at epoch 10: 0.9806163311004639 
Loss at step 11: 0.12341755628585815
Accuracy at epoch 11: 0.9817361235618591 
Loss at step 12: 0.11431533843278885
Accuracy at epoch 12: 0.9913802146911621 
Loss