Skip to content

Commit

Permalink
prefer keras.backend functions to tf functions
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Jan 13, 2019
1 parent c9e2b41 commit 8e51af4
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions deepcell/layers/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def compute_output_shape(self, input_shape):
def call(self, inputs):
input_shape = self.in_shape
if self.data_format == 'channels_first':
x = tf.range(0, input_shape[1], dtype=K.floatx())
y = tf.range(0, input_shape[2], dtype=K.floatx())
x = K.arange(0, input_shape[1], dtype=K.floatx())
y = K.arange(0, input_shape[2], dtype=K.floatx())
else:
x = tf.range(0, input_shape[0], dtype=K.floatx())
y = tf.range(0, input_shape[1], dtype=K.floatx())
x = K.arange(0, input_shape[0], dtype=K.floatx())
y = K.arange(0, input_shape[1], dtype=K.floatx())

x = tf.divide(x, tf.reduce_max(x))
y = tf.divide(y, tf.reduce_max(y))
x = x / K.max(x)
y = y / K.max(y)

loc_x, loc_y = tf.meshgrid(x, y, indexing='ij')

Expand Down Expand Up @@ -112,17 +112,17 @@ def call(self, inputs):
input_shape = self.in_shape

if self.data_format == 'channels_first':
z = tf.range(0, input_shape[1], dtype=K.floatx())
x = tf.range(0, input_shape[2], dtype=K.floatx())
y = tf.range(0, input_shape[3], dtype=K.floatx())
z = K.arange(0, input_shape[1], dtype=K.floatx())
x = K.arange(0, input_shape[2], dtype=K.floatx())
y = K.arange(0, input_shape[3], dtype=K.floatx())
else:
z = tf.range(0, input_shape[0], dtype=K.floatx())
x = tf.range(0, input_shape[1], dtype=K.floatx())
y = tf.range(0, input_shape[2], dtype=K.floatx())
z = K.arange(0, input_shape[0], dtype=K.floatx())
x = K.arange(0, input_shape[1], dtype=K.floatx())
y = K.arange(0, input_shape[2], dtype=K.floatx())

x = tf.divide(x, tf.reduce_max(x))
y = tf.divide(y, tf.reduce_max(y))
z = tf.divide(z, tf.reduce_max(z))
x = x / K.max(x)
y = y / K.max(y)
z = z / K.max(z)

loc_z, loc_x, loc_y = tf.meshgrid(z, x, y, indexing='ij')

Expand Down

0 comments on commit 8e51af4

Please sign in to comment.