Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@ def test_resnet18(self):
resnet18.eval()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
# materalize the fake data for test purpose
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
for data, _ in loader:
output = self.run_model_with_dynamo(xla_resnet18, data)
torch.allclose(resnet18(data.cpu()), output.cpu())
# One graph for initial input data materialization. Another grpah for the
# real model code.
self.assertEqual(met.metric_data('CompileTime')[0], 2)
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count + 2)
# We only expect one graph for the resnet18 inference.
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count)
self.assertEqual(
met.metric_data('RunCachedGraphInputData')[0], sample_count)
self.assertEqual(
Expand Down Expand Up @@ -157,10 +160,8 @@ def test_resnet18(self):
# Graph 3: sync input for backward
# Graph 4: sync input for backward (TODO(JackCaoG) understand why there are two graphs)
self.assertEqual(met.metric_data('CompileTime')[0], 4)
# We execute 3 grphs per step, and currently cache the graph for forward and backward
# will each take 1 additional execution.
# TODO(JackCaoG): Optimize the 2 cached execution.
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3 + 2)
# We execute 3 grphs per step.
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3)
# one for each forward and one for each backward
self.assertEqual(
met.metric_data('RunCachedGraphInputData')[0], sample_count * 2)
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_dynamo_integrations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def test_run_cached_graph(self):
xla_dummy_model = dummy_model.to(xla_device)
xla_out = xla_dummy_model(xla_input)
hash = torch_xla._XLAC._get_graph_hash([xla_out])
# Force trigger an execution to cache this computation.
torch_xla._XLAC._xla_sync_multi([xla_out], [])
# Warm up the cache.
torch_xla._XLAC._xla_warm_up_cache([xla_out], [])

# It is the caller of `run_cached_graph`'s job to make sure the input order
# matches the graph input order. Upstream dynamo has a more completed
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_num_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def do_test(self, model_class, expected_num_output):
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
graph_input_tensor_ids,
graph_input_xla_values)
torch_xla._XLAC._xla_sync_multi(outputs, [])
torch_xla._XLAC._xla_warm_up_cache(outputs, [])

def run_cached_graph(*inputs):
torch_xla._XLAC._xla_sync_multi(inputs, [])
torch_xla._XLAC._xla_warm_up_cache(inputs, [])
xla_graph_inputs = graph_input_matcher(inputs)
xla_graph_outputs = torch_xla._XLAC._run_cached_graph(
xla_graph_hash, xla_graph_inputs)
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
graph_input_tensor_ids,
graph_input_xla_values)

# compiles+runs graph rooted at tensors in 'args_and_out'
torch_xla._XLAC._xla_sync_multi(args_and_out, [])
# compiles and cache graph rooted at tensors in 'args_and_out'
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

def optimized_mod(*args):
Expand Down
12 changes: 10 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,11 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> Recv(

void SyncTensors(const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, bool wait,
bool sync_xla_data) {
bool sync_xla_data, bool warm_up_cache_only = false) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/false);
XLAGraphExecutor::Get()->SyncTensorsGraph(&xtensors, devices, wait,
sync_xla_data);
sync_xla_data, warm_up_cache_only);
}

void SyncLiveTensors(const std::string& device_str,
Expand Down Expand Up @@ -1240,6 +1240,14 @@ void InitXlaModuleBindings(py::module m) {
},
py::arg("tensors"), py::arg("devices"), py::arg("wait") = true,
py::arg("sync_xla_data") = true);
m.def("_xla_warm_up_cache",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices) {
NoGilSection nogil;
SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false,
/*warm_up_cache_only=*/true);
},
py::arg("tensors"), py::arg("devices"));
m.def("_xla_sync_live_tensors",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
Expand Down
26 changes: 18 additions & 8 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ std::vector<XLATensorPtr> XLAGraphExecutor::GetLiveTensors(

void XLAGraphExecutor::SyncTensorsGraph(std::vector<XLATensorPtr>* tensors,
absl::Span<const std::string> devices,
bool wait, bool sync_ltc_data) {
bool wait, bool sync_ltc_data,
bool warm_up_cache_only) {
TF_VLOG(4) << "Trying to sync the value of " << tensors->size()
<< " tensor(s)";
tensorflow::profiler::TraceMe activity(
Expand All @@ -331,14 +332,18 @@ void XLAGraphExecutor::SyncTensorsGraph(std::vector<XLATensorPtr>* tensors,
xla::sys_util::GetEnvBool("XLA_SYNC_TENSORS_OPBYOP", false);
SyncTensorsConfig config;
config.sync_ltc_data = sync_ltc_data;
if (warm_up_cache_only) {
config.force_ltc_data = false;
}
if (op_by_op) {
OpByOpAsync async = SyncTensorsGraphOpByOp(tensors, devices, config);
if (wait) {
async.Wait();
}
} else {
auto async = SyncTensorsGraphInternal(tensors, devices, config);
if (wait && async != nullptr) {
auto async =
SyncTensorsGraphInternal(tensors, devices, config, warm_up_cache_only);
if (wait && async != nullptr && !warm_up_cache_only) {
async->mwait.Wait();
}
}
Expand Down Expand Up @@ -395,6 +400,7 @@ torch::lazy::hash_t XLAGraphExecutor::GetGraphHash(
const std::vector<XLATensorPtr>& tensors) {
SyncTensorsConfig config;
config.sync_ltc_data = true;
config.force_ltc_data = false;

SyncTensorCollection coll = CollectSyncTensors(tensors, config);
absl::Span<const size_t> indices = coll.indices;
Expand Down Expand Up @@ -1191,7 +1197,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
std::shared_ptr<XLAGraphExecutor::Async>
XLAGraphExecutor::SyncTensorsGraphInternal(
std::vector<XLATensorPtr>* tensors, absl::Span<const std::string> devices,
const SyncTensorsConfig& config) {
const SyncTensorsConfig& config, bool warm_up_cache_only) {
tensorflow::profiler::TraceMe activity(
"SyncTensorsGraphInternal", tensorflow::profiler::TraceMeLevel::kInfo);
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
Expand Down Expand Up @@ -1226,10 +1232,14 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
std::move(compile_result.computation), compile_result.is_sharded);
GetComputationCache()->Add(coll.hash, cached_computation);

return ScheduleSyncTensorsGraph(
tensors, &coll, std::move(compile_result.parameters_data),
compile_result.device.toString(), std::move(cached_computation),
tensor_data_vec);
if (warm_up_cache_only) {
return nullptr;
} else {
return ScheduleSyncTensorsGraph(
tensors, &coll, std::move(compile_result.parameters_data),
compile_result.device.toString(), std::move(cached_computation),
tensor_data_vec);
}
}

} // namespace torch_xla
4 changes: 2 additions & 2 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
// We don't use the upstream one given we have OpbyOp mode.
void SyncTensorsGraph(std::vector<XLATensorPtr>* tensors,
absl::Span<const std::string> devices, bool wait,
bool sync_ltc_data);
bool sync_ltc_data, bool warm_up_cache_only = false);

// Makes sure that any outstanding IR operation accumulated over live tensors,
// gets turned into device data. If wait is true, the sync operation will be
Expand Down Expand Up @@ -338,7 +338,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
// our CachedComputation is different from upstream.
std::shared_ptr<Async> SyncTensorsGraphInternal(
std::vector<XLATensorPtr>* tensors, absl::Span<const std::string> devices,
const SyncTensorsConfig& config);
const SyncTensorsConfig& config, bool warm_up_cache_only = false);
};

} // namespace torch_xla