Skip to content

Commit

Permalink
[xla] functional_hlo_runner: fix handling of PJRT client with 0 memor…
Browse files Browse the repository at this point in the history
…y spaces

PiperOrigin-RevId: 631512356
  • Loading branch information
cota authored and tensorflower-gardener committed May 7, 2024
1 parent 15f8f10 commit 8fad1c4
Showing 1 changed file with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,18 @@ FunctionalHloRunner::CopyArgumentsToDevice(
TF_RET_CHECK(!shape.IsTuple()) << "Param tuple without flattened_arguments";
return non_tuple_memory_space(shape);
};
auto buffer_from_host_literal = [&client, &argument_memory_space](
const HloModule* module,
PjRtDevice* device, int arg_i,
const Literal& literal)
-> absl::StatusOr<std::unique_ptr<PjRtBuffer>> {
if (client.memory_spaces().empty()) {
return client.BufferFromHostLiteral(literal, device);
}
TF_ASSIGN_OR_RETURN(PjRtMemorySpace * memory_space,
argument_memory_space(module, device, arg_i));
return client.BufferFromHostLiteral(literal, memory_space);
};

absl::Span<const PjRtLoadedExecutable::LogicalDeviceIds>
addressable_device_logical_ids =
Expand Down Expand Up @@ -1321,10 +1333,9 @@ FunctionalHloRunner::CopyArgumentsToDevice(
LOG(INFO) << "device_id=" << curr_device_id
<< ", input = " << literal.ToString();
}
TF_ASSIGN_OR_RETURN(PjRtMemorySpace * memory_space,
argument_memory_space(module, curr_device, arg_i));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> argument_buffer,
client.BufferFromHostLiteral(literal, memory_space));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> argument_buffer,
buffer_from_host_literal(module, curr_device, arg_i, literal));
argument_buffers[i].push_back(std::move(argument_buffer));
}
}
Expand Down

0 comments on commit 8fad1c4

Please sign in to comment.