Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 202995903 #20483

Merged
merged 14 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ py_test(
],
)

tf_xla_py_test(
name = "adadelta_test",
size = "medium",
srcs = ["adadelta_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
)

tf_xla_py_test(
name = "adagrad_test",
size = "small",
Expand Down
134 changes: 134 additions & 0 deletions tensorflow/compiler/tests/adadelta_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Adadelta Optimizer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adadelta


class AdadeltaOptimizerTest(xla_test.XLATestCase):

def testBasic(self):
num_updates = 4 # number of ADADELTA steps to perform
for dtype in self.float_types:
with self.test_session(), self.test_scope():
for grad in [0.2, 0.1, 0.01]:
for lr in [1.0, 0.5, 0.1]:
var0_init = [1.0, 2.0]
var1_init = [3.0, 4.0]
var0 = resource_variable_ops.ResourceVariable(
var0_init, dtype=dtype)
var1 = resource_variable_ops.ResourceVariable(
var1_init, dtype=dtype)

grads = constant_op.constant([grad, grad], dtype=dtype)

accum = 0.0
accum_update = 0.0

# ADADELTA gradient optimizer
rho = 0.95
epsilon = 1e-8
adadelta_opt = adadelta.AdadeltaOptimizer(
learning_rate=lr, rho=rho, epsilon=epsilon)
adadelta_update = adadelta_opt.apply_gradients(
zip([grads, grads], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
opt_vars = adadelta_opt.variables()
self.assertStartsWith(opt_vars[0].name, var0._shared_name)
self.assertStartsWith(opt_vars[1].name, var0._shared_name)
self.assertStartsWith(opt_vars[2].name, var1._shared_name)
self.assertStartsWith(opt_vars[3].name, var1._shared_name)
self.assertEqual(4, len(opt_vars))
# Assign slots
slot = [None] * 2
slot_update = [None] * 2
self.assertEqual(["accum", "accum_update"],
adadelta_opt.get_slot_names())
slot[0] = adadelta_opt.get_slot(var0, "accum")
self.assertEquals(slot[0].get_shape(), var0.get_shape())
self.assertFalse(slot[0] in variables.trainable_variables())

slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
self.assertFalse(slot_update[0] in variables.trainable_variables())

slot[1] = adadelta_opt.get_slot(var1, "accum")
self.assertEquals(slot[1].get_shape(), var1.get_shape())
self.assertFalse(slot[1] in variables.trainable_variables())

slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
self.assertFalse(slot_update[1] in variables.trainable_variables())

# Fetch params to validate initial values
self.assertAllClose(var0_init, self.evaluate(var0))
self.assertAllClose(var1_init, self.evaluate(var1))

update = [None] * num_updates
tot_update = 0
for step in range(num_updates):
# Run adadelta update for comparison
self.evaluate(adadelta_update)

# Perform initial update without previous accum values
accum = accum * rho + (grad**2) * (1 - rho)
update[step] = (
np.sqrt(accum_update + epsilon) *
(1. / np.sqrt(accum + epsilon)) * grad)
accum_update = (
accum_update * rho + (update[step]**2) * (1.0 - rho))
tot_update += update[step] * lr

# Check that the accumulators have been updated
for slot_idx in range(2):
self.assertAllCloseAccordingToType(
np.array([accum, accum], dtype=dtype),
self.evaluate(slot[slot_idx]),
rtol=1e-5)

self.assertAllCloseAccordingToType(
np.array([accum_update, accum_update], dtype=dtype),
self.evaluate(slot_update[slot_idx]),
rtol=1e-5)

# Check that the parameters have been updated
self.assertAllCloseAccordingToType(
np.array(
[var0_init[0] - tot_update, var0_init[1] - tot_update],
dtype=dtype),
self.evaluate(var0),
rtol=1e-5)

self.assertAllCloseAccordingToType(
np.array(
[var1_init[0] - tot_update, var1_init[1] - tot_update],
dtype=dtype),
self.evaluate(var1),
rtol=1e-5)


if __name__ == "__main__":
test.main()
69 changes: 69 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/training_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,5 +457,74 @@ class ResourceApplyFtrlV2 : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
ResourceApplyFtrlV2);

class ResourceApplyAdadelta : public XlaOpKernel {
public:
explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}

void Compile(XlaOpKernelContext* ctx) override {
TensorShape var_shape, accum_shape, accum_update_shape;
xla::XlaOp var, accum, accum_update;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape,
&accum_update));

TensorShape lr_shape = ctx->InputShape(3);
TensorShape rho_shape = ctx->InputShape(4);
TensorShape epsilon_shape = ctx->InputShape(5);
TensorShape grad_shape = ctx->InputShape(6);

OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));

OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
errors::InvalidArgument("rho is not a scalar: ",
rho_shape.DebugString()));

OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon_shape.DebugString()));

OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
"var and accum do not have the same shape",
var_shape.DebugString(), " ", accum_shape.DebugString()));

OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
errors::InvalidArgument(
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));

xla::XlaOp lr = ctx->Input(3);
xla::XlaOp rho = ctx->Input(4);
xla::XlaOp epsilon = ctx->Input(5);
xla::XlaOp grad = ctx->Input(6);

xla::XlaBuilder* b = ctx->builder();
xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5);
xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);

accum = rho * accum + (one - rho) * xla::Pow(grad, two);
xla::XlaOp update = xla::Pow(accum_update + epsilon, half) *
xla::Pow(accum + epsilon, neg_half) * grad;
accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two);
var = var - update * lr;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update));
}

private:
DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes),
ResourceApplyAdadelta);

} // namespace
} // namespace tensorflow
40 changes: 29 additions & 11 deletions tensorflow/compiler/xla/client/xla_client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ int64 GetUniqueId() {
// computation.
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kAfterAll:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
Expand Down Expand Up @@ -1586,6 +1587,7 @@ XlaOp XlaBuilder::Reduce(
TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());

TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferReduceShape(
operand_shape, init_shape, dimensions_to_reduce,
Expand Down Expand Up @@ -1839,16 +1841,24 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,

void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Send HLO takes two operands: a data operand and a token. Generate the
// token to pass into the send.
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));

// Send instruction produces a tuple of {aliased operand, U32 context}.
HloInstructionProto send_instr;
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
*instr.mutable_shape() =
*send_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(
XlaOp send,
AddInstruction(std::move(instr), HloOpcode::kSend, {operand}));
send_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp send,
AddInstruction(std::move(send_instr), HloOpcode::kSend,
{operand, token}));

HloInstructionProto send_done_instr;
*send_done_instr.mutable_shape() = ShapeUtil::MakeNil();
Expand All @@ -1860,14 +1870,22 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {

XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Recv HLO takes a single token operand. Generate the token to pass into
// the Recv and RecvDone instructions.
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));

// Recv instruction produces a tuple of {receive buffer, U32 context}.
*instr.mutable_shape() =
HloInstructionProto recv_instr;
*recv_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp recv,
AddInstruction(std::move(instr), HloOpcode::kRecv, {}));
recv_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
HloOpcode::kRecv, {token}));

HloInstructionProto recv_done_instr;
*recv_done_instr.mutable_shape() = shape;
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/compiler/xla/service/buffer_liveness_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,12 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({}));
auto recv = builder.AddInstruction(
HloInstruction::CreateRecv(vec_, /*channel_id=*/0));
HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
auto send = builder.AddInstruction(
HloInstruction::CreateSend(recv_done, /*channel_id=*/1));
HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));

auto module = CreateNewModule();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);

auto* true_computation = conditional->true_computation();
auto* token =
true_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
true_computation->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
/*channel_id=*/0));
token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
Expand All @@ -133,8 +135,10 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);

auto* true_computation = conditional->true_computation();
auto* token =
true_computation->AddInstruction(HloInstruction::CreateAfterAll({}));
auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0));
ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/compiler/xla/service/hlo_constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
// Skip Constant, Parameter, Reduce operation.
// Skip Constant, Parameter, Reduce, and AfterAll operation.
// TODO(b/35975797): Enable Reduce operation once arbitrary computation
// are supported by the evaluator.
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
// TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
// operand in which case constant folding will be impossible and this
// special case is not necessary.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple ||
instruction->opcode() == HloOpcode::kReduce) {
instruction->opcode() == HloOpcode::kReduce ||
instruction->opcode() == HloOpcode::kAfterAll) {
continue;
}
// Skip instructions with non-constant operands.
Expand Down