Skip to content

Commit

Permalink
RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
pchavanne committed Jan 30, 2017
1 parent bd5ecb9 commit 524806a
Showing 1 changed file with 51 additions and 36 deletions.
87 changes: 51 additions & 36 deletions yadll/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,11 @@ class RNN(Layer):
"""
nb_instances = 0

def __init__(self, incoming, n_hidden, n_out, activation=sigmoid, last_only=True, **kwargs):
def __init__(self, incoming, n_hidden, n_out, activation=sigmoid,
last_only=True, go_backwards=False, allow_gc=False, **kwargs):
super(RNN, self).__init__(incoming, **kwargs)
self.allow_gc = allow_gc
self.go_backwards = go_backwards
self.last_only = last_only
self.activation = get_activation(activation)

Expand All @@ -683,11 +686,12 @@ def one_step(self, x_t, h_tm1, W, U, b):
def get_output(self, **kwargs):
X = self.input_layer.get_output(**kwargs)
h_t, updates = theano.scan(fn=self.one_step,
sequences=X,
outputs_info=[self.h0, None],
non_sequences=self.params,
allow_gc=False,
strict=True)
sequences=X,
outputs_info=[self.h0, None],
non_sequences=self.params,
go_backwards=self.go_backwards,
allow_gc=self.allow_gc,
strict=True)
return h_t


Expand All @@ -696,26 +700,26 @@ class LSTM(Layer):
Long Short Term Memory
.. math ::
i_t &= \sigma(x_t.W_i + h_{t-1}.U_i + b_i) && \text{Input gate}\\
f_t &= \sigma(x_t.W_f + h_{t-1}.U_f + b_f) && \text{Forget gate}\\
\tilde{C_t} &= \tanh(x_t.W_c + h_{t-1}.U_c + b_c) && \text{Cell gate}\\
C_t &= f_t * C_{t-1} + i_t * \tilde{C_t} && \text{Cell state}\\
o_t &= \sigma(x_t.W_o + h_{t-1}.U_o + b_o) && \text{Output gate}\\
i_t &= \sigma(x_t.W_i + h_{t-1}.U_i + b_i)\\
f_t &= \sigma(x_t.W_f + h_{t-1}.U_f + b_f)\\
\tilde{C_t} &= \tanh(x_t.W_c + h_{t-1}.U_c + b_c)\\
C_t &= f_t * C_{t-1} + i_t * \tilde{C_t}\\
o_t &= \sigma(x_t.W_o + h_{t-1}.U_o + b_o)\\
h_t &= o_t * \tanh(C_t) && \text{Hidden state}\\
\text{with Peephole connections:}\\
i_t &= \sigma(x_t.W_i + h_{t-1}.U_i + C_{t-1}.P_i + b_i) && \text{Input gate}\\
f_t &= \sigma(x_t.W_f + h_{t-1}.U_f + C_{t-1}.P_f + b_f) && \text{Forget gate}\\
\tilde{C_t} &= \tanh(x_t.W_c + h_{t-1}.U_c + b_c) && \text{Cell gate}\\
C_t &= f_t * C_{t-1} + i_t * \tilde{C_t} & \text{Cell state}\\
o_t &= \sigma(x_t.W_o + h_{t-1}.U_o + C_t.P_o + b_o) && \text{Output gate}\\
h_t &= o_t * \tanh(C_t) && \text{Hidden state}\\
i_t &= \sigma(x_t.W_i + h_{t-1}.U_i + C_{t-1}.P_i + b_i)\\
f_t &= \sigma(x_t.W_f + h_{t-1}.U_f + C_{t-1}.P_f + b_f)\\
\tilde{C_t} &= \tanh(x_t.W_c + h_{t-1}.U_c + b_c)\\
C_t &= f_t * C_{t-1} + i_t * \tilde{C_t}\\
o_t &= \sigma(x_t.W_o + h_{t-1}.U_o + C_t.P_o + b_o)\\
h_t &= o_t * \tanh(C_t)\\
\text{with tied forget and input gates:}\\
C_t &= f_t * C_{t-1} + (1 - f_t) * \tilde{C_t} && \text{Cell state}\\
C_t &= f_t * C_{t-1} + (1 - f_t) * \tilde{C_t}\\
Parameters
----------
incoming : a `Layer`
The incoming layer
The incoming layer with an output_shape = (nb_batches, nb_time_steps, nb_dim)
n_hidden : int or tuple of int
(n_hidden, n_input, n_forget, n_cell, n_output).
If an int is provided all gates have the same number of units
Expand All @@ -740,8 +744,11 @@ class LSTM(Layer):
"""
nb_instances = 0

def __init__(self, incoming, n_hidden, n_out, peephole=False, tied_i_f=False, activation=tanh, last_only=True, **kwargs):
def __init__(self, incoming, n_hidden, n_out, peephole=False, tied_i_f=False, activation=tanh,
last_only=True, go_backwards=False, allow_gc=False, **kwargs):
super(LSTM, self).__init__(incoming, **kwargs)
self.allow_gc = allow_gc
self.go_backwards = go_backwards
self.last_only = last_only
self.peephole = peephole # input and forget gates layers look at the cell state
self.tied = tied_i_f # only input new values to the state when we forget something
Expand All @@ -756,6 +763,8 @@ def __init__(self, incoming, n_hidden, n_out, peephole=False, tied_i_f=False, ac
self.W_i = orthogonal(shape=(self.n_in, self.n_i), name='W_i')
self.U_i = orthogonal(shape=(self.n_hidden, self.n_i), name='U_i')
self.b_i = uniform(shape=self.n_i, scale=(-0.5, .5), name='b_i')
if self.peephole:
self.P_i = orthogonal(shape=(self.n_hidden, self.n_i), name='P_i')
# forget gate
self.W_f = orthogonal(shape=(self.n_in, self.n_f), name='W_f')
self.U_f = orthogonal(shape=(self.n_hidden, self.n_f), name='U_f')
Expand All @@ -778,24 +787,26 @@ def __init__(self, incoming, n_hidden, n_out, peephole=False, tied_i_f=False, ac
self.U = T.concatenate([self.U_i, self.U_f, self.U_c, self.U_o])
self.b = T.concatenate([self.b_i, self.b_f, self.b_c, self.b_o])

if peephole:
self.P_i = orthogonal(shape=(self.n_c, self.n_i), name='W_ci')
self.P_f = orthogonal(shape=(self.n_c, self.n_f), name='W_cf')
self.P_o = orthogonal(shape=(self.n_c, self.n_o), name='W_co')
self.params.extend([self.P_i, self.P_f, self.P_o])

self.c0 = constant(shape=self.n_hidden, name='c0')
self.h0 = self.activation(self.c0)

def one_step(self, x_t, h_tm1, c_tm1, W_i, U_i, b_i,
W_f, U_f, b_f,
W_c, U_c, b_c,
W_o, U_o, b_o):
def one_step(self, x_t, h_tm1, c_tm1, W_i, U_i, b_i, W_f, U_f, b_f, W_c, U_c, b_c, W_o, U_o, b_o):
# forget gate
f_t = sigmoid(T.dot(x_t, W_f) + T.dot(h_tm1, U_f) + b_f)
# input gate
i_t = sigmoid(T.dot(x_t, W_i) + T.dot(h_tm1, U_i) + b_i)

# cell state
c_tt = self.activation(T.dot(x_t, W_c) + T.dot(h_tm1, U_c) + b_c)
c_t = f_t * c_tm1 + i_t * c_tt
if self.tied:
c_t = f_t * c_tm1 + (1 - f_t) * c_tt

i_t = 1. - f_t
else:
i_t = sigmoid(T.dot(x_t, W_i) + T.dot(h_tm1, U_i) + b_i)
# cell state
c_tilde_t = self.activation(T.dot(x_t, W_c) + T.dot(h_tm1, U_c) + b_c)
c_t = f_t * c_tm1 + i_t * c_tilde_t
# output gate
o_t = sigmoid(T.dot(x_t, W_o) + T.dot(h_tm1, U_o) + b_o)
# if self.peephole:
Expand All @@ -808,11 +819,15 @@ def one_step(self, x_t, h_tm1, c_tm1, W_i, U_i, b_i,
def get_output(self, **kwargs):
X = self.input_layer.get_output(**kwargs)
[h_vals, _], _ = theano.scan(fn=self.one_step,
sequences=X,
outputs_info=[self.h0, self.c0, None],
non_sequences=self.params,
allow_gc=False,
strict=True)
sequences=X,
outputs_info=[self.h0, self.c0, None],
non_sequences=self.params,
go_backwards=self.go_backwards,
allow_gc=self.allow_gc,
strict=True)
if self.last_only:
h_vals = h_vals[-1]

return h_vals


Expand Down

0 comments on commit 524806a

Please sign in to comment.