Skip to content

Commit

Permalink
Merge pull request #590 from ufal/nematus_variable
Browse files Browse the repository at this point in the history
Importing Nematus models
  • Loading branch information
varisd committed Nov 21, 2017
2 parents 8f2d663 + 0186206 commit 916a3d2
Show file tree
Hide file tree
Showing 9 changed files with 756 additions and 87 deletions.
44 changes: 24 additions & 20 deletions neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=too-many-lines
import math
from typing import (cast, Iterable, List, Callable, Optional,
Any, Tuple, NamedTuple, Union)
from typing import (cast, Iterable, List, Callable, Optional, Any, Tuple,
NamedTuple)

import numpy as np
import tensorflow as tf
Expand All @@ -13,19 +13,20 @@
PAD_TOKEN_INDEX)
from neuralmonkey.model.model_part import ModelPart, FeedDict
from neuralmonkey.model.sequence import EmbeddedSequence
from neuralmonkey.model.stateful import (TemporalStatefulWithOutput,
SpatialStatefulWithOutput)
from neuralmonkey.model.stateful import Stateful
from neuralmonkey.logging import log, warn
from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell
from neuralmonkey.nn.ortho_gru_cell import OrthoGRUCell, NematusGRUCell
from neuralmonkey.nn.utils import dropout
from neuralmonkey.decoders.encoder_projection import (
linear_encoder_projection, concat_encoder_projection, empty_initial_state)
linear_encoder_projection, concat_encoder_projection, empty_initial_state,
EncoderProjection)
from neuralmonkey.decoders.output_projection import (OutputProjectionSpec,
nonlinear_output)
from neuralmonkey.decorators import tensor


RNN_CELL_TYPES = {
"NematusGRU": NematusGRUCell,
"GRU": OrthoGRUCell,
"LSTM": tf.contrib.rnn.LSTMCell
}
Expand Down Expand Up @@ -62,9 +63,7 @@ class Decoder(ModelPart):
# pylint: disable=too-many-locals
# pylint: disable=too-many-arguments,too-many-branches,too-many-statements
def __init__(self,
# TODO only stateful, attention will need temporal or spat.
encoders: List[Union[TemporalStatefulWithOutput,
SpatialStatefulWithOutput]],
encoders: List[Stateful],
vocabulary: Vocabulary,
data_id: str,
name: str,
Expand All @@ -73,9 +72,7 @@ def __init__(self,
rnn_size: int = None,
embedding_size: int = None,
output_projection: OutputProjectionSpec = None,
encoder_projection: Callable[
[tf.Tensor, Optional[int], Optional[List[Any]]],
tf.Tensor]=None,
encoder_projection: EncoderProjection = None,
attentions: List[BaseAttention] = None,
embeddings_source: EmbeddedSequence = None,
attention_on_input: bool = True,
Expand Down Expand Up @@ -164,7 +161,8 @@ def __init__(self,
assert self.rnn_size is not None

if self._rnn_cell_str not in RNN_CELL_TYPES:
raise ValueError("RNN cell must be a either 'GRU' or 'LSTM'")
raise ValueError("RNN cell must be a either 'GRU', 'LSTM', or "
"'NematusGRU'. Not {}".format(self._rnn_cell_str))

if self.output_projection_spec is None:
log("No output projection specified - using tanh projection")
Expand Down Expand Up @@ -365,7 +363,11 @@ def _get_rnn_cell(self) -> tf.contrib.rnn.RNNCell:
return RNN_CELL_TYPES[self._rnn_cell_str](self.rnn_size)

def _get_conditional_gru_cell(self) -> tf.contrib.rnn.GRUCell:
return tf.contrib.rnn.GRUCell(self.rnn_size)
if self._rnn_cell_str == "NematusGRU":
return NematusGRUCell(
self.rnn_size, use_state_bias=True, use_input_bias=False)

return RNN_CELL_TYPES[self._rnn_cell_str](self.rnn_size)

def embed_input_symbol(self, *args) -> tf.Tensor:
loop_state = LoopState(*args)
Expand Down Expand Up @@ -403,16 +405,17 @@ def body(*args) -> LoopState:

# Run the RNN.
cell = self._get_rnn_cell()
if self._rnn_cell_str == "GRU":
cell_output, state = cell(rnn_input,
loop_state.prev_rnn_output)
next_state = state
if self._rnn_cell_str in ["GRU", "NematusGRU"]:
cell_output, next_state = cell(
rnn_input, loop_state.prev_rnn_output)

attns = [
a.attention(cell_output, loop_state.prev_rnn_output,
rnn_input, att_loop_state, loop_state.step)
for a, att_loop_state in zip(
self.attentions,
loop_state.attention_loop_states)]

if self.attentions:
contexts, att_loop_states = zip(*attns)
else:
Expand All @@ -421,8 +424,9 @@ def body(*args) -> LoopState:
if self._conditional_gru:
cell_cond = self._get_conditional_gru_cell()
cond_input = tf.concat(contexts, -1)
cell_output, state = cell_cond(cond_input, state,
scope="cond_gru_2_cell")
cell_output, next_state = cell_cond(
cond_input, next_state, scope="cond_gru_2_cell")

elif self._rnn_cell_str == "LSTM":
prev_state = tf.contrib.rnn.LSTMStateTuple(
loop_state.prev_rnn_state, loop_state.prev_rnn_output)
Expand Down
142 changes: 85 additions & 57 deletions neuralmonkey/decoders/encoder_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,47 @@
This module contains different variants of projection of encoders into the
initial state of the decoder.
"""
Encoder projections are specified in the configuration file. Each encoder
projection function has a unified type ``EncoderProjection``, which is a
callable that takes three arguments:
1. ``train_mode`` -- boolean tensor specifying whether the train mode is on
2. ``rnn_size`` -- the size of the resulting initial state
3. ``encoders`` -- a list of ``Stateful`` objects used as the encoders.
from typing import List, Optional, Callable
To enable further parameterization of encoder projection functions, one can
use higher-order functions.
"""
from typing import List, Callable, cast

import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.model.stateful import Stateful
from neuralmonkey.model.stateful import Stateful, TemporalStatefulWithOutput
from neuralmonkey.nn.utils import dropout
from neuralmonkey.logging import log


# pylint: disable=invalid-name
EncoderProjection = Callable[
[tf.Tensor, int, List[Stateful]], tf.Tensor]
# pylint: enable=invalid-name


# pylint: disable=unused-argument
# The function must conform the API
def empty_initial_state(train_mode: tf.Tensor,
rnn_size: Optional[int],
rnn_size: int,
encoders: List[Stateful] = None) -> tf.Tensor:
"""Return an empty vector.
Arguments:
train_mode: tf 0-D bool Tensor specifying the training mode (not used)
rnn_size: The size of the resulting vector
encoders: The list of encoders (not used)
"""
"""Return an empty vector."""
if rnn_size is None:
raise ValueError("You must supply rnn_size for this type of "
"encoder projection")
raise ValueError(
"You must supply rnn_size for this type of encoder projection")
return tf.zeros([rnn_size])


def linear_encoder_projection(
dropout_keep_prob: float) -> Callable[
[tf.Tensor, Optional[int], Optional[List[Stateful]]],
tf.Tensor]:
def linear_encoder_projection(dropout_keep_prob: float) -> EncoderProjection:
"""Return a linear encoder projection.
Return a projection function which applies dropout on concatenated
Expand All @@ -45,61 +52,82 @@ def linear_encoder_projection(
Arguments:
dropout_keep_prob: The dropout keep probability
"""
check_argument_types()

def func(train_mode: tf.Tensor,
rnn_size: Optional[int] = None,
encoders: Optional[List[Stateful]] = None) -> tf.Tensor:
"""Linearly project encoders' encoded value.
Linearly project encoders' encoded value to rnn_size
and apply dropout.
Arguments:
train_mode: tf 0-D bool Tensor specifying the training mode
rnn_size: The size of the resulting vector
encoders: The list of encoders
"""
if rnn_size is None:
raise ValueError("You must supply rnn_size for this type of "
"encoder projection")
rnn_size: int,
encoders: List[Stateful]) -> tf.Tensor:

if encoders is None or not encoders:
raise ValueError("There must be at least one encoder for this type"
" of encoder projection")
if rnn_size is None:
raise ValueError(
"You must supply rnn_size for this type of encoder projection")

encoded_concat = tf.concat([e.output for e in encoders], 1)
encoded_concat = dropout(
encoded_concat, dropout_keep_prob, train_mode)
en_concat = concat_encoder_projection(train_mode, None, encoders)
en_concat = dropout(en_concat, dropout_keep_prob, train_mode)

return tf.layers.dense(encoded_concat, rnn_size,
name="encoders_projection")
return tf.layers.dense(en_concat, rnn_size, name="encoders_projection")

return func
return cast(EncoderProjection, func)


def concat_encoder_projection(
train_mode: tf.Tensor,
rnn_size: Optional[int] = None,
encoders: Optional[List[Stateful]] = None) -> tf.Tensor:
"""Create the initial state by concatenating the encoders' encoded values.
rnn_size: int = None,
encoders: List[Stateful] = None) -> tf.Tensor:
"""Concatenate the encoded values of the encoders."""

Arguments:
train_mode: tf 0-D bool Tensor specifying the training mode (not used)
rnn_size: The size of the resulting vector (not used)
encoders: The list of encoders
"""
if encoders is None or not encoders:
raise ValueError("There must be at least one encoder for this type "
"of encoder projection")

if rnn_size is not None:
assert rnn_size == sum(e.output.get_shape()[1].value
for e in encoders)

encoded_concat = tf.concat([e.output for e in encoders], 1)
output_size = sum(e.output.get_shape()[1].value for e in encoders)
if rnn_size is not None and rnn_size != output_size:
raise ValueError("RNN size supplied for concat projection ({}) does "
"not match the size of the concatenated vectors ({})."
.format(rnn_size, output_size))

# pylint: disable=no-member
log("The inferred rnn_size of this encoder projection will be {}"
.format(encoded_concat.get_shape()[1].value))
# pylint: enable=no-member
.format(output_size))

encoded_concat = tf.concat([e.output for e in encoders], 1)
return encoded_concat


def nematus_projection(dropout_keep_prob: float = 1.0) -> EncoderProjection:
"""Return encoder projection used in Nematus.
The initial state is a dense projection with tanh activation computed on
the averaged states of the encoders. Dropout is applied to the means
(before the projection).
Arguments:
dropout_keep_prob: The dropout keep probability.
"""
check_argument_types()

def func(
train_mode: tf.Tensor,
rnn_size: int,
encoders: List[TemporalStatefulWithOutput]) -> tf.Tensor:

if len(encoders) != 1:
raise ValueError("Exactly one encoder required for this type of "
"projection. {} given.".format(len(encoders)))
encoder = encoders[0]

# shape (batch, time)
masked_sum = tf.reduce_sum(
encoder.temporal_states
* tf.expand_dims(encoder.temporal_mask, 2), 1)

# shape (batch, 1)
lengths = tf.reduce_sum(encoder.temporal_mask, 1, keep_dims=True)

means = masked_sum / lengths
means = dropout(means, dropout_keep_prob, train_mode)

return tf.layers.dense(means, rnn_size,
activation=tf.tanh,
name="encoders_projection")

return cast(EncoderProjection, func)
46 changes: 45 additions & 1 deletion neuralmonkey/decoders/output_projection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
"""Module with different variants of projection functions for RNN outputs."""
"""Output Projection Module.
This module contains different variants of projection functions of decoder
outputs into the logit function inputs.
Output projections are specified in the configuration file. Each output
projection function has a unified type ``OutputProjection``, which is a
callable that takes four arguments and returns a tensor:
1. ``prev_state`` -- the hidden state of the decoder.
2. ``prev_output`` -- embedding of the previously decoded word (or train input)
3. ``ctx_tensots`` -- a list of context vectors (for each attention object)
To enable further parameterization of output projection functions, one can
use higher-order functions.
"""
from typing import Union, Tuple, List, Callable
import tensorflow as tf
from typeguard import check_argument_types

from neuralmonkey.nn.projection import multilayer_projection, maxout
from neuralmonkey.nn.utils import dropout


# pylint: disable=invalid-name
Expand Down Expand Up @@ -57,6 +72,35 @@ def _projection(prev_state, prev_output, ctx_tensors, train_mode):
return _projection, output_size


def nematus_output(
output_size: int,
dropout_keep_prob: float = 1.0) -> Tuple[OutputProjection, int]:
"""Apply nonlinear one-hidden-layer deep output.
Implementation consistent with Nematus.
Can be used instead of (and is in theory equivalent to) nonlinear_output.
Projects the RNN state, embedding of the previously outputted word, and
concatenation of all context vectors into a shared vector space, sums them
up and apply a hyperbolic tangent activation function.
"""
check_argument_types()

def _projection(prev_state, prev_output, ctx_tensors, train_mode):
prev_state = dropout(prev_state, dropout_keep_prob, train_mode)
prev_output = dropout(prev_output, dropout_keep_prob, train_mode)
ctx_concat = tf.concat(ctx_tensors, 1)
ctx = dropout(ctx_concat, dropout_keep_prob, train_mode)

logit_rnn = tf.layers.dense(prev_state, output_size, name="rnn_state")
logit_emb = tf.layers.dense(prev_output, output_size, name="prev_out")
logit_ctx = tf.layers.dense(ctx, output_size, name="context")

return tf.tanh(logit_rnn + logit_emb + logit_ctx)

return _projection, output_size


def nonlinear_output(
output_size: int,
activation_fn: Callable[[tf.Tensor], tf.Tensor] = tf.tanh
Expand Down
Loading

0 comments on commit 916a3d2

Please sign in to comment.