Skip to content

Commit

Permalink
Remove a print
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Jun 30, 2018
1 parent ba90031 commit 540aae7
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions group_norm.py
Expand Up @@ -132,15 +132,9 @@ def call(self, inputs, **kwargs):
group_axes[self.axis] = input_shape[self.axis] // self.groups
group_axes.insert(1, self.groups)

if self.axis != 0:
group_shape = [group_axes[0], self.groups] + group_axes[2:]
else:
# case where group normalization is done on the batch axis
group_shape = [group_axes[0], self.groups] + group_axes[2:]

# group_shape[self.axis + 1] = input_shape[self.axis] // self.groups
# reshape inputs to new group shape
group_shape = [group_axes[0], self.groups] + group_axes[2:]
group_shape = K.stack(group_shape)
print(group_shape)
inputs = K.reshape(inputs, group_shape)

group_reduction_axes = list(range(len(group_axes)))
Expand Down Expand Up @@ -195,7 +189,7 @@ def compute_output_shape(self, input_shape):
if __name__ == '__main__':
from keras.layers import Input
from keras.models import Model
ip = Input(shape=(None, None, 2))
ip = Input(shape=(None, None, 4))
#ip = Input(batch_shape=(100, None, None, 2))
x = GroupNormalization(groups=2, axis=-1, epsilon=0.1)(ip)
model = Model(ip, x)
Expand Down

0 comments on commit 540aae7

Please sign in to comment.