# Training and Evaluation

In [1]:
import numpy as np
import tensorflow as tf
import multiprocessing as mp
from multiprocessing.pool import ThreadPool
import datetime
import matplotlib.pyplot as plt
from src import *

In [2]:
def dense(x, scope, num_h, n_x):
    """
    standard affine layer
    scope = name tf variable scope
    num_h = number of hidden units
    """
    with tf.variable_scope(scope):
        w = tf.get_variable('w', [n_x, num_h], initializer=tf.random_normal_initializer(stddev=0.04))
        b = tf.get_variable('b', [num_h], initializer=tf.constant_initializer(0))
        return tf.matmul(x, w)+b

def conv(x, scope, filter_h,filter_w, n_kernel, stride_h=1,stride_w=1, padding='SAME'):
    """
    Convolutional layer
    scope        = name tf variable scope
    filter_width = receptive field of kernel
    n_kernel     = # of kernels
    stride       = locations
    """
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        n_x = x.get_shape().as_list()[-1]
        w = tf.get_variable('w',
                            [filter_h, filter_w, n_x, n_kernel],
                            initializer=tf.random_normal_initializer(stddev=0.04))
        b = tf.get_variable('b', [n_kernel], initializer=tf.constant_initializer(0))
        return tf.nn.convolution(x, w, padding=padding, strides=[stride_h, stride_w])+b    


def maxpool(X, scope, filter_h,filter_w, stride_h,stride_w, padding='VALID'):
    """
    Maxpool operation
    """
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE) as scope:
        pool = tf.nn.max_pool(X, 
                            ksize=[1, filter_h, filter_w, 1], 
                            strides=[1, stride_h, stride_w, 1],
                            padding=padding)
    return pool    
    
def upsample(X, ratio=2):
    """
    Takes a 4D image tensor and increases spatial resolution by replicating values
    """
    n_h, n_w = X.get_shape().as_list()[1:3]
    return tf.image.resize_nearest_neighbor(X, [n_h*ratio, n_w*ratio])
    
    
def convOneDImages(X,batch_size,n_images, reuse=False, ):
    """
    Makes convolutions over 1D images  in a tensor of size 
    (batch_size,n_images, imgsize,1 ). Returns the average 
    values of the encodings of 1D images. 
    """

    with tf.variable_scope('convOneDImages', reuse=reuse):
        h = tf.nn.leaky_relu(conv(X, 'conv0',1,5,32,1,2), 0.1)    #32x32
        h = tf.nn.leaky_relu(conv(h, 'conv1', 1,3,64,1,1), 0.1)   #32x64
        h = maxpool(h,'pool1', 1,2,1,2,'VALID')                        #16x64
        h = tf.nn.leaky_relu(conv(h, 'conv2', 1,3,128,1,1), 0.1)  #16x128
        h = maxpool(h,'pool2', 1,2,1,2,'VALID')                        #8x128
        h = tf.nn.leaky_relu(conv(h, 'conv3', 1,3,256,1,1), 0.1)  #8x256
        h = maxpool(h,'pool3', 1,2,1,2,'VALID')                        #4x256
        h = tf.reshape(h, [batch_size*n_images, -1])
        
        h = tf.nn.leaky_relu(dense(h, 'fc_0', 512,4*256), 0.1)
        h = tf.nn.leaky_relu(dense(h, 'fc_1', 512,512), 0.1)
        h = tf.reshape(h, [batch_size, n_images, 512]) 

        h = tf.reduce_mean(h, 1)
        return h


def gen2DObj(X, batch_size, enc_dim = 512, reuse=False):
    """
    Takes encoding produced by the 1D images as input and
    generates a 2D image of the object
    """
    with tf.variable_scope('2D_object_generator', reuse=reuse):
        h = tf.nn.leaky_relu(dense(X, 'hz', num_h=6*6*32,n_x=enc_dim), 0.1)
        h = tf.nn.leaky_relu(dense(h, 'hz1', num_h=6*6*128,n_x=6*6*32), 0.1)
        h = tf.reshape(h, [batch_size, 6, 6, 128]) 
        h = upsample(h)                                          #12x12x128
        h = tf.nn.leaky_relu(conv(h, 'h1', 3,3, n_kernel=64), 0.1)    #12x12x64
        h = upsample(h)                                          #24x24x64
        h = tf.nn.leaky_relu(conv(h, 'h2', 3,3, n_kernel=32), 0.1)    #24x24x32
        h = upsample(h)                                          #48x48x32
        h = tf.nn.leaky_relu(conv(h, 'h3', 3,3, n_kernel=16), 0.1)    #48x48x16
        h = conv(h, 'hx', 1,1,1)
        return h    
    

class objGenNetwork(object):
    """
    Implementation of the model
    """
    def __init__(self,
                 imgDim = 64,
                 o_imgDim = 48,
                 model_n_img_per_obj = 32,
                 noise = 0.5,
                 minibatchSize = 50,
                 testSampleSize = 1000,
                 lr = 0.001,
                 nProcess_dataprep = 4,
                 training = True,
                 shift_n = 8,
                 vers='v1'
                ):
        
        self.imgDim = imgDim
        self.o_imgDim = o_imgDim
        self.model_n_img_per_obj = model_n_img_per_obj
        self.noise = noise
        self.minibatchSize = minibatchSize
        self.testSampleSize = testSampleSize
        self.lr = lr
        self.nProcess_dataprep = nProcess_dataprep
        self.training = training
        self.shift_n = shift_n
        self.skip_step = 1
        self.vers = vers
        
        self.gstep = tf.Variable(0, 
                                 dtype=tf.int32, 
                                 trainable=False,
                                 name='global_step')
        self.train_x = None
        self.train_y = None
        self.test_x = None
        self.test_y = None

    def data_generator(self):

        """
        Generates batches of random objects and their 1D 
        density images from random rotations. Objects are 
        encoded in a NxN occupancy map, where the M[i,j] is
        1 if the object occupies the area between the coordinates
        (i,j),(i+1,j),(i,j+1),(i+1,j+1), and 0 otherwise. Gaussian
        noise with N(0,(self.noise)^2 ) is applied over 1D images.
        Additionally we shift 1D images randomly to left and right 
        by 8 pixels
        """
        nProcess_dataprep=self.nProcess_dataprep
        minibatchSize = self.minibatchSize
        testSampleSize = self.testSampleSize
        imgDim = self.imgDim
        o_imgDim = self.o_imgDim
        noise = self.noise
        model_n_img_per_obj=self.model_n_img_per_obj
        trainingBatch=True
        #shift_n = self.shift_n 
        #translations are done within the gen_rand_poly_images function
        
        batches = minibatchSize if trainingBatch else testSampleSize
        
        pool = mp.Pool(processes=nProcess_dataprep)
        results = pool.map(gen_rand_poly_images,[model_n_img_per_obj] * batches )
        pool.close(); pool.join()
        
        batch_x = np.expand_dims(np.array([k[1] for k in results],
                                          dtype='float32'),
                                 axis=3)
        batch_y = np.expand_dims(np.array([k[2] for k in results],
                                          dtype='float32'),
                                 axis=4)
        
        # add gaussian noise
        if(noise > 0):
            batch_x = batch_x + np.random.normal(0,noise,batch_x.shape)

        if trainingBatch:
            self.train_x_new = batch_x
            self.train_y_new = batch_y
        else:
            self.new_test_x = batch_x
            self.new_test_y = batch_y
        
        
    def inference(self):
        h = convOneDImages(self.x_ph, self.minibatchSize, self.model_n_img_per_obj)
        self.logits = gen2DObj(h, self.minibatchSize)

        
    def loss(self):
        """
        Defines loss function
        We use cross entropy with logits over the pixels. Computes the minimum
        loss over all rotations
        """
        # 
        with tf.name_scope('loss'):
            tiled_logits = tf.tile(tf.expand_dims(self.logits, 1),[1,360,1,1,1])
            entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.y_ph, logits=tiled_logits)
            entropy = tf.reduce_sum(entropy, axis = [2,3,4])
            entropy = tf.reduce_min(entropy,axis = 1)
            self.loss = tf.reduce_mean(entropy, name='loss')

    def optimize(self):
        """
        Optimization op
        """
        self.opt = tf.train.AdadeltaOptimizer(self.lr).minimize(self.loss, 
                                                global_step=self.gstep)

    def eval_graph(self):
        """
        Takes the most accurate rotation of the true object
        as reference and calculates the average accuracy of
        the occupancy grid of the object
        """
        with tf.name_scope('predict'):           
            self.predictions = tf.nn.sigmoid(self.logits)
        with tf.name_scope('prediction_eval'):
            tiledPreds = tf.tile(tf.expand_dims(self.predictions, 1),
                                 [1,360,1,1, 1])
            correctPreds = tf.equal(tf.round(tiledPreds),
                                    self.y_ph)
            hitsOnRotations = tf.reduce_sum(tf.cast(correctPreds,tf.float32),
                                            axis = [2,3,4])
            mostHits = tf.reduce_max(hitsOnRotations,axis=1)
            self.accuracy = tf.divide(tf.reduce_mean(mostHits),64.*64.)
        
        
    def eval_once(self, sess, writer, epoch, step):
        accuracy_batch, summaries = sess.run([self.accuracy,
                                              self.summary_op],
                                             feed_dict={self.x_ph: self.train_x_new,
                                                        self.y_ph: self.train_y_new})
        writer.add_summary(summaries, global_step=step)
        print('Accuracy at step {0}: {1} '.format(step,accuracy_batch))
        


    def summary(self):
        """
        Summary for TensorBoard
        """
        with tf.name_scope('summaries'):
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('accuracy', self.accuracy)
            tf.summary.histogram('histogram loss', self.loss)
            self.summary_op = tf.summary.merge_all()

    def build(self):
        """
        Builds the computation graph
        """
        self.x_ph = tf.placeholder(tf.float32, [self.minibatchSize, 
                                                self.model_n_img_per_obj,
                                                self.imgDim,
                                                1]) 
        self.y_ph = tf.placeholder(tf.float32, [self.minibatchSize,
                                                360,
                                                self.o_imgDim,
                                                self.o_imgDim,
                                                1])
        self.data_generator()
        self.inference()
        self.loss()
        self.optimize()
        self.eval_graph()
        self.summary()
    
    def train_one_epoch(self, sess, saver, writer, epoch, step):
#        start_time = time.time()
        self.training = True
        _, l, summaries = sess.run([self.opt, 
                                    self.loss,
                                    self.summary_op],
                                   feed_dict={self.x_ph: self.train_x_new,
                                              self.y_ph: self.train_y_new})
        writer.add_summary(summaries, global_step=step)
        if (step + 1) % self.skip_step == 0:
            print('Loss at step {0}: {1}'.format(step, l))
        step += 1
        saver.save(sess, 'checkpoints/cryoem/cryoem_'+self.vers, step)
#        print('Average loss at epoch {0}: {1}'.format(epoch, l))
#        print('Took: {0} seconds'.format(time.time() - start_time))
        return step

    def train(self, n_epochs):
        """
        We prepare the training data in a paralel process during training.
        When the next batch is ready, it is first used as a test data, and
        we report the prediction accuracy over that batch. Afterwards this
        new batch is used for training
        """
        safe_mkdir('checkpoints')
        safe_mkdir('checkpoints/cryoem')
        writer = tf.summary.FileWriter('./graphs/cryoem_'+self.vers, tf.get_default_graph())

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/cryoem/cryoem_'+self.vers))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            
            step = self.gstep.eval()

            for epoch in range(n_epochs):
                self.train_x = self.train_x_new
                self.train_y = self.train_y_new

                pool = ThreadPool(processes=1)
                async_result = pool.apply_async(self.data_generator,())
                step = self.train_one_epoch(sess, saver, writer, epoch, step)
                pool.close()
                pool.join()
                self.eval_once(sess, writer, epoch, step)
                
                y_pred = sess.run(self.predictions,feed_dict={self.x_ph: self.train_x_new})
                plt.imsave('0true.png',self.train_y_new[5][0][:,:,0])
                plt.imsave('0pred.png',y_pred[5,:,:,0]>0.5)
                plt.imsave('1true.png',self.train_y_new[6][0][:,:,0])
                plt.imsave('1pred.png',y_pred[6,:,:,0]>0.5)
                
        writer.close()


In [3]:
model = objGenNetwork(noise=0,model_n_img_per_obj=64,vers='sigma0')
model.build()


INFO:tensorflow:Summary name histogram loss is illegal; using histogram_loss instead.


In [4]:
model.train(n_epochs=20)

INFO:tensorflow:Restoring parameters from checkpoints/cryoem/cryoem_sigma1-83


NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key 2D_object_generator/h1/b/Adadelta not found in checkpoint
	 [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "/home/serkan/anaconda3/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/serkan/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 497, in start
    self.io_loop.start()
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/home/serkan/anaconda3/lib/python3.6/asyncio/base_events.py", line 422, in run_forever
    self._run_once()
  File "/home/serkan/anaconda3/lib/python3.6/asyncio/base_events.py", line 1434, in _run_once
    handle._run()
  File "/home/serkan/anaconda3/lib/python3.6/asyncio/events.py", line 145, in _run
    self._callback(*self._args)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py", line 758, in _run_callback
    ret = callback()
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 536, in <lambda>
    self.io_loop.add_callback(lambda : self._handle_events(self.socket, 0))
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2907, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/serkan/anaconda3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2961, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-f46dca8bb92c>", line 1, in <module>
    model.train(n_epochs=20)
  File "<ipython-input-2-2e47f82e9d67>", line 295, in train
    saver = tf.train.Saver()
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1281, in __init__
    self.build()
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1293, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1330, in _build
    build_save=build_save, build_restore=build_restore)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 778, in _build_internal
    restore_sequentially, reshape)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 397, in _AddRestoreOps
    restore_sequentially)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 829, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1463, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
    return func(*args, **kwargs)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3155, in create_op
    op_def=op_def)
  File "/home/serkan/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1717, in __init__
    self._traceback = tf_stack.extract_stack()

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key 2D_object_generator/h1/b/Adadelta not found in checkpoint
	 [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]


In [None]:
with tf.Session() as sess:    
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/cryoem/cryoem_sigma1'))
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

    y_pred = sess.run(model.predictions,feed_dict={model.x_ph: model.train_x_new})



fig=plt.figure(figsize=(20, 8))
columns = 2
rows = 5
for i in range(0, 5):
    fig.add_subplot(5, 2, i*2+1)
    plt.imshow(model.train_y_new[0][0][:,:,0],interpolation='nearest', cmap="gray")
    fig.add_subplot(5, 2, i*2+2)
    plt.imshow(y_pred[i,:,:,0]>0.5,interpolation='nearest', cmap="gray")
    plt.imshow(img)
plt.show()
