Skip to content

Commit

Permalink
fix input dtype not matching variable dtype (#1386)
Browse files Browse the repository at this point in the history
* fix not input dtype not matching variable dtype

* fix input dtype not matching variable dtype

* Fix incorrect variable name input_dtype -> inputs_dtype

* Fix code format

* Fix code format
  • Loading branch information
gehuangyi20 authored and ppwwyyxx committed Jan 22, 2020
1 parent 12ad257 commit 4ac2e22
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tensorpack/models/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ def Conv2D(
if get_tf_version_tuple() >= (1, 5):
kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format)

# matching input dtype (ex. tf.float16) since the default dtype of variable if tf.float32
inputs_dtype = inputs.dtype
W = tf.get_variable(
'W', filter_shape, initializer=kernel_initializer)
'W', filter_shape, dtype=inputs_dtype, initializer=kernel_initializer)

if use_bias:
b = tf.get_variable('b', [out_channel], initializer=bias_initializer)
b = tf.get_variable('b', [out_channel], dtype=inputs_dtype, initializer=bias_initializer)

if split == 1:
conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs)
Expand Down Expand Up @@ -238,9 +240,11 @@ def Conv2DTranspose(
None if shape_sta[2] is None else shape_sta[2] * strides2d[1] + shape_res2d[1],
filters]

W = tf.get_variable('W', kernel_shape + [filters, channels_in], initializer=kernel_initializer)
inputs_dtype = inputs.dtype
W = tf.get_variable('W', kernel_shape + [filters, channels_in],
dtype=inputs_dtype, initializer=kernel_initializer)
if use_bias:
b = tf.get_variable('b', [filters], initializer=bias_initializer)
b = tf.get_variable('b', [filters], dtype=inputs_dtype, initializer=bias_initializer)
conv = tf.nn.conv2d_transpose(
inputs, W, out_shape_dyn,
shape4d(strides, data_format=data_format),
Expand Down

0 comments on commit 4ac2e22

Please sign in to comment.