Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions third_party/xla_client/xla_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,17 @@ StatusOr<string> GetComputationHloText(const XlaComputation& computation) {

void ReportComputationError(
const Status& status,
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations) {
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations,
tensorflow::gtl::ArraySlice<const Shape* const> output_shapes) {
std::stringstream ss;
for (size_t i = 0; i < computations.size(); ++i) {
string hlo_text = GetComputationHloText(*computations[i]).ValueOrDie();
MaybeSaveHloGraph(hlo_text, i);
ss << ">>> Dumping Computation " << i << "\n";
ss << hlo_text << "\n";
if (i < output_shapes.size() && output_shapes[i] != nullptr) {
ss << "OutputShape: " << *output_shapes[i] << "\n\n";
}
}
ss << "StackTrace:\n" << tensorflow::CurrentStackTrace() << "\n";
ss << "Status: " << status << "\n";
Expand All @@ -75,9 +79,10 @@ void ReportComputationError(

void CheckComputationStatus(
const Status& status,
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations) {
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations,
tensorflow::gtl::ArraySlice<const Shape* const> output_shapes) {
if (!status.ok()) {
ReportComputationError(status, computations);
ReportComputationError(status, computations, output_shapes);
}
}

Expand Down
6 changes: 4 additions & 2 deletions third_party/xla_client/xla_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ StatusOr<string> GetComputationHloText(const XlaComputation& computation);

void ReportComputationError(
const Status& status,
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations);
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations,
tensorflow::gtl::ArraySlice<const Shape* const> output_shapes);

// Checks whether an action on the given computation generated an error, and if
// that was the case, emit error and computations HLO text.
void CheckComputationStatus(
const Status& status,
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations);
tensorflow::gtl::ArraySlice<const XlaComputation* const> computations,
tensorflow::gtl::ArraySlice<const Shape* const> output_shapes);

size_t ShapeHash(const Shape& shape);

Expand Down
16 changes: 11 additions & 5 deletions third_party/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,12 @@ void XrtComputationClient::CheckCompileStatus(
const SessionWork& session_work) {
if (!status.ok()) {
std::vector<const XlaComputation*> computations;
std::vector<const Shape*> output_shapes;
for (auto li : session_work.index_mapping) {
computations.push_back(&instances[li].computation);
output_shapes.push_back(instances[li].output_shape);
}
util::ReportComputationError(status, computations);
util::ReportComputationError(status, computations, output_shapes);
}
}

Expand All @@ -454,7 +456,7 @@ XrtComputationClient::ExecuteComputation(
std::vector<tensorflow::Tensor> outputs;
util::CheckComputationStatus(
session->session()->Run(feed_inputs, {exec_ops.front()}, &outputs),
{&computation.computation()});
{&computation.computation()}, {&computation.program_shape().result()});
XLA_CHECK_EQ(outputs.size(), 1);

return GetComputationResults(outputs[0], computation.program_shape().result(),
Expand Down Expand Up @@ -518,14 +520,17 @@ XrtComputationClient::RunComputations(
auto session_runner = [&, this, session]() {
std::vector<tensorflow::Output> exec_nodes;
std::vector<const XlaComputation*> xla_computations;
std::vector<const Shape*> output_shapes;
for (auto replica : replicas) {
exec_nodes.push_back(exec_ops[replica]);
xla_computations.push_back(&computations[replica]->computation());
output_shapes.push_back(
&computations[replica]->program_shape().result());
}
std::vector<tensorflow::Tensor> outputs;
util::CheckComputationStatus(
session->session()->Run(feed_inputs, exec_nodes, &outputs),
xla_computations);
xla_computations, output_shapes);
XLA_CHECK_EQ(outputs.size(), exec_nodes.size());

for (size_t i = 0; i < outputs.size(); ++i) {
Expand Down Expand Up @@ -634,7 +639,7 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::ExecuteChainedXrt(
std::vector<tensorflow::Tensor> outputs;
util::CheckComputationStatus(
session->session()->Run(feed_inputs, {cached_node.outputs[0]}, &outputs),
{});
{}, {});
XLA_CHECK_EQ(outputs.size(), 1);

std::vector<DataPtr> results;
Expand Down Expand Up @@ -692,7 +697,8 @@ XrtComputationClient::ExecuteChainedSplit(
std::vector<tensorflow::Tensor> outputs;
util::CheckComputationStatus(
session->session()->Run(feed_inputs, {exec_ops.front()}, &outputs),
{&op.computation->computation()});
{&op.computation->computation()},
{&op.computation->program_shape().result()});
XLA_CHECK_EQ(outputs.size(), 1);
ops_outputs[i] = GetComputationResults(
outputs[0], op.computation->program_shape().result(),
Expand Down