Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 195301913 #19072

Merged
merged 64 commits into from
May 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
ba1c33f
ArraysExtraInfo: Add name_regexp field and regexp name matching.
tensorflower-gardener May 2, 2018
72fd2b8
Use experimental auto_sharding in multi worker dataset.
guptapriya May 2, 2018
22eed54
Automated g4 rollback of changelist 195091587
tensorflower-gardener May 2, 2018
d6d4355
Add Name String to GraphOptimizationPass and Log Registered Passes
May 2, 2018
c439434
Instantiate SwapDimension1And2InTensor3 for Eigen::half
tensorflower-gardener May 2, 2018
1f47bbd
Optimized the analysis of rank and size operations.
benoitsteiner May 2, 2018
e408e81
Internal-only change.
May 2, 2018
489640a
Fix some nits in cpu_literal_caching_test that I noticed after submis…
May 2, 2018
5597237
Initialize all members of CollectiveParams at construction time to avoid
tensorflower-gardener May 2, 2018
c08bf79
Renames _regression_head_with_mean_squared_error_loss to _regression_…
tensorflower-gardener May 2, 2018
1cc2258
Automated g4 rollback of changelist 194981511
hawkinsp May 2, 2018
156483f
[XLA:GPU] Unroll unfused elementwise op kernels.
d0k May 2, 2018
9b6cba1
Internal-only change.
May 2, 2018
1ea4a77
Replaced calls to tensorflow::StringPiece::ToString with std::string …
tensorflower-gardener May 2, 2018
3c0afb1
Turn on two half precision tests for GPU.
bixia1 May 2, 2018
ce0ef22
docs: Link to the appropriately branched version of the live colab no…
asimshankar May 2, 2018
262b176
Added support for packing of symbolic shapes
benoitsteiner May 2, 2018
ad491ad
[XLA] Redesign: Dump HloSnapshot in local service as well. And suppor…
tensorflower-gardener May 2, 2018
b182fd8
Increasing test size to reflect recent additions and prevent test tim…
jsimsa May 2, 2018
79f6d50
Fix tsan failure in batch_dataset_op_test.
saxenasaurabh May 2, 2018
2706eeb
Re-enabling a test.
jsimsa May 2, 2018
f9e8a75
[XLA] Add new optimization that sinks constants into while loop bodies
May 2, 2018
bd6c00a
Fix a bug in create_python_api.py
tensorflower-gardener May 2, 2018
8f61038
Updated ABSL to latest version in workspace.bzl.
tensorflower-gardener May 2, 2018
08fec96
Fix support for batch_normalization with mixed precision
tensorflower-gardener May 2, 2018
d030ea9
Add steps_per_run to LoggingTensorHook and StepCounterHook and other …
chrisying May 2, 2018
1d92d50
[TF:XLA] Bump open source llvm revision to r331338
May 2, 2018
9180cc2
[XLA] BF16 propagation: do not change if propagation is confined insi…
ukoxyz May 2, 2018
4704ae7
Optimize LogicalOr and LogicalAnd with all true or false inputs:
tensorflower-gardener May 2, 2018
5e9e696
Replaced calls to tensorflow::StringPiece::ToString with std::string …
hawkinsp May 2, 2018
4c256cd
Add prefetching to one device distribution strategy.
guptapriya May 2, 2018
85566b2
Adding a version of rolled triangular solver code for the right-multi…
tensorflower-gardener May 2, 2018
0237e86
Adds the EvalListener support for run_local.
May 2, 2018
49f2afe
Allow evaluation and prediction through warm-starting (no current che…
tensorflower-gardener May 2, 2018
30927ec
Mark all nodes processed by AddOpsRewrite/MinBCast stages with a tag.
tensorflower-gardener May 2, 2018
1f4efb7
Add RNNEstimator which takes in arbitrary heads.
tensorflower-gardener May 2, 2018
c7a5787
Enable reshape of _ScopedAllocatorConcat output.
dubey May 2, 2018
8a022d3
Allow `Layer.add_loss` to receive non-tensor; fixes error triggered w…
fchollet May 2, 2018
dde83d4
Handle negative values when slicing symbolic shapes
benoitsteiner May 3, 2018
a1ef905
BufferValue is a new base class for LogicalBuffer and HloValue. This …
fdxmw May 3, 2018
7833890
Add a collect_trace option to run_op_benchmark for cases when callers…
annarev May 3, 2018
8f0a90b
Add complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
rryan May 3, 2018
db329cf
Automated g4 rollback of changelist 195091587
tensorflower-gardener May 3, 2018
1a4f746
Remove duplicated emplace_back floor operator.
tensorflower-gardener May 3, 2018
2b1a03c
Compute shape of segment_ids dynamically in _unsorted_segment_N
tensorflower-gardener May 3, 2018
223be4a
Replaced calls to tensorflow::StringPiece::ToString with std::string …
tensorflower-gardener May 3, 2018
ebad5d6
Update ops-related pbtxt files.
tensorflower-gardener May 3, 2018
71f97c8
Fix tf.variable_scope unique name after entering root scope
Mostafa-Alaa May 3, 2018
f600046
Expose Interpreter to tensorflow.contrib.lite
aselle May 3, 2018
985351d
Simplify getter and setter method for GraphOptimizationPass::name_
May 3, 2018
283e8fe
Use tensorflow size to determine number of elements instead of the st…
May 3, 2018
a88a7e3
Post-transform pass to dedupe large constant arrays.
tensorflower-gardener May 3, 2018
85a4759
[XLA] Redesign: add ExecuteGraph to grpc service.
tensorflower-gardener May 3, 2018
a16ba4f
Do not delegate temporary tensors to NNAPI.
tensorflower-gardener May 3, 2018
e585463
Simplify file reading and support SavedModel.
May 3, 2018
4b767a8
Small fix for an eager colab notebook.
allenlavoie May 3, 2018
775d1c0
[TF:XLA] Bump open source llvm revision to r331442
May 3, 2018
ceda304
Enable unary chain hoisting optimization for concat/split/splitv by d…
tensorflower-gardener May 3, 2018
fded0f9
Change all std::bind usages in GCS to lambdas. Fix the wrong #define …
rxsang May 3, 2018
278e68c
Simplified the implementation of shape_n since the optimized code pat…
benoitsteiner May 3, 2018
41dcb67
Fix bugs in model pruner.
tensorflower-gardener May 3, 2018
5a64e60
Checkpointable: Utilities to read object metadata
allenlavoie May 3, 2018
7529268
tfdbg + tflearn: replace deprecated classes and methods in example & …
caisq May 3, 2018
ff2bf27
Merge commit for internal changes
caisq May 4, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2097,7 +2097,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,

for (int i = 0; i < size; ++i) {
TensorId id = results.missing_unused_input_map_keys[i];
tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
tf_results->missing_unused_key_names_data.push_back(std::string(id.first));
tf_results->missing_unused_key_names[i] =
tf_results->missing_unused_key_names_data.back().c_str();
tf_results->missing_unused_key_indexes[i] = id.second;
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/c/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1368,15 +1368,15 @@ TEST(CAPI, SavedModel) {
}

const tensorflow::string input_op_name =
tensorflow::ParseTensorName(input_name).first.ToString();
std::string(tensorflow::ParseTensorName(input_name).first);
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);

const tensorflow::string output_op_name =
tensorflow::ParseTensorName(output_name).first.ToString();
std::string(tensorflow::ParseTensorName(output_name).first);
TF_Operation* output_op =
TF_GraphOperationByName(graph, output_op_name.c_str());
ASSERT_TRUE(output_op != nullptr);
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/c/checkpoint_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() {
const auto& slice_proto = entry.slices(i);
CHECK(filtered_keys
.insert(EncodeTensorNameSlice(
v2_reader_->key().ToString() /* full var's name */,
std::string(v2_reader_->key()) /* full var's name */,
TensorSlice(slice_proto)))
.second);
}
Expand All @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() {
new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
string key = v2_reader_->key().ToString();
string key = std::string(v2_reader_->key());
(*var_to_shape_map)[key] = TensorShape(entry.shape());
(*var_to_data_type_map)[key] = DataType(entry.dtype());
}
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/cc/framework/cc_op_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ string AvoidCPPKeywords(StringPiece name) {
if (IsCPPKeyword(name)) {
return strings::StrCat(name, "_");
}
return name.ToString();
return std::string(name);
}

void InferArgAttributes(const OpDef::ArgDef& arg,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/cc/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ std::unordered_set<string> Scope::Impl::GetColocationConstraints(
for (const string& entry : node_constraints) {
StringPiece s(entry);
if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
current_constraints.insert(s.ToString());
current_constraints.insert(std::string(s));
}
}
} else {
Expand Down
17 changes: 11 additions & 6 deletions tensorflow/compiler/tf2xla/kernels/fft_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,24 @@ class FFTOp : public GenericFftOp {
explicit FFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {}
};
REGISTER_XLA_OP(Name("FFT"), FFTOp<1>);
REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>);
REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>);
REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>);
REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64),
FFTOp<2>);
REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64),
FFTOp<3>);

template <int FFTRank>
class IFFTOp : public GenericFftOp {
public:
explicit IFFTOp(OpKernelConstruction* ctx)
: GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {}
};
REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>);
REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>);
REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>);
REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64),
IFFTOp<1>);
REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64),
IFFTOp<2>);
REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64),
IFFTOp<3>);

template <int FFTRank>
class RFFTOp : public GenericFftOp {
Expand Down
179 changes: 155 additions & 24 deletions tensorflow/compiler/tf2xla/lib/triangular_solve.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,13 +82,6 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
block_size);
}

// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) {
auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};

std::map<int, xla::XlaComputation> base_computations;
auto get_base_triangular_solve =
[&](int k) -> xla::StatusOr<xla::XlaComputation*> {
Expand Down Expand Up @@ -117,16 +110,21 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
"b");

// We use a left-looking subroutine on the block diagonal in some common
// cases, while falling back to a recursive call in unsupported cases. The
// left-looking subroutine is written with a While loop and so yields much
// faster compile times. Moreover, the left-looking variant can give
// higher performance on smaller (sub)problems.
// We use a left-looking or right-looking subroutine on the block diagonal
// in the lower=true cases, while falling back to a recursive call in
// others. The left-looking and right-looking subroutines are written with
// a While loop and so yields much faster compile times. Moreover, they
// can give higher performance on smaller (sub)problems.
if (left_side && lower) {
TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
b_param, transpose_a,
conjugate_a)
.status());
} else if (!left_side && lower) {
TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param,
b_param, transpose_a,
conjugate_a)
.status());
} else {
TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
left_side, lower, transpose_a,
Expand Down Expand Up @@ -169,7 +167,9 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
update = builder->Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
Expand Down Expand Up @@ -219,7 +219,9 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
update = builder->Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
Expand Down Expand Up @@ -268,7 +270,9 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
update = builder->Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
Expand Down Expand Up @@ -318,7 +322,9 @@ xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
update = builder->Div(b_slice, a_slice_conj);
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
Expand Down Expand Up @@ -371,11 +377,6 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
batch_dimensions.push_back(a_size);
}

auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) {
auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};

// The main computation is performed in a While loop.

// Allocate the output and set its first or last row,
Expand All @@ -391,7 +392,9 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
auto update = builder->Div(b_slice, maybe_conj(builder, a_slice));
TF_ASSIGN_OR_RETURN(auto a_slice_conj,
MaybeConjugate(builder, a_slice, conjugate_a));
auto update = builder->Div(b_slice, a_slice_conj);
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
}
Expand Down Expand Up @@ -493,7 +496,9 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
// body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a,
{i, i}, {1, 1}));
auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt));
TF_ASSIGN_OR_RETURN(auto a_elt_conj,
MaybeConjugate(bodyb.get(), a_elt, conjugate_a));
auto div_result = bodyb->Div(result_row, a_elt_conj);
TF_ASSIGN_OR_RETURN(body_out,
DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
div_result, {i, zero}));
Expand All @@ -513,4 +518,130 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
}

xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
const xla::XlaOp& a,
const xla::XlaOp& b,
bool transpose_a,
bool conjugate_a) {
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
const int64 ndims = xla::ShapeUtil::Rank(a_shape);

std::vector<int64> batch_dimensions;
for (int i = 0; i < ndims - 2; ++i) {
int64 a_size = a_shape.dimensions(i);
batch_dimensions.push_back(a_size);
}

// The main computation is performed in a While loop.
xla::XlaOp output = Zeros(builder, b_shape);

// Construct the initial loop carry tuple,
// if transpose_a:
// init = (0, output, a, b)
// else:
// init = (n-1, output, a, b)
std::vector<xla::Shape> tuple_shapes = {
// The loop iteration counter is a scalar, incremented each iteration.
xla::ShapeUtil::MakeShape(xla::S32, {}),
// The output has the shape of b, with one row updated each iteration.
b_shape,
// The coefficient matrix a is a loop invariant.
a_shape,
// The right-hand-side matrix b is a loop invariant.
b_shape};
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
auto init_i = builder->ConstantR0<int32>(transpose_a ? 0 : n - 1);
auto init = builder->Tuple({init_i, output, a, b});

// Construct the loop condition function,
// def cond_fun(loop_carry):
// i, output, a, b = loop_carry
// return i < n if transpose_a else i >= 0
std::unique_ptr<xla::XlaBuilder> condb =
builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond");
{
auto i = condb->GetTupleElement(
condb->Parameter(0, tuple_shape,
"TriangularSolveRightLookingWhileTuple"),
0);
if (transpose_a) {
condb->Lt(i, condb->ConstantR0<int32>(n));
} else {
condb->Ge(i, condb->ConstantR0<int32>(0));
}
}
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());

// Construct the loop body function,
// def body_fun(loop_carry):
// i, output, a, b = loop_carry
// if transpose_a:
// a_row = np.swapaxes(a[..., :, i:i+1], -1 -2)
// else:
// a_row = a[..., :, i:i+1]
// result_row = b[..., :, i:i+1] - np.matmul(output, a_row)
// output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
// if transpose_a:
// return (i - 1, output, a, b)
// else:
// return (i + 1, output, a, b)
// We have to do some extra FLOPs propagating zeros in the matrix multiply
// because we can't have the size of its arguments depend on the loop counter.
std::unique_ptr<xla::XlaBuilder> bodyb =
builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody");
{
auto input_tuple = bodyb->Parameter(
0, tuple_shape, "TriangularSolveRightLookingWhileTuple");

// i, output, a, b = loop_carry
auto i = bodyb->GetTupleElement(input_tuple, 0);
auto body_out = bodyb->GetTupleElement(input_tuple, 1);
auto body_a = bodyb->GetTupleElement(input_tuple, 2);
auto body_b = bodyb->GetTupleElement(input_tuple, 3);
auto zero = bodyb->ConstantR0<int32>(0);

// We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :,
// i:i+1]) But since we can't have intermediate array sizes depend on the
// loop counter, we instead exploit the fact that we initialized the output
// to all zeros and use that as zero-padding (doing unnecessary FLOPs).
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a,
/*transpose_x=*/false,
/*transpose_y=*/transpose_a,
/*conjugate_x=*/false,
/*conjugate_y=*/conjugate_a));
// result = b - np.matmul(output, a)
auto result = bodyb->Sub(body_b, b_update);
// result_row = result[..., :, i:i+1]
TF_ASSIGN_OR_RETURN(
auto result_row,
DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1}));

// body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a,
{i, i}, {1, 1}));
TF_ASSIGN_OR_RETURN(auto a_ii_conj,
MaybeConjugate(bodyb.get(), a_ii, conjugate_a));
auto div_result = bodyb->Div(result_row, a_ii_conj);
TF_ASSIGN_OR_RETURN(body_out,
DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
div_result, {zero, i}));

// if transpose_a:
// return (i + 1, body_out, a, b)
// else:
// return (i - 1, body_out, a, b)
auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? 1 : -1));
bodyb->Tuple({next_i, body_out, body_a, body_b});
}
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());

// Construct the While loop and return the result,
// return while_loop(cond_fun, body_fun, init)[1]
auto triangular_solve_left_looking_while = builder->While(cond, body, init);
return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
}

} // namespace tensorflow
6 changes: 6 additions & 0 deletions tensorflow/compiler/tf2xla/lib/triangular_solve.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
bool transpose_a,
bool conjugate_a);

xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
const xla::XlaOp& a,
const xla::XlaOp& b,
bool transpose_a,
bool conjugate_a);

} // namespace tensorflow

#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_
7 changes: 7 additions & 0 deletions tensorflow/compiler/tf2xla/lib/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,11 @@ xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
return builder->Transpose(x, permutation);
}

xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder,
const xla::XlaOp& x, bool conjugate) {
TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
auto perform_conj = shape.element_type() == xla::C64 && conjugate;
return perform_conj ? builder->Conj(x) : x;
}

} // namespace tensorflow
5 changes: 5 additions & 0 deletions tensorflow/compiler/tf2xla/lib/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
const xla::XlaOp& x);

// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder,
const xla::XlaOp& x, bool conjugate);

} // namespace tensorflow

#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_