Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ spmd
.. autoclass:: HybridMesh
.. autoclass:: ShardingSpec

experimental
----------------------------------
.. automodule:: torch_xla.experimental
.. autofunction:: eager_mode

debug
----------------------------------
Expand Down
12 changes: 12 additions & 0 deletions examples/eager/train_decoder_only_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_decoder_only_base import TrainDecoderOnlyBase

import torch_xla

if __name__ == '__main__':
torch_xla.experimental.eager_mode(True)
base = TrainDecoderOnlyBase()
base.start_training()
12 changes: 12 additions & 0 deletions test/debug_tool/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def setUpClass(cls):
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
open(cls.debug_file_name, 'w').close()

def test_eager_mark_step(self):
torch_xla.experimental.eager_mode(True)
device = xm.xla_device()
t1 = torch.randn(5, 9, device=device)
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
# We expect PT_XLA_BUDEG not to output anything under the eager mode
self.assertEqual(len(lines), 0)
torch_xla.experimental.eager_mode(False)
open(self.debug_file_name, 'w').close()

def test_user_mark_step(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
Expand Down
19 changes: 19 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ def test_tracing_time_metrics(self):
self.assertIn('LazyTracing', met.metric_names())
self.assertGreater(met.metric_data('LazyTracing')[0], 1)

def test_eager_metrics(self):
torch_xla.experimental.eager_mode(True)
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.tensor(156, device=xla_device)
t2 = t1 + 100
xm.wait_device_ops()
self.assertIn('EagerOpCompileTime', met.metric_names())
# one for cosntant, one for add
self.assertEqual(met.metric_data('EagerOpCompileTime')[0], 2)
self.assertIn('EagerOpExecuteTime', met.metric_names())
# one for add
self.assertEqual(met.metric_data('EagerOpExecuteTime')[0], 2)
# mark_step should be a no-op
xm.mark_step()
self.assertNotIn('CompileTime', met.metric_names())
self.assertNotIn('ExecuteTime', met.metric_names())
torch_xla.experimental.eager_mode(False)

def test_short_metrics_report_default_list(self):
xla_device = xm.xla_device()
t1 = torch.tensor(1456, device=xla_device)
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ python3 examples/data_parallel/train_resnet_xla_ddp.py
python3 examples/fsdp/train_decoder_only_fsdp_v2.py
python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py
python3 examples/train_resnet_amp.py
python3 examples/eager/train_decoder_only_eager.py
13 changes: 13 additions & 0 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,12 @@ void DebugUtil::analyze_graph_execution_python_frame(
return;
}

// don't output analysis for eager mode execution/compilation
if (XLAGraphExecutor::Get()->UseEagerMode() &&
source != GraphAnalysisSource::DynamoExecution) {
return;
}

if (pt_xla_debug_level <= 1 && source != GraphAnalysisSource::Compilation) {
// for debug level <=1, only output compilation analysis in this function.
return;
Expand Down Expand Up @@ -385,6 +391,13 @@ void DebugUtil::post_compilation_analysis(
if (pt_xla_debug_level <= 0 || !is_master_process) {
return;
}

// don't output analysis for eager mode execution/compilation.
// TODO(JackCaoG): enable this for eager+dynamo
if (XLAGraphExecutor::Get()->UseEagerMode()) {
return;
}

static const std::string debug_output_prefix = "Post Compilation Analysis: ";
std::stringstream ss;
ss << "\n"
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2403,6 +2403,11 @@ void InitXlaModuleBindings(py::module m) {
});
m.def("_get_xla_enable_device_data_cache",
[]() { return FLAGS_torch_lazy_enable_device_data_cache; });
m.def("_set_use_eager_mode", [](bool use_eager_mode) {
XLAGraphExecutor::Get()->SetUseEagerMode(use_eager_mode);
});
m.def("_get_use_eager_mode",
[]() { return XLAGraphExecutor::Get()->UseEagerMode(); });
m.def("_replace_xla_tensor",
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/runtime/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,24 @@ metrics::Metric* ComputationClient::CompileMetric() {
return metric;
}

metrics::Metric* ComputationClient::EagerCompileMetric() {
static metrics::Metric* metric =
new metrics::Metric("EagerOpCompileTime", metrics::MetricFnTime);
return metric;
}

metrics::Metric* ComputationClient::ExecuteMetric() {
static metrics::Metric* metric =
new metrics::Metric("ExecuteTime", metrics::MetricFnTime);
return metric;
}

metrics::Metric* ComputationClient::EagerExecuteMetric() {
static metrics::Metric* metric =
new metrics::Metric("EagerOpExecuteTime", metrics::MetricFnTime);
return metric;
}

metrics::Metric* ComputationClient::ExecuteReplicatedMetric() {
static metrics::Metric* metric =
new metrics::Metric("ExecuteReplicatedTime", metrics::MetricFnTime);
Expand Down
25 changes: 14 additions & 11 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class XlaCoordinator;
// ComputationClient.
struct ClientExecuteOptions {
bool explode_tuple{true};
bool eager_mode{false};
};

class ComputationClient {
Expand Down Expand Up @@ -213,16 +214,14 @@ class ComputationClient {
// of torch::lazy::Computation?
struct CompileInstance {
CompileInstance() = default;
CompileInstance(xla::XlaComputation computation,
std::string compilation_device,
std::vector<std::string> devices,
const xla::Shape* output_shape,
bool parameter_is_tupled_arguments = false,
bool is_sharded = false,
bool allow_spmd_sharding_propagation_to_output = true,
bool use_auto_spmd_partitioning = false,
std::vector<int64_t> auto_spmd_mesh_shape = {},
std::vector<int64_t> auto_spmd_mesh_ids = {})
CompileInstance(
xla::XlaComputation computation, std::string compilation_device,
std::vector<std::string> devices, const xla::Shape* output_shape,
bool parameter_is_tupled_arguments = false, bool is_sharded = false,
bool allow_spmd_sharding_propagation_to_output = true,
bool use_auto_spmd_partitioning = false,
std::vector<int64_t> auto_spmd_mesh_shape = {},
std::vector<int64_t> auto_spmd_mesh_ids = {}, bool eager_mode = false)
: computation(std::move(computation)),
compilation_device(std::move(compilation_device)),
devices(std::move(devices)),
Expand All @@ -233,7 +232,8 @@ class ComputationClient {
allow_spmd_sharding_propagation_to_output),
use_auto_spmd_partitioning(use_auto_spmd_partitioning),
auto_spmd_mesh_shape(auto_spmd_mesh_shape),
auto_spmd_mesh_ids(auto_spmd_mesh_ids) {}
auto_spmd_mesh_ids(auto_spmd_mesh_ids),
eager_mode(eager_mode) {}

xla::XlaComputation computation;
std::string compilation_device;
Expand All @@ -245,6 +245,7 @@ class ComputationClient {
bool use_auto_spmd_partitioning;
std::vector<int64_t> auto_spmd_mesh_shape;
std::vector<int64_t> auto_spmd_mesh_ids;
bool eager_mode;
};

struct ExecuteComputationOptions : public ClientExecuteOptions {};
Expand Down Expand Up @@ -430,7 +431,9 @@ class ComputationClient {
static metrics::Metric* TransferToDeviceTransformMetric();
static metrics::Metric* TransferFromDeviceMetric();
static metrics::Metric* CompileMetric();
static metrics::Metric* EagerCompileMetric();
static metrics::Metric* ExecuteMetric();
static metrics::Metric* EagerExecuteMetric();
static metrics::Metric* ExecuteReplicatedMetric();
static metrics::Metric* ExecuteParallelMetric();
static metrics::Metric* ExecuteChainedMetric();
Expand Down
12 changes: 10 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,11 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(

std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
std::vector<ComputationClient::CompileInstance> instances) {
metrics::TimedSection timed(CompileMetric());
auto metrics_fn = CompileMetric;
if (instances[0].eager_mode) {
metrics_fn = EagerCompileMetric;
}
metrics::TimedSection timed(metrics_fn());
tsl::profiler::TraceMe activity("PjRtComputationClient::Compile",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::ComputationPtr> computations;
Expand Down Expand Up @@ -695,7 +699,11 @@ PjRtComputationClient::ExecuteComputation(
// Shared ownership of the timed section ensures that it will only get logged
// once both `ExecuteComputation` and the async work in `ExecuteSharded` are
// complete; a copy is held from the lambda that releases it when done.
auto timed = std::make_shared<metrics::TimedSection>(ExecuteMetric());
auto metrics_fn = ExecuteMetric;
if (options.eager_mode) {
metrics_fn = EagerExecuteMetric;
}
auto timed = std::make_shared<metrics::TimedSection>(metrics_fn());
tsl::profiler::TraceMe activity("PjRtComputationClient::ExecuteComputation",
tsl::profiler::TraceMeLevel::kInfo);
TF_VLOG(1) << "Executing PjRt computation on " << device;
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ XLATensorPtr XLATensor::Create(
std::optional<at::ScalarType> logical_element_type) {
XLATensorPtr xtensor = c10::make_intrusive<XLATensor>(
XLATensor(std::move(ir_value), device, logical_element_type));
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
if (UseEagerDebugMode()) {
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
graph_executor->RegisterTensor(xtensor->data());
if (UseEagerDebugMode() || graph_executor->UseEagerMode()) {
std::vector<XLATensorPtr> xtensors({xtensor});
XLAGraphExecutor::Get()->ApplyEagerSync(xtensors);
graph_executor->ApplyEagerSync(xtensors);
}
return xtensor;
}
Expand Down
18 changes: 12 additions & 6 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ bool ShouldSyncIrValue(const torch::lazy::Value& ir_value) {

XLAGraphExecutor::ComputationCache* CreateComputationCache() {
static const size_t kMaxCacheSize =
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 1024);
runtime::sys_util::GetEnvInt("XLA_COMPILATION_CACHE_SIZE", 2048);
static const bool readonlyPersistentCache =
runtime::sys_util::GetEnvBool("XLA_PERSISTENT_CACHE_READ_ONLY", false);
static std::string persistentCacheDir =
Expand Down Expand Up @@ -814,6 +814,7 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(

std::vector<torch::lazy::BackendDataPtr> results;
if (async->cached_computation->is_sharded) {
// TODO(JackCaoG): handle eager mode
std::vector<std::string> devices =
runtime::GetComputationClient()->GetLocalDevices();
runtime::ComputationClient::ExecuteReplicatedOptions execute_options;
Expand Down Expand Up @@ -1088,7 +1089,8 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
std::shared_ptr<XLAGraphExecutor::Async> async = std::make_shared<Async>(
coll, std::move(parameters_data), std::move(tensors_data),
std::move(cached_computation));
auto syncfn = [async, hash = coll->hash, sharding_specs = sharding_specs]() {
auto syncfn = [async, hash = coll->hash, sharding_specs = sharding_specs,
use_eager_mode = UseEagerMode()]() {
try {
std::vector<torch::lazy::BackendDataPtr> results;
// Execute replicated if the compiled computation is partitioned.
Expand Down Expand Up @@ -1117,9 +1119,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
<< async->device << " ...";
results = torch::lazy::getBackend()->ExecuteComputation(
async->cached_computation->computation, async->parameters_data,
async->device);
std::vector<runtime::ComputationClient::DataPtr> outputs =
runtime::GetComputationClient()->ExecuteComputation(
*async->cached_computation->computation,
UnwrapXlaData(async->parameters_data), async->device.toString(),
{/*explode_tuple=*/true,
/*eager_mode=*/use_eager_mode});
results = WrapXlaData(outputs);
TORCH_LAZY_COUNTER("ExecuteComputation", 1);
TF_VLOG(3) << "Executing IR graph hash "
<< torch::lazy::HashToString(hash) << " on device "
Expand Down Expand Up @@ -1372,7 +1378,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
runtime::GetComputationClient()->GetCompilationDevices(
coll.device.toString(), devices),
&shape, should_wrap_parameter, is_sharded});

instances.front().eager_mode = UseEagerMode();
if (use_autosharding) {
TF_VLOG(5) << "use_auto_spmd_partitioning is set.";
TF_CHECK(is_sharded) << "Auto-sharding pass requires SPMD mode.";
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
void ClearPendingIrs(std::vector<XLATensorPtr> tensors,
const torch::lazy::BackendDevice& device);

void SetUseEagerMode(bool use_eager_mode) {
use_eager_mode_ = use_eager_mode;
}

bool UseEagerMode() { return use_eager_mode_; }

private:
// This is just to group results from compile(). Since our computation is
// different, we don't reuse the upstream CompilationResult.
Expand Down Expand Up @@ -361,6 +367,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const SyncTensorsConfig& config, bool warm_up_cache_only = false);

ComputationCache* computation_cache_;
bool use_eager_mode_ = false;
};

} // namespace torch_xla
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .eager import eager_mode

__all__ = [
"eager_mode",
]
9 changes: 9 additions & 0 deletions torch_xla/experimental/eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch_xla


def eager_mode(enable: bool):
"""Configure torch_xla's default executation mode.
Under eager mode only functions that was `torch_xla.compile`d will be
traced and compiled. Other torch ops will be executed eagerly.
"""
torch_xla._XLAC._set_use_eager_mode(enable)