diff --git a/tensorlayer/layers/merge.py b/tensorlayer/layers/merge.py index 1f4775c83..e4c6bde26 100644 --- a/tensorlayer/layers/merge.py +++ b/tensorlayer/layers/merge.py @@ -81,6 +81,8 @@ class ElementwiseLayer(Layer): combine_fn : a TensorFlow element-wise combine function e.g. AND is ``tf.minimum`` ; OR is ``tf.maximum`` ; ADD is ``tf.add`` ; MUL is ``tf.multiply`` and so on. See `TensorFlow Math API `__ . + act : activation function + The activation function of this layer. name : str A unique layer name. @@ -102,6 +104,7 @@ def __init__( self, layers, combine_fn=tf.minimum, + act=None, name='elementwise_layer', ): Layer.__init__(self, name=name) @@ -109,12 +112,13 @@ def __init__( logging.info("ElementwiseLayer %s: size:%s fn:%s" % (self.name, layers[0].outputs.get_shape(), combine_fn.__name__)) self.outputs = layers[0].outputs - # logging.info(self.outputs._shape, type(self.outputs._shape)) + for l in layers[1:]: - if str(self.outputs.get_shape()) != str(l.outputs.get_shape()): - raise Exception("Hint: the input shapes should be the same. %s != %s" % (self.outputs.get_shape(), str(l.outputs.get_shape()))) self.outputs = combine_fn(self.outputs, l.outputs, name=name) + if act: + self.outputs = act(self.outputs) + self.all_layers = list(layers[0].all_layers) self.all_params = list(layers[0].all_params) self.all_drop = dict(layers[0].all_drop)