Skip to content

Commit

Permalink
get_layer test
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Dec 12, 2016
1 parent f99cf22 commit 0b0ea99
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion examples/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# Create connected layers
# Input layer
l_in = yadll.layers.InputLayer(shape=(hp.batch_size, 28 * 28), name='Input')
l_in = yadll.layers.InputLayer(input_shape=(hp.batch_size, 28 * 28), name='Input')
# Dropout Layer 1
l_dro1 = yadll.layers.Dropout(incoming=l_in, corruption_level=0.4, name='Dropout 1')
# Dense Layer 1
Expand Down
4 changes: 4 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def test_get_reguls(self, layer):
def test_named_layer(self, named_layer):
assert named_layer.name == 'layer_name'

def test_unnamed_layer(self, layer):
from yadll.layers import Layer
assert layer.name == 'Layer ' + str(Layer.nb_instances)

def test_get_output(self, layer):
with pytest.raises(NotImplementedError):
layer.get_output()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def input(self, x):
@pytest.fixture
def layer(self, input):
from yadll.layers import DenseLayer
return DenseLayer(incoming=input, nb_units=25)
return DenseLayer(incoming=input, nb_units=25, name='DenseLayer 1')

@pytest.fixture
def layer2(self, input):
Expand All @@ -46,6 +46,8 @@ def test_network(self, network, x, input, layer, layer2, unsupervised_layer, cap
assert net.reguls == 0
assert net.params == [layer.W, layer.b]
assert net.name == 'test_network'
assert net.get_layer('DenseLayer 1') is layer
assert net['DenseLayer 1'] is layer
assert net.has_unsupervised_layer is False
net.add(unsupervised_layer)
assert net.has_unsupervised_layer is True
Expand Down
3 changes: 3 additions & 0 deletions yadll/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def get_layer(self, layer_name):
"""
return self.layers[self.layer_names.index(layer_name)]

def __getitem__(self, layer_name):
return self.get_layer(layer_name)

def save_params(self, file):
"""
Save the parameters of the network to file with cPickle
Expand Down

0 comments on commit 0b0ea99

Please sign in to comment.