Skip to content

Commit

Permalink
elu
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Jan 27, 2017
1 parent 0fcd995 commit ad89280
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
35 changes: 35 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,35 @@ def test_get_output(self, flatten_layer, input_layer, input_data):
assert (result == input.reshape(input.shape[0], -1)).all()


class TestActivation:
@pytest.fixture
def activation(self):
from yadll.layers import Activation
return Activation

@pytest.fixture
def input_data(self):
from yadll.utils import shared_variable
return shared_variable(np.random.random((2, 3, 4, 5)))

@pytest.fixture
def input_layer(self, input_data):
from yadll.layers import InputLayer
shape = (2, 3, 4, 5,)
return InputLayer(shape, input=input_data)

def test_output_shape(self, activation, input_layer):
layer = activation(input_layer)
assert layer.output_shape == (2, 3 * 4 * 5)

def test_get_output(self, activation, input_layer, input_data):
from yadll.activations import tanh
layer = activation(input_layer, activation=tanh)
result = layer.get_output().eval()
input = np.asarray(input_data.eval())
assert np.allclose(result, np.tanh(input))


class TestDenseLayer:
@pytest.fixture
def dense_layer(self):
Expand Down Expand Up @@ -371,3 +400,9 @@ def lstm(self):
from yadll.layers import LSTM
return LSTM


class TestGRU:
@pytest.fixture
def gru(self):
from yadll.layers import GRU
return GRU
3 changes: 1 addition & 2 deletions yadll/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def from_conf(self, conf):
class InputLayer(Layer):
"""
Input layer of the data, it has no parameters, it just shapes the data as
the input for any network.
A ::class:`InputLayer` is always the first layer of any network.
the input for any network. A ::class:`InputLayer` is always the first layer of any network.
"""
nb_instances = 0

Expand Down

0 comments on commit ad89280

Please sign in to comment.