Skip to content

Commit

Permalink
first yadll commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Aug 30, 2016
1 parent 6fe0cc9 commit 803e19a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def input_data(self):
@pytest.fixture
def input_layer(self, input_data):
from yadll.layers import InputLayer
shape = (10, 20)
shape = (None, 20)
return InputLayer(shape, input_var=input_data)

@pytest.fixture
Expand Down Expand Up @@ -310,6 +310,11 @@ def input_layer(self, input_data):
def layer(self, pool_layer, input_layer):
return pool_layer(input_layer, poolsize=(2, 2))

def test_output_shape(self, layer):
assert layer.output_shape() == (layer.input_shape[0],
layer.input_shape[1],
layer.input_shape[2] / layer.poolsize[0],
layer.input_shape[3] / layer.poolsize[1])

class TestConvLayer:
@pytest.fixture
Expand Down Expand Up @@ -339,6 +344,20 @@ def rbm(self):
return RBM


class TestBatchNormalization:
@pytest.fixture
def batch_normalization(self):
from yadll.layers import BatchNormalization
return BatchNormalization


class TestLayerNormalization:
@pytest.fixture
def layer_normalization(self):
from yadll.layers import BatchNormalization
return BatchNormalization


class TestRNN:
@pytest.fixture
def rnn(self):
Expand Down
41 changes: 41 additions & 0 deletions yadll/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ def __init__(self, incoming, corruption_level=0.5, **kwargs):

def get_output(self, stochastic=True, **kwargs):
X = self.input_layer.get_output(stochastic=stochastic, **kwargs)
if self.input_shape[0] is None:
lst = list(self.input_shape)
lst[0] = T.shape(X)[0]
self.input_shape = tuple(lst)
if self.p != 1 and stochastic:
X = X * T_rng.binomial(self.input_shape, n=1, p=self.p, dtype=floatX)
return X
Expand Down Expand Up @@ -521,6 +525,43 @@ def get_unsupervised_cost(self, persistent=None, k=1, **kwargs):
return monitoring_cost, updates


class BatchNormalization(Layer):
"""
Normalize the previous layer at each batch
References
----------
..[1] http://jmlr.org/proceedings/papers/v37/ioffe15.pdf
..[2] https://github.com/fchollet/keras/blob/master/keras/layers/normalization.py#L6
"""
def __init__(self, incoming, **kwargs):
super(BatchNormalization, self).__init__(incoming, **kwargs)

def get_output(self, **kwargs):
X = self.input_layer.get_output(**kwargs)
# TODO
return X


class LayerNormalization(Layer):
"""
Normalize the previous layer at each batch
References
----------
..[1] http://arxiv.org/pdf/1607.06450v1.pdf
"""
def __init__(self, incoming, **kwargs):
super(LayerNormalization, self).__init__(incoming, **kwargs)

def get_output(self, **kwargs):
X = self.input_layer.get_output(**kwargs)
# TODO
return X


class RNN(Layer):
"""
Recurrent Neural Network
Expand Down

0 comments on commit 803e19a

Please sign in to comment.