diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 255fe0079cd5..b832b8f4b9fc 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -58,13 +58,16 @@ def test_resnet18(self): resnet18.eval() xla_resnet18 = torchvision.models.resnet18().to(device) xla_resnet18.eval() + # materalize the fake data for test purpose + xm.mark_step() + xm.wait_device_ops() + met.clear_all() for data, _ in loader: output = self.run_model_with_dynamo(xla_resnet18, data) torch.allclose(resnet18(data.cpu()), output.cpu()) - # One graph for initial input data materialization. Another grpah for the - # real model code. - self.assertEqual(met.metric_data('CompileTime')[0], 2) - self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count + 2) + # We only expect one graph for the resnet18 inference. + self.assertEqual(met.metric_data('CompileTime')[0], 1) + self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count) self.assertEqual( met.metric_data('RunCachedGraphInputData')[0], sample_count) self.assertEqual( @@ -157,10 +160,8 @@ def test_resnet18(self): # Graph 3: sync input for backward # Graph 4: sync input for backward (TODO(JackCaoG) understand why there are two graphs) self.assertEqual(met.metric_data('CompileTime')[0], 4) - # We execute 3 grphs per step, and currently cache the graph for forward and backward - # will each take 1 additional execution. - # TODO(JackCaoG): Optimize the 2 cached execution. - self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3 + 2) + # We execute 3 grphs per step. + self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count * 3) # one for each forward and one for each backward self.assertEqual( met.metric_data('RunCachedGraphInputData')[0], sample_count * 2) diff --git a/test/dynamo/test_dynamo_integrations_util.py b/test/dynamo/test_dynamo_integrations_util.py index b1755e939cf9..3b3adf39e40d 100644 --- a/test/dynamo/test_dynamo_integrations_util.py +++ b/test/dynamo/test_dynamo_integrations_util.py @@ -109,8 +109,8 @@ def test_run_cached_graph(self): xla_dummy_model = dummy_model.to(xla_device) xla_out = xla_dummy_model(xla_input) hash = torch_xla._XLAC._get_graph_hash([xla_out]) - # Force trigger an execution to cache this computation. - torch_xla._XLAC._xla_sync_multi([xla_out], []) + # Warm up the cache. + torch_xla._XLAC._xla_warm_up_cache([xla_out], []) # It is the caller of `run_cached_graph`'s job to make sure the input order # matches the graph input order. Upstream dynamo has a more completed diff --git a/test/dynamo/test_num_output.py b/test/dynamo/test_num_output.py index 138e78289aa4..f05a2964935b 100644 --- a/test/dynamo/test_num_output.py +++ b/test/dynamo/test_num_output.py @@ -82,10 +82,10 @@ def do_test(self, model_class, expected_num_output): graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_xla_values) - torch_xla._XLAC._xla_sync_multi(outputs, []) + torch_xla._XLAC._xla_warm_up_cache(outputs, []) def run_cached_graph(*inputs): - torch_xla._XLAC._xla_sync_multi(inputs, []) + torch_xla._XLAC._xla_warm_up_cache(inputs, []) xla_graph_inputs = graph_input_matcher(inputs) xla_graph_outputs = torch_xla._XLAC._run_cached_graph( xla_graph_hash, xla_graph_inputs) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index e5b1d74e8334..9b1724ebfb87 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -275,8 +275,8 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): graph_input_tensor_ids, graph_input_xla_values) - # compiles+runs graph rooted at tensors in 'args_and_out' - torch_xla._XLAC._xla_sync_multi(args_and_out, []) + # compiles and cache graph rooted at tensors in 'args_and_out' + torch_xla._XLAC._xla_warm_up_cache(args_and_out, []) torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) def optimized_mod(*args): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 921e73da75a3..5c0f37b217da 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -329,11 +329,11 @@ std::pair> Recv( void SyncTensors(const std::vector& tensors, const std::vector& devices, bool wait, - bool sync_xla_data) { + bool sync_xla_data, bool warm_up_cache_only = false) { std::vector xtensors = GetXlaTensors(tensors, /*want_all=*/false); XLAGraphExecutor::Get()->SyncTensorsGraph(&xtensors, devices, wait, - sync_xla_data); + sync_xla_data, warm_up_cache_only); } void SyncLiveTensors(const std::string& device_str, @@ -1240,6 +1240,14 @@ void InitXlaModuleBindings(py::module m) { }, py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, py::arg("sync_xla_data") = true); + m.def("_xla_warm_up_cache", + [](const std::vector& tensors, + const std::vector& devices) { + NoGilSection nogil; + SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false, + /*warm_up_cache_only=*/true); + }, + py::arg("tensors"), py::arg("devices")); m.def("_xla_sync_live_tensors", [](const std::string& device, const std::vector& devices, bool wait) { diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 947acbdb88af..97a0432193b1 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -322,7 +322,8 @@ std::vector XLAGraphExecutor::GetLiveTensors( void XLAGraphExecutor::SyncTensorsGraph(std::vector* tensors, absl::Span devices, - bool wait, bool sync_ltc_data) { + bool wait, bool sync_ltc_data, + bool warm_up_cache_only) { TF_VLOG(4) << "Trying to sync the value of " << tensors->size() << " tensor(s)"; tensorflow::profiler::TraceMe activity( @@ -331,14 +332,18 @@ void XLAGraphExecutor::SyncTensorsGraph(std::vector* tensors, xla::sys_util::GetEnvBool("XLA_SYNC_TENSORS_OPBYOP", false); SyncTensorsConfig config; config.sync_ltc_data = sync_ltc_data; + if (warm_up_cache_only) { + config.force_ltc_data = false; + } if (op_by_op) { OpByOpAsync async = SyncTensorsGraphOpByOp(tensors, devices, config); if (wait) { async.Wait(); } } else { - auto async = SyncTensorsGraphInternal(tensors, devices, config); - if (wait && async != nullptr) { + auto async = + SyncTensorsGraphInternal(tensors, devices, config, warm_up_cache_only); + if (wait && async != nullptr && !warm_up_cache_only) { async->mwait.Wait(); } } @@ -395,6 +400,7 @@ torch::lazy::hash_t XLAGraphExecutor::GetGraphHash( const std::vector& tensors) { SyncTensorsConfig config; config.sync_ltc_data = true; + config.force_ltc_data = false; SyncTensorCollection coll = CollectSyncTensors(tensors, config); absl::Span indices = coll.indices; @@ -1191,7 +1197,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( std::shared_ptr XLAGraphExecutor::SyncTensorsGraphInternal( std::vector* tensors, absl::Span devices, - const SyncTensorsConfig& config) { + const SyncTensorsConfig& config, bool warm_up_cache_only) { tensorflow::profiler::TraceMe activity( "SyncTensorsGraphInternal", tensorflow::profiler::TraceMeLevel::kInfo); SyncTensorCollection coll = CollectSyncTensors(*tensors, config); @@ -1226,10 +1232,14 @@ XLAGraphExecutor::SyncTensorsGraphInternal( std::move(compile_result.computation), compile_result.is_sharded); GetComputationCache()->Add(coll.hash, cached_computation); - return ScheduleSyncTensorsGraph( - tensors, &coll, std::move(compile_result.parameters_data), - compile_result.device.toString(), std::move(cached_computation), - tensor_data_vec); + if (warm_up_cache_only) { + return nullptr; + } else { + return ScheduleSyncTensorsGraph( + tensors, &coll, std::move(compile_result.parameters_data), + compile_result.device.toString(), std::move(cached_computation), + tensor_data_vec); + } } } // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index 7a0e35ac09d3..b198a49b84dd 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -115,7 +115,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // We don't use the upstream one given we have OpbyOp mode. void SyncTensorsGraph(std::vector* tensors, absl::Span devices, bool wait, - bool sync_ltc_data); + bool sync_ltc_data, bool warm_up_cache_only = false); // Makes sure that any outstanding IR operation accumulated over live tensors, // gets turned into device data. If wait is true, the sync operation will be @@ -338,7 +338,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // our CachedComputation is different from upstream. std::shared_ptr SyncTensorsGraphInternal( std::vector* tensors, absl::Span devices, - const SyncTensorsConfig& config); + const SyncTensorsConfig& config, bool warm_up_cache_only = false); }; } // namespace torch_xla