Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push from google. #1

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions docker/dev.dockerfile
@@ -1,21 +1,21 @@
# Commands:
#
# device="gpu"; (Set device to cpu to build and run CPU only docker)
# device="gpu"; (Leave empty to build and run CPU only docker)
#
# sudo docker build --tag tensorflow:lingvo --build-arg bazel_version=0.13.1 $(test "$device" = "cpu" && echo "--build-arg base_image=ubuntu:16.04") - < lingvo/docker/dev.dockerfile
# sudo docker build --tag tensorflow:lingvo --build-arg bazel_version=0.13.1 $(test "$device" = "gpu" && echo "--build-arg base_image=nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04") - < lingvo/docker/dev.dockerfile
# sudo docker run --rm $(test "$device" = "gpu" && echo "--runtime=nvidia") -it -v /tmp/lingvo:/tmp/lingvo -v ${HOME}/.gitconfig:/home/${USER}/.gitconfig:ro --name lingvo tensorflow:lingvo bash

# TODO(drpng): upgrade to latest (17.10)
ARG gpu_base_image="nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04"
ARG base_image=$gpu_base_image
ARG cpu_base_image="ubuntu:16.04"
ARG base_image=$cpu_base_image
FROM $base_image

LABEL maintainer="Patrick Nguyen <drpng@google.com>"

# Re-declare args because the args declared before FROM can't be used in any
# instruction after a FROM.
ARG gpu_base_image="nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04"
ARG base_image=$gpu_base_image
ARG cpu_base_image="ubuntu:16.04"
ARG base_image=$cpu_base_image

# Pick up some TF dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
Expand Down Expand Up @@ -60,7 +60,7 @@ RUN pip --no-cache-dir install \
&& \
python -m ipykernel.kernelspec

RUN pip install tf-nightly$(test "$base_image" = "$gpu_base_image" && echo "-gpu")
RUN pip install tf-nightly$(test "$base_image" != "$cpu_base_image" && echo "-gpu")

ARG bazel_version
# This is to install bazel, for development purposes.
Expand Down
175 changes: 145 additions & 30 deletions lingvo/core/attention.py
Expand Up @@ -45,6 +45,9 @@ def _ApplyAttentionDropout(params, x, step_state=None, prng_seed=None):
Returns:
A Tensor with the same shape as `x`.
"""
if params.atten_dropout_prob == 0:
return x

if params.atten_dropout_deterministic:
if isinstance(step_state, py_utils.NestedMap):
assert 'global_step' in step_state, step_state.DebugString()
Expand Down Expand Up @@ -525,6 +528,37 @@ def EncodeSource(src_w, vecs, ctxs):

self._encode_source = EncodeSource

def PackSource(self,
theta,
source_vecs,
source_contexts,
source_padding,
source_segment_id=None):
"""Packs source vectors. Does not change attention state.

Unlike the InitForSource API above, this API takes in a single tensor
representing the entire sequence.

Args:
theta: A nested map object containing weights' values of this
layer and its children layers.
source_vecs: A single tensor of shape [time, batch_size, source_dim].
source_contexts: A single tensor of shape [time, batch_size, some_dim].
source_padding: A tensor of shape [time, batch_size].
source_segment_id: A tensor of shape [time, batch_size].

Returns:
Concated source vectors, concated source contexts, and source paddings.
"""
with tf.name_scope(self.params.name):
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)

(concated_source_vecs, concated_source_contexts) = (
self._encode_source(theta.source_var, source_vecs, source_contexts))
return (concated_source_vecs, concated_source_contexts, source_padding,
source_segment_id)

def InitForSourcePacked(self,
theta,
source_vecs,
Expand All @@ -547,14 +581,11 @@ def InitForSourcePacked(self,
Returns:
Concated source vectors, concated source contexts, and source paddings.
"""
with tf.name_scope(self.params.name):
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
self._source_init_done = True
self._source_padding = source_padding
self._source_segment_id = source_segment_id
(self._concated_source_vecs, self._concated_source_contexts) = (
self._encode_source(theta.source_var, source_vecs, source_contexts))
self._source_init_done = True

(self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id) = self.PackSource(
theta, source_vecs, source_contexts, source_padding, source_segment_id)
return (self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id)

Expand Down Expand Up @@ -807,6 +838,38 @@ def Atten(per_dim_scale, source_padding, source_segment_id,

self._ctx_vec = Atten

def PackSource(self,
theta,
source_vecs,
source_contexts,
source_padding,
source_segment_id=None):
"""Packs source vectors. Does not change attention state.

Args:
theta: A nested map object containing weights' values of this
layer and its children layers.
source_vecs: A tensor of shape [time, source_batch, source_dim].
source_contexts: A tensor of shape [time, source_batch, context_dim].
source_padding: A tensor of shape [time, source_batch].
source_segment_id: A tensor of shape [time, source_batch].

Returns:
A tuple (concated_source_vecs, concated_source_contexts, source_padding),
where concated_source_vecs is a tensor of shape [time, batch_size,
hidden_dim], concated_source_contexts is a tensor of shape [batch_size,
time, some_dim] and source_padding is a tensor of shape [time,
batch_size]. Note the mismatch between concated_source_vecs and
concated_source_contexts. In concated_source_vecs, time is the first dim,
while it is the second dim in concated_source_contexts.
"""
concated_source_vecs = tf.identity(source_vecs)
concated_source_contexts = tf.transpose(source_contexts, [1, 0, 2])
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
return (concated_source_vecs, concated_source_contexts, source_padding,
source_segment_id)

def InitForSourcePacked(self,
theta,
source_vecs,
Expand All @@ -833,13 +896,9 @@ def InitForSourcePacked(self,
while it is the second dim in concated_source_contexts.
"""
self._source_init_done = True
self._concated_source_vecs = tf.identity(source_vecs)
self._concated_source_contexts = tf.transpose(source_contexts, [1, 0, 2])
self._source_padding = source_padding
if source_segment_id is None:
self._source_segment_id = tf.zeros_like(source_padding)
else:
self._source_segment_id = source_segment_id
(self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id) = self.PackSource(
theta, source_vecs, source_contexts, source_padding, source_segment_id)
return (self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id)

Expand Down Expand Up @@ -1067,8 +1126,41 @@ def InitForSourcePacked(self,
num_heads] and source_segment_id is a tensor of shape
[source_seq_len, batch_size * num_heads].
"""
p = self.params
self._source_init_done = True
(self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id) = self.PackSource(
theta, source_vecs, source_contexts, source_padding, source_segment_id)
return (self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id)

def PackSource(self,
theta,
source_vecs,
source_contexts,
source_padding,
source_segment_id=None):
"""Packs source vectors. Does not change attention state.

Args:
theta: A nested map object containing weights' values of this
layer and its children layers.
source_vecs: A tensor of shape [time, source_batch, source_dim].
source_contexts: A tensor of shape [time, source_batch, context_dim].
source_padding: A tensor of shape [time, source_batch].
source_segment_id: A tensor of shape [time, source_batch].

Returns:
(concated_source_vecs, concated_source_contexts, source_padding,
source_segment_id) tuple where concated_source_vecs is a tensor of shape
[source_seq_len, batch_size * num_heads, orig_source_dim / num_heads],
concated_source_contexts is a tensor of shape [source_batch_size *
num_heads, source_seq_len, orig_context_dim / num_heads],
source_padding is a tensor of shape [source_seq_len, batch_size *
num_heads] and source_segment_id is a tensor of shape
[source_seq_len, batch_size * num_heads].
"""

p = self.params
if not p.enable_source_proj:
assert p.source_dim == p.hidden_dim
if not p.enable_query_proj:
Expand Down Expand Up @@ -1097,7 +1189,6 @@ def InitForSourcePacked(self,
source_projected = tf.reshape(
source_projected,
[time_steps, batch_size * num_heads, hidden_depth // num_heads])
self._source_seq_len = tf.shape(source_padding)[0]
if p.use_source_vec_as_attention_value:
source_contexts_reshaped = source_projected
else:
Expand Down Expand Up @@ -1125,13 +1216,12 @@ def InitForSourcePacked(self,
tf.reshape(source_segment_id, [time_steps, batch_size, 1]),
[1, 1, num_heads]), [time_steps, batch_size * num_heads])

(self._concated_source_vecs, self._concated_source_contexts,
self._source_padding,
self._source_segment_id) = self.atten.InitForSourcePacked(
(concated_source_vecs, concated_source_contexts,
source_padding, source_segment_id) = self.atten.PackSource(
theta.atten, source_projected, source_contexts_reshaped,
source_padding_replicated, source_segment_id_repl)
return (self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id)
return (concated_source_vecs, concated_source_contexts, source_padding,
source_segment_id)

def ExtendSourcePacked(self, theta, new_source_vecs, new_source_contexts,
new_source_paddings, new_source_segment_ids,
Expand Down Expand Up @@ -1970,6 +2060,34 @@ def EncodeSource(src_w, vecs, ctxs):

self._encode_source = EncodeSource

def PackSource(self,
theta,
source_vecs,
source_contexts,
source_padding,
source_segment_id=None):
"""Packs source vectors. Does not change attention state.

Args:
theta: A nested map object containing weights' values of this
layer and its children layers.
source_vecs: A single tensor of shape [time, batch_size, source_dim].
source_contexts: A single tensor of shape [time, batch_size, some_dim].
source_padding: A tensor of shape [time, batch_size].
source_segment_id: A tensor of shape [time, batch_size].

Returns:
Concated source vectors, concated source contexts, and source paddings.
"""
with tf.name_scope(self.params.name):
(concated_source_vecs, concated_source_contexts) = (
self._encode_source(theta.source_var, source_vecs, source_contexts))
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)

return (concated_source_vecs, concated_source_contexts, source_padding,
source_segment_id)

def InitForSourcePacked(self,
theta,
source_vecs,
Expand All @@ -1989,14 +2107,11 @@ def InitForSourcePacked(self,
Returns:
Concated source vectors, concated source contexts, and source paddings.
"""
with tf.name_scope(self.params.name):
self._source_init_done = True
self._source_padding = source_padding
(self._concated_source_vecs, self._concated_source_contexts) = (
self._encode_source(theta.source_var, source_vecs, source_contexts))
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
self._source_segment_id = source_segment_id
self._source_init_done = True
(self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id) = self.PackSource(
theta, source_vecs, source_contexts, source_padding, source_segment_id)

return (self._concated_source_vecs, self._concated_source_contexts,
self._source_padding, self._source_segment_id)

Expand Down
30 changes: 30 additions & 0 deletions lingvo/core/layers.py
Expand Up @@ -1276,6 +1276,36 @@ def FProp(self, theta, inputs):
return inputs


class DeterministicDropoutLayer(base_layer.LayerBase):
"""Apply dropout during trainig."""

@classmethod
def Params(cls):
p = super(DeterministicDropoutLayer, cls).Params()
p.Define('keep_prob', 1.0, 'Keep probability.')
p.Define('seed', None, 'Random seed')
return p

def FProp(self, theta, inputs):
"""Apply dropout to inputs.

Args:
theta: A nested map object containing weights' values of this
layer and its children layers.
inputs: The inputs tensor.
Returns:
inputs with dropout applied at training time.
"""
p = self.params
if p.keep_prob < 1.0 and not p.is_eval:
return py_utils.DeterministicDropout(
inputs,
self.params.keep_prob,
py_utils.GetOpSeedPair(op_seed=self.params.seed))
else:
return inputs


class LayerNorm(base_layer.LayerBase):
"""Layer normalization.

Expand Down
29 changes: 29 additions & 0 deletions lingvo/core/layers_test.py
Expand Up @@ -1535,5 +1535,34 @@ def testLayerNormBProp(self):
self.assertAllClose(sg, ng, rtol=1e-02, atol=1e-02)


class DeterministicDropoutTest(tf.test.TestCase):

def testDeterministicDropoutLayer(self):
params = layers.DeterministicDropoutLayer.Params().Set(keep_prob=0.7)
params.name = 'drop'
dropout = layers.DeterministicDropoutLayer(params)

x = tf.ones([4, 6], dtype=tf.float32)

with self.test_session() as sess:
graph = tf.get_default_graph()
global_step = py_utils.GetOrCreateGlobalStep()
tf.assign(global_step, tf.constant(1234, dtype=tf.int64))
graph.add_to_collection('step_seed', tf.constant(5678, dtype=tf.int64))

x = dropout.FProp(dropout.theta, x)
tf.global_variables_initializer().run()
x_val = sess.run(x)
print(np.array_repr(x_val))
# pyformat: disable
self.assertAllClose(
[[1.0 / 0.7, 0.0000000, 1.0 / 0.7, 1.0 / 0.7, 1.0 / 0.7, 1.0 / 0.7],
[1.0 / 0.7, 1.0 / 0.7, 1.0 / 0.7, 1.0 / 0.7, 1.0 / 0.7, 1.0 / 0.7],
[1.0 / 0.7, 1.0 / 0.7, 1.0 / 0.7, 0.0000000, 1.0 / 0.7, 1.0 / 0.7],
[0.0000000, 1.0 / 0.7, 0.0000000, 0.0000000, 1.0 / 0.7, 0.0000000]],
x_val)
# pyformat: enable


if __name__ == '__main__':
tf.test.main()