diff --git a/tensorlayer/layers/image_resampling.py b/tensorlayer/layers/image_resampling.py index 84aa744d6..fc0f5833b 100644 --- a/tensorlayer/layers/image_resampling.py +++ b/tensorlayer/layers/image_resampling.py @@ -59,19 +59,31 @@ def __init__( if not isinstance(size, (list, tuple)) and len(size) == 2: raise AssertionError() - if len(self.inputs.get_shape()) == 3: if is_scale: - size_h = size[0] * tf.shape(self.inputs)[0] - size_w = size[1] * tf.shape(self.inputs)[1] + input_shape = self.inputs.shape.as_list() + if input_shape[0] is not None: + size_h = size[0] * input_shape[0] + else: + size_h = size[0] * tf.shape(self.inputs)[0] + if input_shape[1] is not None: + size_w = size[1] * input_shape[1] + else: + size_w = size[1] * tf.shape(self.inputs)[1] size = [size_h, size_w] elif len(self.inputs.get_shape()) == 4: if is_scale: - size_h = size[0] * tf.shape(self.inputs)[1] - size_w = size[1] * tf.shape(self.inputs)[2] + input_shape = self.inputs.shape.as_list() + if input_shape[1] is not None: + size_h = size[0] * input_shape[1] + else: + size_h = size[0] * tf.shape(self.inputs)[1] + if input_shape[2] is not None: + size_w = size[1] * input_shape[2] + else: + size_w = size[1] * tf.shape(self.inputs)[2] size = [size_h, size_w] - else: raise Exception("Donot support shape %s" % tf.shape(self.inputs)) @@ -135,14 +147,28 @@ def __init__( if len(self.inputs.get_shape()) == 3: if is_scale: - size_h = size[0] * tf.shape(self.inputs)[0] - size_w = size[1] * tf.shape(self.inputs)[1] + input_shape = self.inputs.shape.as_list() + if input_shape[0] is not None: + size_h = size[0] * input_shape[0] + else: + size_h = size[0] * tf.shape(self.inputs)[0] + if input_shape[1] is not None: + size_w = size[1] * input_shape[1] + else: + size_w = size[1] * tf.shape(self.inputs)[1] size = [size_h, size_w] elif len(self.inputs.get_shape()) == 4: if is_scale: - size_h = size[0] * tf.shape(self.inputs)[1] - size_w = size[1] * tf.shape(self.inputs)[2] + input_shape = self.inputs.shape.as_list() + if input_shape[1] is not None: + size_h = size[0] * input_shape[1] + else: + size_h = size[0] * tf.shape(self.inputs)[1] + if input_shape[2] is not None: + size_w = size[1] * input_shape[2] + else: + size_w = size[1] * tf.shape(self.inputs)[2] size = [size_h, size_w] else: