From 0cffd3cc49a52e90c3319e6dc95ed9d78671278c Mon Sep 17 00:00:00 2001 From: vyokky <7678676@qq.com> Date: Fri, 24 Nov 2017 23:01:52 +0000 Subject: [PATCH] add Deformable CNN and ConvLSTM add Deformable CNN and ConvLSTM --- tensorlayer/layers.py | 507 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 507 insertions(+) diff --git a/tensorlayer/layers.py b/tensorlayer/layers.py index 433249be0..f052fed5c 100755 --- a/tensorlayer/layers.py +++ b/tensorlayer/layers.py @@ -2804,6 +2804,218 @@ def __init__( self.all_params.extend( variables ) + +def _to_bc_h_w(x, x_shape): + """(b, h, w, c) -> (b*c, h, w)""" + x = tf.transpose(x, [0, 3, 1, 2]) + x = tf.reshape(x, (-1, x_shape[1], x_shape[2])) + return x + + +def _to_b_h_w_n_c(x, x_shape): + """(b*c, h, w, n) -> (b, h, w, n, c)""" + x = tf.reshape( + x, (-1, x_shape[4], x_shape[1], x_shape[2], x_shape[3])) + x = tf.transpose(x, [0, 2, 3, 4, 1]) + return x + +def tf_repeat(a, repeats): + """TensorFlow version of np.repeat for 1D""" + # https://github.com/tensorflow/tensorflow/issues/8521 + assert len(a.get_shape()) == 1 + + a = tf.expand_dims(a, -1) + a = tf.tile(a, [1, repeats]) + a = tf_flatten(a) + return a + +def tf_batch_map_coordinates(inputs, coords): + """Batch version of tf_map_coordinates + + Only supports 2D feature maps + + Parameters + ---------- + input : tf.Tensor. shape = (b*c, h, w) + coords : tf.Tensor. shape = (b*c, h, w, n, 2) + + Returns + ------- + tf.Tensor. shape = (b*c, h, w, n) + """ + + input_shape = inputs.get_shape() + coords_shape = coords.get_shape() + batch_channel = tf.shape(inputs)[0] + input_h = int(input_shape[1]) + input_w = int(input_shape[2]) + kernel_n = int(coords_shape[3]) + n_coords = input_h * input_w * kernel_n + + coords_lt = tf.cast(tf.floor(coords), 'int32') + coords_rb = tf.cast(tf.ceil(coords), 'int32') + coords_lb = tf.stack([coords_lt[:, :, :, :, 0], coords_rb[:, :, :, :, 1]], axis=-1) + coords_rt = tf.stack([coords_rb[:, :, :, :, 0], coords_lt[:, :, :, :, 1]], axis=-1) + + idx = tf_repeat(tf.range(batch_channel), n_coords) + + vals_lt = _get_vals_by_coords(inputs, coords_lt, idx, (batch_channel, input_h, input_w, kernel_n)) + vals_rb = _get_vals_by_coords(inputs, coords_rb, idx, (batch_channel, input_h, input_w, kernel_n)) + vals_lb = _get_vals_by_coords(inputs, coords_lb, idx, (batch_channel, input_h, input_w, kernel_n)) + vals_rt = _get_vals_by_coords(inputs, coords_rt, idx, (batch_channel, input_h, input_w, kernel_n)) + + coords_offset_lt = coords - tf.cast(coords_lt, 'float32') + + vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, :, :, :, 0] + vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, :, :, :, 0] + mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, :, :, :, 1] + + return mapped_vals + +def tf_batch_map_offsets(inputs, offsets, grid_offset): + """Batch map offsets into input + + Parameters + --------- + inputs : tf.Tensor. shape = (b, h, w, c) + offsets: tf.Tensor. shape = (b, h, w, 2*n) + grid_offset: Offset grids + + Returns + ------- + tf.Tensor. shape = (b, h, w, c) + """ + + input_shape = inputs.get_shape() + batch_size = tf.shape(inputs)[0] + kernel_n = int(int(offsets.get_shape()[3]) / 2) + input_h = input_shape[1] + input_w = input_shape[2] + channel = input_shape[3] + batch_channel = batch_size * input_shape[3] + + # inputs (b, h, w, c) --> (b*c, h, w) + inputs = _to_bc_h_w(inputs, input_shape) + + # offsets (b, h, w, 2*n) --> (b, h, w, n, 2) + offsets = tf.reshape(offsets, (batch_size, input_h, input_w, kernel_n, 2)) + # offsets (b, h, w, n, 2) --> (b*c, h, w, n, 2) + offsets = tf.tile(offsets, [channel, 1, 1, 1, 1]) + + coords = tf.expand_dims(grid_offset, 0) # grid_offset --> (1, h, w, n, 2) + coords = tf.tile(coords, [batch_channel, 1, 1, 1, 1]) + offsets # grid_offset --> (b*c, h, w, n, 2) + # clip out of bound + coords = tf.stack([tf.clip_by_value(coords[:, :, :, :, 0], 0.0, tf.cast(input_h - 1, 'float32')), + tf.clip_by_value(coords[:, :, :, :, 1], 0.0, tf.cast(input_w - 1, 'float32'))], axis=-1) + + mapped_vals = tf_batch_map_coordinates(inputs, coords) + # (b*c, h, w, n) --> (b, h, w, n, c) + mapped_vals = _to_b_h_w_n_c(mapped_vals, [batch_size, input_h, input_w, kernel_n, channel]) + + return mapped_vals + +# ## 2D deformable convolutional layer +class DeformableConv2dLayer(Layer): + """The :class:`DeformableConv2dLayer` class is a + `Deformable Convolutional Layer ` + + Parameters + ----------- + layer : TensorLayer layer. + offset_layer: TensorLayer layer, to predict the offset of convolutional operations. The shape of its output should be (batchsize, input height, input width, 2*(number of element in the convolutional kernel)) + e.g. if apply a 3*3 kernel, the number of the last dimension should be 18 (2*3*3) + channel_multiplier : int, The number of channels to expand to. + filter_size : tuple (height, width) for filter size. + strides : tuple (height, width) for strides. + act : None or activation function. + shape : list of shape + shape of the filters, [filter_height, filter_width, in_channels, out_channels]. + W_init : weights initializer + The initializer for initializing the weight matrix. + b_init : biases initializer or None + The initializer for initializing the bias vector. If None, skip biases. + W_init_args : dictionary + The arguments for the weights tf.get_variable(). + b_init_args : dictionary + The arguments for the biases tf.get_variable(). + name : a string or None + An optional name to attach to this layer. + + + Note + ----------- + - The stride is fixed as (1, 1, 1, 1) + - `The padding is fixed as 'same' + - The current implementation is memory-inefficient, please use carefully + """ + def __init__( + self, + layer=None, + offset_layer=None, + act=tf.identity, + shape=[3, 3, 10, 10], + W_init=tf.truncated_normal_initializer(stddev=0.02), + b_init=tf.constant_initializer(value=0.0), + W_init_args={}, + b_init_args={}, + name='deformable_conv_2d_layer', + ): + Layer.__init__(self, name=name) + self.inputs = layer.outputs + self.offset_layer = offset_layer + + print(" [TL] DeformableConv2dLayer %s: shape:%s, act:%s" % + (self.name, str(shape), act.__name__)) + + with tf.variable_scope(name) as vs: + offset = self.offset_layer.outputs + assert offset.get_shape()[-1] == 2 * shape[0] * shape[1] + + ## Grid initialisation + input_h = int(self.inputs.get_shape()[1]) + input_w = int(self.inputs.get_shape()[2]) + kernel_n = shape[0] * shape[1] + initial_offsets = tf.stack(tf.meshgrid(tf.range(shape[0]), + tf.range(shape[1]), + indexing='ij')) # initial_offsets --> (kh, kw, 2) + initial_offsets = tf.reshape(initial_offsets, (-1, 2)) # initial_offsets --> (n, 2) + initial_offsets = tf.expand_dims(initial_offsets, 0) # initial_offsets --> (1, n, 2) + initial_offsets = tf.expand_dims(initial_offsets, 0) # initial_offsets --> (1, 1, n, 2) + initial_offsets = tf.tile(initial_offsets, [input_h, input_w, 1, 1]) # initial_offsets --> (h, w, n, 2) + initial_offsets = tf.cast(initial_offsets, 'float32') + grid = tf.meshgrid( + tf.range(input_h), tf.range(input_w), indexing='ij' + ) + grid = tf.stack(grid, axis=-1) + grid = tf.cast(grid, 'float32') # grid --> (h, w, 2) + grid = tf.expand_dims(grid, 2) # grid --> (h, w, 1, 2) + grid = tf.tile(grid, [1, 1, kernel_n, 1]) # grid --> (h, w, n, 2) + grid_offset = grid + initial_offsets # grid_offset --> (h, w, n, 2) + + input_deform = tf_batch_map_offsets(self.inputs, offset, grid_offset) + + W = tf.get_variable(name='W_conv2d', shape=[1, 1, shape[0] * shape[1], shape[-2], shape[-1]], + initializer=W_init, **W_init_args) + b = tf.get_variable(name='b_conv2d', shape=(shape[-1]), initializer=b_init, **b_init_args) + + self.outputs = tf.reshape(act( + tf.nn.conv3d(input_deform, W, strides=[1, 1, 1, 1, 1], padding='VALID', name=None) + b), + (tf.shape(self.inputs)[0], input_h, input_w, shape[-1])) + + ## fixed + self.all_layers = list(layer.all_layers) + self.all_params = list(layer.all_params) + self.all_drop = dict(layer.all_drop) + + ## offset_layer + self.all_layers.extend(offset_layer.all_layers) + self.all_params.extend(offset_layer.all_params) + self.all_drop.update(offset_layer.all_drop) + + ## this layer + self.all_layers.extend([self.outputs]) + + # ## Normalization layer class LocalResponseNormLayer(Layer): """The :class:`LocalResponseNormLayer` class is for Local Response Normalization, see ``tf.nn.local_response_normalization`` or ``tf.nn.lrn`` for new TF version. @@ -4454,6 +4666,301 @@ def __init__( self.all_layers.extend( [self.outputs] ) self.all_params.extend( rnn_variables ) + +class ConvRNNCell(object): + """Abstract object representing an Convolutional RNN cell. + """ + + def __call__(self, inputs, state, scope=None): + """Run this RNN cell on inputs, starting from the given state. + """ + raise NotImplementedError("Abstract method") + + @property + def state_size(self): + """size(s) of state(s) used by this cell. + """ + raise NotImplementedError("Abstract method") + + @property + def output_size(self): + """Integer or TensorShape: size of outputs produced by this cell.""" + raise NotImplementedError("Abstract method") + + def zero_state(self, batch_size, dtype): + """Return zero-filled state tensor(s). + Args: + batch_size: int, float, or unit Tensor representing the batch size. + dtype: the data type to use for the state. + Returns: + tensor of shape '[batch_size x shape[0] x shape[1] x num_features] + filled with zeros + """ + + shape = self.shape + num_features = self.num_features + zeros = tf.zeros([batch_size, shape[0], shape[1], num_features * 2]) + return zeros + + +class BasicConvLSTMCell(ConvRNNCell): + """Basic Conv LSTM recurrent network cell. The + """ + + def __init__(self, shape, filter_size, num_features, forget_bias=1.0, input_size=None, + state_is_tuple=False, activation=tf.nn.tanh): + """Initialize the basic Conv LSTM cell. + Args: + shape: int tuple thats the height and width of the cell + filter_size: int tuple thats the height and width of the filter + num_features: int thats the depth of the cell + forget_bias: float, The bias added to forget gates (see above). + input_size: Deprecated and unused. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. + """ + # if not state_is_tuple: + # logging.warn("%s: Using a concatenated state is slower and will soon be " + # "deprecated. Use state_is_tuple=True.", self) + if input_size is not None: + logging.warn("%s: The input_size parameter is deprecated.", self) + self.shape = shape + self.filter_size = filter_size + self.num_features = num_features + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation + + @property + def state_size(self): + return (LSTMStateTuple(self._num_units, self._num_units) + if self._state_is_tuple else 2 * self._num_units) + + @property + def output_size(self): + return self._num_units + + def __call__(self, inputs, state, scope=None): + """Long short-term memory cell (LSTM).""" + with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + print state + # c, h = tf.split(3, 2, state) + c, h = tf.split(state, 2, 3) + concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + # i, j, f, o = tf.split(3, 4, concat) + i, j, f, o = tf.split(concat, 4, 3) + + new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * + self._activation(j)) + new_h = self._activation(new_c) * tf.nn.sigmoid(o) + + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = tf.concat([new_c, new_h], 3) + return new_h, new_state + + +def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None): + """convolution: + Args: + args: a 4D Tensor or a list of 4D, batch x n, Tensors. + filter_size: int tuple of filter height and width. + num_features: int, number of features. + bias_start: starting value to initialize the bias; 0 by default. + scope: VariableScope for the created subgraph; defaults to "Linear". + Returns: + A 4D Tensor with shape [batch h w num_features] + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + + # Calculate the total size of arguments on dimension 1. + total_arg_size_depth = 0 + shapes = [a.get_shape().as_list() for a in args] + for shape in shapes: + if len(shape) != 4: + raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes)) + if not shape[3]: + raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes)) + else: + total_arg_size_depth += shape[3] + + dtype = [a.dtype for a in args][0] + + # Now the computation. + with tf.variable_scope(scope or "Conv"): + matrix = tf.get_variable( + "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype=dtype) + if len(args) == 1: + res = tf.nn.conv2d(args[0], matrix, strides=[1, 1, 1, 1], padding='SAME') + else: + res = tf.nn.conv2d(tf.concat(args, 3), matrix, strides=[1, 1, 1, 1], padding='SAME') + if not bias: + return res + bias_term = tf.get_variable( + "Bias", [num_features], + dtype=dtype, + initializer=tf.constant_initializer( + bias_start, dtype=dtype)) + return res + bias_term + +## ConvLSTM layer +class ConvLSTMLayer(Layer): + """ + The :class:`ConvLSTMLayer` class is a Convolutional LSTM layer. + `Convolutional LSTM Layer ` + Parameters + ---------- + layer : a :class:`Layer` instance + The `Layer` class feeding into this layer. + cell_shape : tuple, the shape of each cell width*height + filter_size : tuple, the size of filter width*height + cell_fn : a TensorFlow's core Convolutional RNN cell as follow. + feature_map : a int + The number of feature map in the layer. + initializer : initializer + The initializer for initializing the parameters. + n_steps : a int + The sequence length. + initial_state : None or RNN State + If None, initial_state is zero_state. + return_last : boolen + - If True, return the last output, "Sequence input and single output" + - If False, return all outputs, "Synced sequence input and output" + - In other word, if you want to apply one or more RNN(s) on this layer, set to False. + return_seq_2d : boolen + - When return_last = False + - If True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it. + - If False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it. + name : a string or None + An optional name to attach to this layer. + + Variables + -------------- + outputs : a tensor + The output of this RNN. + return_last = False, outputs = all cell_output, which is the hidden state. + cell_output.get_shape() = (?, n_hidden) + + final_state : a tensor or StateTuple + When state_is_tuple = False, + it is the final hidden and cell states, states.get_shape() = [?, 2 * n_hidden].\n + When state_is_tuple = True, it stores two elements: (c, h), in that order. + You can get the final state after each iteration during training, then + feed it to the initial state of next iteration. + + initial_state : a tensor or StateTuple + It is the initial state of this RNN layer, you can use it to initialize + your state at the begining of each epoch or iteration according to your + training procedure. + + batch_size : int or tensor + Is int, if able to compute the batch_size, otherwise, tensor for ``?``. + + """ + + def __init__( + self, + layer=None, + cell_shape=None, + feature_map=1, + filter_size=(3, 3), + cell_fn=BasicConvLSTMCell, + initializer=tf.random_uniform_initializer(-0.1, 0.1), + n_steps=5, + initial_state=None, + return_last=False, + return_seq_2d=False, + name='convlstm_layer', + ): + Layer.__init__(self, name=name) + self.inputs = layer.outputs + print(" tensorlayer:Instantiate RNNLayer %s: feature_map:%d, n_steps:%d, " + "in_dim:%d %s, cell_fn:%s " % (self.name, feature_map, + n_steps, self.inputs.get_shape().ndims, self.inputs.get_shape(), + cell_fn.__name__)) + # You can get the dimension by .get_shape() or ._shape, and check the + # dimension by .with_rank() as follow. + # self.inputs.get_shape().with_rank(2) + # self.inputs.get_shape().with_rank(3) + + # Input dimension should be rank 5 [batch_size, n_steps(max), n_features] + try: + self.inputs.get_shape().with_rank(5) + except: + raise Exception("RNN : Input dimension should be rank 5 : [batch_size, n_steps, input_x, " + "input_y, feature_map]") + + fixed_batch_size = self.inputs.get_shape().with_rank_at_least(1)[0] + + if fixed_batch_size.value: + batch_size = fixed_batch_size.value + print(" RNN batch_size (concurrent processes): %d" % batch_size) + else: + from tensorflow.python.ops import array_ops + batch_size = array_ops.shape(self.inputs)[0] + print(" non specified batch_size, uses a tensor instead.") + self.batch_size = batch_size + + # Simplified version of tensorflow.models.rnn.rnn.py's rnn(). + # This builds an unrolled LSTM for tutorial purposes only. + # In general, use the rnn() or state_saving_rnn() from rnn.py. + # + # The alternative version of the code below is: + # + # from tensorflow.models.rnn import rnn + # inputs = [tf.squeeze(input_, [1]) + # for input_ in tf.split(1, num_steps, inputs)] + # outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state) + outputs = [] + self.cell = cell = cell_fn(shape=cell_shape, filter_size=filter_size, num_features=feature_map) + if initial_state is None: + self.initial_state = cell.zero_state(batch_size, dtype=tf.float32) # 1.2.3 + state = self.initial_state + # with tf.variable_scope("model", reuse=None, initializer=initializer): + with tf.variable_scope(name, initializer=initializer) as vs: + for time_step in range(n_steps): + if time_step > 0: tf.get_variable_scope().reuse_variables() + (cell_output, state) = cell(self.inputs[:, time_step, :, :, :], state) + outputs.append(cell_output) + + # Retrieve just the RNN variables. + # rnn_variables = [v for v in tf.all_variables() if v.name.startswith(vs.name)] + rnn_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) + + print(" n_params : %d" % (len(rnn_variables))) + + if return_last: + # 2D Tensor [batch_size, n_hidden] + self.outputs = outputs[-1] + else: + if return_seq_2d: + # PTB tutorial: stack dense layer after that, or compute the cost from the output + # 2D Tensor [n_example, n_hidden] + self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, cell_shape[0] * cell_shape[1] * feature_map]) + else: + # : stack more RNN layer after that + # 5D Tensor [n_example/n_steps, n_steps, n_hidden] + self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, n_steps, cell_shape[0], + cell_shape[1], feature_map]) + + self.final_state = state + + self.all_layers = list(layer.all_layers) + self.all_params = list(layer.all_params) + self.all_drop = dict(layer.all_drop) + self.all_layers.extend([self.outputs]) + self.all_params.extend(rnn_variables) + # Advanced Ops for Dynamic RNN def advanced_indexing_op(input, index): """Advanced Indexing for Sequences, returns the outputs by given sequence lengths.