Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609250166
  • Loading branch information
tensorflower-gardener committed Feb 22, 2024
1 parent 07cca9a commit b0a8809
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 25 deletions.
4 changes: 2 additions & 2 deletions tensorflow/compiler/jit/deadness_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ class DeadnessAnalysis {
friend class DeadnessAnalysis;
};

virtual tsl::StatusOr<DeadnessPredicate> GetPredicateFor(Node* n,
int oidx) const = 0;
virtual absl::StatusOr<DeadnessPredicate> GetPredicateFor(Node* n,
int oidx) const = 0;

// Prints out the internal state of this instance. For debugging purposes
// only.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/deadness_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ limitations under the License.
namespace tensorflow {
namespace {

tsl::StatusOr<bool> HasInputsWithMismatchingDeadness(
absl::StatusOr<bool> HasInputsWithMismatchingDeadness(
const DeadnessAnalysis& deadness_analysis, const Node& n) {
std::optional<DeadnessAnalysis::DeadnessPredicate> pred;
for (const Edge* edge : n.in_edges()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ uint64 Signature::Hash::operator()(const Signature& signature) const {
return h;
}

StatusOr<Signature> Signature::Build(
absl::StatusOr<Signature> Signature::Build(
const NameAttrList& function,
absl::Span<const XlaCompiler::Argument> args) {
Signature signature;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct DeviceCompilationClusterSignature {
string HumanString() const;

// Builds the signature for a compilation.
static StatusOr<DeviceCompilationClusterSignature> Build(
static absl::StatusOr<DeviceCompilationClusterSignature> Build(
const NameAttrList& function,
absl::Span<const XlaCompiler::Argument> args);
};
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/device_compilation_profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ DeviceCompilationProfiler::~DeviceCompilationProfiler() {
cluster_compile_stats_.clear();
}

StatusOr<DeviceCompilationProfiler::ClusterCompileStats>
absl::StatusOr<DeviceCompilationProfiler::ClusterCompileStats>
DeviceCompilationProfiler::GetCompileStats(const NameAttrList& function) const {
mutex_lock lock(mu_);

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/device_compilation_profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class DeviceCompilationProfiler : public ResourceBase {
};

// Returns the compilation statistics for the given cluster.
StatusOr<ClusterCompileStats> GetCompileStats(
absl::StatusOr<ClusterCompileStats> GetCompileStats(
const NameAttrList& function) const;

// Determines whether the cluster should be compiled. Creates and inserts an
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/jit/device_compiler_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ class DeviceCompilerClient {

// Serializes an available `executable` to string using `ClientType` and
// returns it.
virtual StatusOr<std::string> SerializeExecutable(
virtual absl::StatusOr<std::string> SerializeExecutable(
const ExecutableType& executable) = 0;

// Compiles `result` (HLO) to a serializable executable (eg.
// xla::AotCompilationResult) using `ClientType`, serializes it to string and
// returns it.
virtual StatusOr<std::string> BuildSerializedExecutable(
virtual absl::StatusOr<std::string> BuildSerializedExecutable(
const XlaCompiler::Options& options,
const XlaCompiler::CompilationResult& result) = 0;

Expand Down
12 changes: 7 additions & 5 deletions tensorflow/compiler/jit/device_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,16 @@ class DeviceInfoCache {
return names_[device.id()];
}

StatusOr<DeviceId> GetIdFor(absl::string_view name);
absl::StatusOr<DeviceId> GetIdFor(absl::string_view name);

using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;

DeviceRegistration* GetCompilationDevice(DeviceId device) const {
return id_to_compilation_device_[device.id()];
}

StatusOr<DeviceRegistration*> GetCompilationDevice(absl::string_view name) {
absl::StatusOr<DeviceRegistration*> GetCompilationDevice(
absl::string_view name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
return GetCompilationDevice(device_id);
}
Expand All @@ -137,7 +138,8 @@ class DeviceInfoCache {

using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;

StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(absl::string_view device_name) {
absl::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
absl::string_view device_name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
return std::cref(*id_to_device_type_[device_id.id()]);
}
Expand Down Expand Up @@ -196,7 +198,7 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
// case it is the responsibility of the optimization pass that injected the
// CPU nodes into the cluster to ensure that these nodes can be compiled by
// the unknown XLA backend.
StatusOr<jit::DeviceId> PickDeviceForXla(
absl::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);

Expand All @@ -205,7 +207,7 @@ StatusOr<jit::DeviceId> PickDeviceForXla(
//
// We return a failing Status for errors unrelated to the device choice
// algorithm itself.
StatusOr<std::optional<jit::DeviceId>> MaybePickDeviceForXla(
absl::StatusOr<std::optional<jit::DeviceId>> MaybePickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
} // namespace tensorflow
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/jit/encapsulate_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ struct XlaClusterInfo {
// dependencies and control dependencies. cluster_deps maps the name name of an
// outside compilation cluster to a set of names of outside compilation clusters
// that it depends on.
tsl::StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
absl::StatusOr<
std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(
const Graph* g, const string& outside_compilation_attr_name);

Expand Down
5 changes: 3 additions & 2 deletions tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
// XlaLaunch -> NodeA
static Status BuildXlaLaunchOps(
Graph* graph,
const std::function<StatusOr<bool>(const Node&)>& is_xla_launch_node,
const std::function<StatusOr<XlaFunctionInfo>(const Node&)>&
const std::function<absl::StatusOr<bool>(const Node&)>&
is_xla_launch_node,
const std::function<absl::StatusOr<XlaFunctionInfo>(const Node&)>&
get_xla_function_info,
bool add_edges_to_output_of_downstream_nodes);
};
Expand Down
12 changes: 6 additions & 6 deletions tensorflow/compiler/jit/get_compiler_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ limitations under the License.

namespace tensorflow {

static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
static absl::StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
xla::LocalClient* local_client,
const XlaCompiler::CompilationResult& result,
const XlaCompiler::Options& options,
Expand Down Expand Up @@ -93,7 +93,7 @@ static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
return std::move(executables[0]);
}

static StatusOr<std::string> BuildHLOString(
static absl::StatusOr<std::string> BuildHLOString(
IrExportStage stage, const XlaCompiler::CompilationResult& result,
xla::LocalClient* local_client, const XlaCompiler::Options& options) {
switch (stage) {
Expand Down Expand Up @@ -138,7 +138,7 @@ static StatusOr<std::string> BuildHLOString(
case IrExportStage::OPTIMIZED_HLO_DOT: {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
BuildExecutable(local_client, result, options));
StatusOr<std::string> graph = xla::RenderGraph(
absl::StatusOr<std::string> graph = xla::RenderGraph(
*executable->executable()->module().entry_computation(),
"Visualization",
/*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
Expand All @@ -149,7 +149,7 @@ static StatusOr<std::string> BuildHLOString(
}
}

static StatusOr<std::vector<XlaCompiler::Argument>>
static absl::StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArgumentFromTensorSpec(
const FunctionBody* fbody, absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
Expand Down Expand Up @@ -328,7 +328,7 @@ absl::StatusOr<std::string> CompileAndBuildHLOString(
* - `input_handles`: Contains all concrete_fn inputs tensors, including
* captured inputs.
*/
StatusOr<std::string> GetCompilerIr(
absl::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const ArgShapeAndDType> input_arg_shape_and_dtype,
Expand Down Expand Up @@ -386,7 +386,7 @@ StatusOr<std::string> GetCompilerIr(
function, args);
}

StatusOr<std::string> GetCompilerIr(
absl::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, absl::string_view platform_name,
EagerContext* context,
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/jit/get_compiler_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ enum class CompilerArgSource {
// Returns the IR format of the selected stage for a given function `func_name`
// using library runtime `runtime` on a device `dev` with given
// `inputs_arg_shape_and_dtype` and `input_handles`.
StatusOr<std::string> GetCompilerIr(
absl::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const ArgShapeAndDType> input_arg_shape_and_dtype,
Expand All @@ -64,7 +64,7 @@ StatusOr<std::string> GetCompilerIr(
// Returns the IR format of the selected stage for a given function `func_name`
// using library runtime `runtime` on a platform `platform_name` with given
// `inputs_arg_shape_and_dtype` and `input_handles`.
StatusOr<std::string> GetCompilerIr(
absl::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, absl::string_view platform_name,
EagerContext* context,
Expand Down

0 comments on commit b0a8809

Please sign in to comment.