diff --git a/test/run_tests.sh b/test/run_tests.sh index 386b3b817742..a5ac8f2933c8 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -40,13 +40,13 @@ if [ "$LOGFILE" != "" ]; then python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA 2>&1 | tee $LOGFILE python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA 2>&1 | tee $LOGFILE python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE - # run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE + run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE python3 "$CDIR/test_mp_replication.py" "$@" 2>&1 | tee $LOGFILE else python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY - # run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY python3 "$CDIR/test_mp_replication.py" "$@" fi diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 5c5268379183..affbb2d6ecfa 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -432,6 +432,16 @@ def mark_step(): ms.save_metrics() +def wait_device_ops(devices=[]): + """Waits for all the async operations on the given devices to complete. + + Args: + devices (string..., optional): The devices whose async ops need to be waited + for. If empty, all the local devices will be waited for. + """ + torch_xla._XLAC._xla_wait_device_ops(devices=devices) + + def optimizer_step(optimizer, barrier=False, optimizer_args={}): """Run the provided optimizer step and issue the XLA device step computation. diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a943a2cec68f..26f5935126bb 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -322,6 +322,12 @@ void InitXlaModuleBindings(py::module m) { StepMarker(device, devices, wait); }, py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def("_xla_wait_device_ops", + [](const std::vector& devices) { + NoGilSection nogil; + XLATensor::WaitDeviceOps(devices); + }, + py::arg("devices")); m.def("_xla_counter_names", []() { return xla::metrics::GetCounterNames(); }); m.def("_xla_counter_value", [](const std::string& name) -> py::object { xla::metrics::CounterData* data = xla::metrics::GetCounter(name); diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp index 10c6cd569919..b68397aee7f2 100644 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ b/torch_xla/csrc/op_by_op_executor.cpp @@ -43,8 +43,8 @@ const xla::Shape& GetParameterShape(const ir::Output& operand, size_t ComputeNodeKey( const ir::Node* node, - tensorflow::gtl::ArraySlice input_shapes) { - size_t key = 0x129b98d6968b7; + tensorflow::gtl::ArraySlice input_shapes, size_t seed) { + size_t key = seed; const auto& operands = node->operands(); for (size_t i = 0; i < operands.size(); ++i) { key = xla::util::HashCombine(key, xla::util::ShapeHash(GetParameterShape( @@ -72,6 +72,11 @@ xla::XlaComputation BuildNodeComputation( return ConsumeValue(loctx.Build()); } +size_t GetNodesKeySeed(const std::string& device, + tensorflow::gtl::ArraySlice devices) { + return xla::util::MHash(device, devices); +} + } // namespace OpByOpExecutor::OpByOpExecutor(size_t compile_cache_size) @@ -97,6 +102,9 @@ std::vector OpByOpExecutor::BuildOps( node_to_index[post_order[i]] = i; } + auto compilation_devices = + xla::ComputationClient::Get()->GetCompilationDevices(device, devices); + size_t nodes_key_seed = GetNodesKeySeed(device, compilation_devices); Device exec_device(device); std::vector cache_keys; std::unordered_map> compile_indices; @@ -126,7 +134,7 @@ std::vector OpByOpExecutor::BuildOps( op_input_shapes.push_back(ops_shapes[op_index]); } - size_t cache_key = ComputeNodeKey(node, op_input_shapes); + size_t cache_key = ComputeNodeKey(node, op_input_shapes, nodes_key_seed); cxop.computation = compile_cache_.Get(cache_key); if (cxop.computation == nullptr) { XLA_COUNTER("OpByOpCompileCacheMiss", 1); @@ -145,12 +153,9 @@ std::vector OpByOpExecutor::BuildOps( ConsumeValue(computation.GetProgramShape()); compile_shapes.push_back(MakeShapeWithDeviceLayout( program_shape.result(), exec_device.hw_type)); - compile_instances.push_back( - {std::move(computation), device, - xla::ComputationClient::Get()->GetCompilationDevices(device, - devices), - &compile_shapes.back()}); - + compile_instances.push_back({std::move(computation), device, + compilation_devices, + &compile_shapes.back()}); ops_shapes[i] = &compile_shapes.back(); } else { ops_shapes[i] = @@ -171,8 +176,12 @@ std::vector OpByOpExecutor::BuildOps( // If we missed the cache for certain ops, compile them now and fixup the // chained ops vector. if (!compile_instances.empty()) { + TF_VLOG(3) << "Compiling " << compile_instances.size() + << " computations on device " << device; auto computation_ptrs = xla::ComputationClient::Get()->Compile(std::move(compile_instances)); + TF_VLOG(3) << "Compiling " << computation_ptrs.size() + << " computations on device " << device << " done!"; for (size_t i = 0; i < computation_ptrs.size(); ++i) { compile_cache_.Add(cache_keys[i], computation_ptrs[i]); for (auto index : compile_indices[cache_keys[i]]) { diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 348ef37b2134..19ca9883c932 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -1019,8 +1019,8 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( tensors[at_tensor_index[i]].data()->xla_data = std::move(handles[i]); } } - TF_VLOG(4) << "Tensors graph hash " << coll.hash << " on device '" - << coll.device << "'"; + TF_VLOG(4) << "Tensors graph hash " << coll.hash << " on device " + << coll.device; return coll; } @@ -1171,6 +1171,21 @@ void XLATensor::MarkStep(const Device* device) { DeviceContextArena::Get()->ClearProfileData(device); } +void XLATensor::WaitDeviceOps( + tensorflow::gtl::ArraySlice devices) { + std::set wait_devices; + if (!devices.empty()) { + for (auto& device_str : devices) { + wait_devices.insert(Device(device_str)); + } + } else { + for (auto& device_str : xla::ComputationClient::Get()->GetLocalDevices()) { + wait_devices.insert(Device(device_str)); + } + } + LockDevices(wait_devices); +} + XLATensor::OpByOpAsync XLATensor::SyncTensorsGraphOpByOp( std::vector* tensors, tensorflow::gtl::ArraySlice devices, diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 2e4ebae80cf9..5a6e8b0fbdf6 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -144,6 +144,11 @@ class XLATensor { // the computation boundaries. static void MarkStep(const Device* device); + // Waits for all the outstanding operations on all the supplied devices. + // If devices is empty, the wait will happen for all local devices. + static void WaitDeviceOps( + tensorflow::gtl::ArraySlice devices); + // Retrieves the PyTorch CPU tensors behind the XLA tensors IR operations. // All the tensors must be on the same device. static std::vector GetTensors(std::vector* tensors); diff --git a/torch_xla/distributed/data_parallel.py b/torch_xla/distributed/data_parallel.py index fe65a0044379..6b196af1c6d5 100644 --- a/torch_xla/distributed/data_parallel.py +++ b/torch_xla/distributed/data_parallel.py @@ -136,6 +136,7 @@ def __call__(self, loop_fn, loader, fixed_batch_size=False, batchdim=0): torch.device(self._device_ids[0]), self._contexts[0]) ] + xm.wait_device_ops() para_loader = pl.ParallelLoader( loader, self._device_ids,