From 4d164ea4cc02f4d9e729e800a611b65f091618d2 Mon Sep 17 00:00:00 2001 From: vyokky <7678676@qq.com> Date: Sun, 26 Nov 2017 00:24:42 +0000 Subject: [PATCH 1/2] fix typos and memory optimisation fix typos and memory optimisation --- tensorlayer/layers.py | 46 ++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index 19564b4fc..dde22cb4f 100755 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -1805,7 +1805,7 @@ def tf_batch_map_offsets(inputs, offsets, grid_offset): --------- inputs : tf.Tensor. shape = (b, h, w, c) offsets: tf.Tensor. shape = (b, h, w, 2*n) - grid_offset: Offset grids + grid_offset: Offset grids shape = (h, w, n, 2) Returns ------- @@ -1814,11 +1814,10 @@ def tf_batch_map_offsets(inputs, offsets, grid_offset): input_shape = inputs.get_shape() batch_size = tf.shape(inputs)[0] - kernel_n = int(int(offsets.get_shape()[3]) / 2) + kernel_n = int(int(offsets.get_shape()[3])/2) input_h = input_shape[1] input_w = input_shape[2] channel = input_shape[3] - batch_channel = batch_size * input_shape[3] # inputs (b, h, w, c) --> (b*c, h, w) inputs = _to_bc_h_w(inputs, input_shape) @@ -1826,13 +1825,15 @@ def tf_batch_map_offsets(inputs, offsets, grid_offset): # offsets (b, h, w, 2*n) --> (b, h, w, n, 2) offsets = tf.reshape(offsets, (batch_size, input_h, input_w, kernel_n, 2)) # offsets (b, h, w, n, 2) --> (b*c, h, w, n, 2) - offsets = tf.tile(offsets, [channel, 1, 1, 1, 1]) + # offsets = tf.tile(offsets, [channel, 1, 1, 1, 1]) coords = tf.expand_dims(grid_offset, 0) # grid_offset --> (1, h, w, n, 2) - coords = tf.tile(coords, [batch_channel, 1, 1, 1, 1]) + offsets # grid_offset --> (b*c, h, w, n, 2) + coords = tf.tile(coords, [batch_size, 1, 1, 1, 1]) + offsets # grid_offset --> (b, h, w, n, 2) + # clip out of bound coords = tf.stack([tf.clip_by_value(coords[:, :, :, :, 0], 0.0, tf.cast(input_h - 1, 'float32')), tf.clip_by_value(coords[:, :, :, :, 1], 0.0, tf.cast(input_w - 1, 'float32'))], axis=-1) + coords = tf.tile(coords, [channel, 1, 1, 1, 1]) mapped_vals = tf_batch_map_coordinates(inputs, coords) # (b*c, h, w, n) --> (b, h, w, n, c) @@ -4835,23 +4836,23 @@ class ConvLSTMLayer(Layer): The `Layer` class feeding into this layer. cell_shape : tuple, the shape of each cell width*height filter_size : tuple, the size of filter width*height - cell_fn : a TensorFlow's core Convolutional RNN cell as follow. + cell_fn : a Convolutional RNN cell as follow. feature_map : a int The number of feature map in the layer. initializer : initializer The initializer for initializing the parameters. n_steps : a int The sequence length. - initial_state : None or RNN State + initial_state : None or ConvLSTM State If None, initial_state is zero_state. return_last : boolen - If True, return the last output, "Sequence input and single output" - If False, return all outputs, "Synced sequence input and output" - - In other word, if you want to apply one or more RNN(s) on this layer, set to False. + - In other word, if you want to apply one or more ConvLSTM(s) on this layer, set to False. return_seq_2d : boolen - When return_last = False - - If True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it. - - If False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it. + - If True, return 4D Tensor [n_example, h, w, c], for stacking DenseLayer after it. + - If False, return 5D Tensor [n_example/n_steps, h, w, c], for stacking multiple ConvLSTM after it. name : a string or None An optional name to attach to this layer. @@ -4860,17 +4861,17 @@ class ConvLSTMLayer(Layer): outputs : a tensor The output of this RNN. return_last = False, outputs = all cell_output, which is the hidden state. - cell_output.get_shape() = (?, n_hidden) + cell_output.get_shape() = (?, h, w, c]) final_state : a tensor or StateTuple When state_is_tuple = False, - it is the final hidden and cell states, states.get_shape() = [?, 2 * n_hidden].\n - When state_is_tuple = True, it stores two elements: (c, h), in that order. + it is the final hidden and cell states, + When state_is_tuple = True, You can get the final state after each iteration during training, then feed it to the initial state of next iteration. initial_state : a tensor or StateTuple - It is the initial state of this RNN layer, you can use it to initialize + It is the initial state of this ConvLSTM layer, you can use it to initialize your state at the begining of each epoch or iteration according to your training procedure. @@ -4902,7 +4903,7 @@ def __init__( # self.inputs.get_shape().with_rank(2) # self.inputs.get_shape().with_rank(3) - # Input dimension should be rank 5 [batch_size, n_steps(max), n_features] + # Input dimension should be rank 5 [batch_size, n_steps(max), h, w, c] try: self.inputs.get_shape().with_rank(5) except: @@ -4920,16 +4921,7 @@ def __init__( print(" non specified batch_size, uses a tensor instead.") self.batch_size = batch_size - # Simplified version of tensorflow.models.rnn.rnn.py's rnn(). - # This builds an unrolled LSTM for tutorial purposes only. - # In general, use the rnn() or state_saving_rnn() from rnn.py. - # - # The alternative version of the code below is: - # - # from tensorflow.models.rnn import rnn - # inputs = [tf.squeeze(input_, [1]) - # for input_ in tf.split(1, num_steps, inputs)] - # outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state) + outputs = [] self.cell = cell = cell_fn(shape=cell_shape, filter_size=filter_size, num_features=feature_map) if initial_state is None: @@ -4954,11 +4946,11 @@ def __init__( else: if return_seq_2d: # PTB tutorial: stack dense layer after that, or compute the cost from the output - # 2D Tensor [n_example, n_hidden] + # 4D Tensor [n_example, h, w, c] self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, cell_shape[0] * cell_shape[1] * feature_map]) else: # : stack more RNN layer after that - # 5D Tensor [n_example/n_steps, n_steps, n_hidden] + # 5D Tensor [n_example/n_steps, n_steps, h, w, c] self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, n_steps, cell_shape[0], cell_shape[1], feature_map]) From 863cc656a86aa215c520d3df14d51342274d5bb4 Mon Sep 17 00:00:00 2001 From: vyokky <7678676@qq.com> Date: Sun, 26 Nov 2017 00:29:48 +0000 Subject: [PATCH 2/2] fix typos and memory optimisation --- tensorlayer/layers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index dde22cb4f..101141f89 100755 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -1888,7 +1888,10 @@ def __init__( Layer.__init__(self, name=name) self.inputs = layer.outputs self.offset_layer = offset_layer - + + if tf.__version__ < "1.4": + raise Exception("Deformable CNN layer requires tensrflow 1.4 or higher version") + print(" [TL] DeformableConv2dLayer %s: shape:%s, act:%s" % (self.name, str(shape), act.__name__)) @@ -6990,6 +6993,7 @@ def get_batch(self, data, bucket_id, PAD_ID=0, GO_ID=1, EOS_ID=2, UNK_ID=3): + #