Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 24 additions & 28 deletions tensorlayer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@ def tf_batch_map_offsets(inputs, offsets, grid_offset):
---------
inputs : tf.Tensor. shape = (b, h, w, c)
offsets: tf.Tensor. shape = (b, h, w, 2*n)
grid_offset: Offset grids
grid_offset: Offset grids shape = (h, w, n, 2)

Returns
-------
Expand All @@ -1814,25 +1814,26 @@ def tf_batch_map_offsets(inputs, offsets, grid_offset):

input_shape = inputs.get_shape()
batch_size = tf.shape(inputs)[0]
kernel_n = int(int(offsets.get_shape()[3]) / 2)
kernel_n = int(int(offsets.get_shape()[3])/2)
input_h = input_shape[1]
input_w = input_shape[2]
channel = input_shape[3]
batch_channel = batch_size * input_shape[3]

# inputs (b, h, w, c) --> (b*c, h, w)
inputs = _to_bc_h_w(inputs, input_shape)

# offsets (b, h, w, 2*n) --> (b, h, w, n, 2)
offsets = tf.reshape(offsets, (batch_size, input_h, input_w, kernel_n, 2))
# offsets (b, h, w, n, 2) --> (b*c, h, w, n, 2)
offsets = tf.tile(offsets, [channel, 1, 1, 1, 1])
# offsets = tf.tile(offsets, [channel, 1, 1, 1, 1])

coords = tf.expand_dims(grid_offset, 0) # grid_offset --> (1, h, w, n, 2)
coords = tf.tile(coords, [batch_channel, 1, 1, 1, 1]) + offsets # grid_offset --> (b*c, h, w, n, 2)
coords = tf.tile(coords, [batch_size, 1, 1, 1, 1]) + offsets # grid_offset --> (b, h, w, n, 2)

# clip out of bound
coords = tf.stack([tf.clip_by_value(coords[:, :, :, :, 0], 0.0, tf.cast(input_h - 1, 'float32')),
tf.clip_by_value(coords[:, :, :, :, 1], 0.0, tf.cast(input_w - 1, 'float32'))], axis=-1)
coords = tf.tile(coords, [channel, 1, 1, 1, 1])

mapped_vals = tf_batch_map_coordinates(inputs, coords)
# (b*c, h, w, n) --> (b, h, w, n, c)
Expand Down Expand Up @@ -1887,7 +1888,10 @@ def __init__(
Layer.__init__(self, name=name)
self.inputs = layer.outputs
self.offset_layer = offset_layer


if tf.__version__ < "1.4":
raise Exception("Deformable CNN layer requires tensrflow 1.4 or higher version")

print(" [TL] DeformableConv2dLayer %s: shape:%s, act:%s" %
(self.name, str(shape), act.__name__))

Expand Down Expand Up @@ -4835,23 +4839,23 @@ class ConvLSTMLayer(Layer):
The `Layer` class feeding into this layer.
cell_shape : tuple, the shape of each cell width*height
filter_size : tuple, the size of filter width*height
cell_fn : a TensorFlow's core Convolutional RNN cell as follow.
cell_fn : a Convolutional RNN cell as follow.
feature_map : a int
The number of feature map in the layer.
initializer : initializer
The initializer for initializing the parameters.
n_steps : a int
The sequence length.
initial_state : None or RNN State
initial_state : None or ConvLSTM State
If None, initial_state is zero_state.
return_last : boolen
- If True, return the last output, "Sequence input and single output"
- If False, return all outputs, "Synced sequence input and output"
- In other word, if you want to apply one or more RNN(s) on this layer, set to False.
- In other word, if you want to apply one or more ConvLSTM(s) on this layer, set to False.
return_seq_2d : boolen
- When return_last = False
- If True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it.
- If False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it.
- If True, return 4D Tensor [n_example, h, w, c], for stacking DenseLayer after it.
- If False, return 5D Tensor [n_example/n_steps, h, w, c], for stacking multiple ConvLSTM after it.
name : a string or None
An optional name to attach to this layer.

Expand All @@ -4860,17 +4864,17 @@ class ConvLSTMLayer(Layer):
outputs : a tensor
The output of this RNN.
return_last = False, outputs = all cell_output, which is the hidden state.
cell_output.get_shape() = (?, n_hidden)
cell_output.get_shape() = (?, h, w, c])

final_state : a tensor or StateTuple
When state_is_tuple = False,
it is the final hidden and cell states, states.get_shape() = [?, 2 * n_hidden].\n
When state_is_tuple = True, it stores two elements: (c, h), in that order.
it is the final hidden and cell states,
When state_is_tuple = True,
You can get the final state after each iteration during training, then
feed it to the initial state of next iteration.

initial_state : a tensor or StateTuple
It is the initial state of this RNN layer, you can use it to initialize
It is the initial state of this ConvLSTM layer, you can use it to initialize
your state at the begining of each epoch or iteration according to your
training procedure.

Expand Down Expand Up @@ -4902,7 +4906,7 @@ def __init__(
# self.inputs.get_shape().with_rank(2)
# self.inputs.get_shape().with_rank(3)

# Input dimension should be rank 5 [batch_size, n_steps(max), n_features]
# Input dimension should be rank 5 [batch_size, n_steps(max), h, w, c]
try:
self.inputs.get_shape().with_rank(5)
except:
Expand All @@ -4920,16 +4924,7 @@ def __init__(
print(" non specified batch_size, uses a tensor instead.")
self.batch_size = batch_size

# Simplified version of tensorflow.models.rnn.rnn.py's rnn().
# This builds an unrolled LSTM for tutorial purposes only.
# In general, use the rnn() or state_saving_rnn() from rnn.py.
#
# The alternative version of the code below is:
#
# from tensorflow.models.rnn import rnn
# inputs = [tf.squeeze(input_, [1])
# for input_ in tf.split(1, num_steps, inputs)]
# outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)

outputs = []
self.cell = cell = cell_fn(shape=cell_shape, filter_size=filter_size, num_features=feature_map)
if initial_state is None:
Expand All @@ -4954,11 +4949,11 @@ def __init__(
else:
if return_seq_2d:
# PTB tutorial: stack dense layer after that, or compute the cost from the output
# 2D Tensor [n_example, n_hidden]
# 4D Tensor [n_example, h, w, c]
self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, cell_shape[0] * cell_shape[1] * feature_map])
else:
# <akara>: stack more RNN layer after that
# 5D Tensor [n_example/n_steps, n_steps, n_hidden]
# 5D Tensor [n_example/n_steps, n_steps, h, w, c]
self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, n_steps, cell_shape[0],
cell_shape[1], feature_map])

Expand Down Expand Up @@ -6998,6 +6993,7 @@ def get_batch(self, data, bucket_id, PAD_ID=0, GO_ID=1, EOS_ID=2, UNK_ID=3):






#