Skip to content
This repository has been archived by the owner on Mar 31, 2019. It is now read-only.

Commit

Permalink
untabify
Browse files Browse the repository at this point in the history
  • Loading branch information
justheuristic committed Aug 9, 2017
1 parent 0790d48 commit 94ef1b4
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions agentnet/utils/layers/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_output_for(self, input, **kwargs):
:param kwargs: no effect
:return: upcasted tensor
"""
if not hasattr(self.broadcast_layer, "symbolic_input_shape"):
if not hasattr(self.broadcast_layer, "symbolic_input_shape"):
raise ValueError("UpcastLayer.get_output_for must be called after respective BroadcastLayer.get_output_for")

# symbolic shape. dirty hack to handle "None" axes
Expand All @@ -196,11 +196,12 @@ def get_output_shape_for(self, input_shape, **kwargs):

# this one is NOT symbolic. list() is used as a shallow copy op.
original_shape = list(self.broadcast_layer.input_shape)
broadcasted_dims = [original_shape[ax] for ax in self.broadcast_layer.broadcasted_axes if ax != 0]

if input_shape[0] is None or None in broadcasted_dims:
new_batch_size = None

else:
broadcasted_dims = [original_shape[ax] for ax in self.broadcast_layer.broadcasted_axes if ax != 0]
new_batch_size = original_shape[0] * np.prod(broadcasted_dims)

new_shape = (new_batch_size,) + tuple(input_shape)[1:]
Expand Down

0 comments on commit 94ef1b4

Please sign in to comment.