Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,8 +1331,6 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
static const size_t parameter_wrapping_threadshold =
runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
static const bool using_pjrt =
runtime::sys_util::GetEnvString("PJRT_DEVICE", "").size() > 0;
static const bool use_autosharding = ShardingUtil::GetAutoSharding();
LoweringContext lowering_ctx("SyncTensorsGraph", coll.device,
po_data->post_order,
Expand Down Expand Up @@ -1393,7 +1391,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
// TODO(yeounoh) enable wrapping with auto-sharding.
bool should_wrap_parameter =
(program_shape.parameters_size() >= parameter_wrapping_threadshold) &&
using_pjrt && !use_autosharding;
!use_autosharding;
if (should_wrap_parameter) {
TF_VLOG(3) << "Wrapping graph with " << program_shape.parameters_size()
<< " parameters. Threadshold = "
Expand Down