Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ if [ "$LOGFILE" != "" ]; then
python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE
# run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE
run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE
python3 "$CDIR/test_mp_replication.py" "$@" 2>&1 | tee $LOGFILE
else
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA
python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA
python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA
python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
# run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
python3 "$CDIR/test_mp_replication.py" "$@"
fi
10 changes: 10 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,16 @@ def mark_step():
ms.save_metrics()


def wait_device_ops(devices=[]):
"""Waits for all the async operations on the given devices to complete.

Args:
devices (string..., optional): The devices whose async ops need to be waited
for. If empty, all the local devices will be waited for.
"""
torch_xla._XLAC._xla_wait_device_ops(devices=devices)


def optimizer_step(optimizer, barrier=False, optimizer_args={}):
"""Run the provided optimizer step and issue the XLA device step computation.

Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ void InitXlaModuleBindings(py::module m) {
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def("_xla_wait_device_ops",
[](const std::vector<std::string>& devices) {
NoGilSection nogil;
XLATensor::WaitDeviceOps(devices);
},
py::arg("devices"));
m.def("_xla_counter_names", []() { return xla::metrics::GetCounterNames(); });
m.def("_xla_counter_value", [](const std::string& name) -> py::object {
xla::metrics::CounterData* data = xla::metrics::GetCounter(name);
Expand Down
27 changes: 18 additions & 9 deletions torch_xla/csrc/op_by_op_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ const xla::Shape& GetParameterShape(const ir::Output& operand,

size_t ComputeNodeKey(
const ir::Node* node,
tensorflow::gtl::ArraySlice<const xla::Shape*> input_shapes) {
size_t key = 0x129b98d6968b7;
tensorflow::gtl::ArraySlice<const xla::Shape*> input_shapes, size_t seed) {
size_t key = seed;
const auto& operands = node->operands();
for (size_t i = 0; i < operands.size(); ++i) {
key = xla::util::HashCombine(key, xla::util::ShapeHash(GetParameterShape(
Expand Down Expand Up @@ -72,6 +72,11 @@ xla::XlaComputation BuildNodeComputation(
return ConsumeValue(loctx.Build());
}

size_t GetNodesKeySeed(const std::string& device,
tensorflow::gtl::ArraySlice<const std::string> devices) {
return xla::util::MHash(device, devices);
}

} // namespace

OpByOpExecutor::OpByOpExecutor(size_t compile_cache_size)
Expand All @@ -97,6 +102,9 @@ std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
node_to_index[post_order[i]] = i;
}

auto compilation_devices =
xla::ComputationClient::Get()->GetCompilationDevices(device, devices);
size_t nodes_key_seed = GetNodesKeySeed(device, compilation_devices);
Device exec_device(device);
std::vector<size_t> cache_keys;
std::unordered_map<size_t, std::vector<size_t>> compile_indices;
Expand Down Expand Up @@ -126,7 +134,7 @@ std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
op_input_shapes.push_back(ops_shapes[op_index]);
}

size_t cache_key = ComputeNodeKey(node, op_input_shapes);
size_t cache_key = ComputeNodeKey(node, op_input_shapes, nodes_key_seed);
cxop.computation = compile_cache_.Get(cache_key);
if (cxop.computation == nullptr) {
XLA_COUNTER("OpByOpCompileCacheMiss", 1);
Expand All @@ -145,12 +153,9 @@ std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
ConsumeValue(computation.GetProgramShape());
compile_shapes.push_back(MakeShapeWithDeviceLayout(
program_shape.result(), exec_device.hw_type));
compile_instances.push_back(
{std::move(computation), device,
xla::ComputationClient::Get()->GetCompilationDevices(device,
devices),
&compile_shapes.back()});

compile_instances.push_back({std::move(computation), device,
compilation_devices,
&compile_shapes.back()});
ops_shapes[i] = &compile_shapes.back();
} else {
ops_shapes[i] =
Expand All @@ -171,8 +176,12 @@ std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
// If we missed the cache for certain ops, compile them now and fixup the
// chained ops vector.
if (!compile_instances.empty()) {
TF_VLOG(3) << "Compiling " << compile_instances.size()
<< " computations on device " << device;
auto computation_ptrs =
xla::ComputationClient::Get()->Compile(std::move(compile_instances));
TF_VLOG(3) << "Compiling " << computation_ptrs.size()
<< " computations on device " << device << " done!";
for (size_t i = 0; i < computation_ptrs.size(); ++i) {
compile_cache_.Add(cache_keys[i], computation_ptrs[i]);
for (auto index : compile_indices[cache_keys[i]]) {
Expand Down
19 changes: 17 additions & 2 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,8 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors(
tensors[at_tensor_index[i]].data()->xla_data = std::move(handles[i]);
}
}
TF_VLOG(4) << "Tensors graph hash " << coll.hash << " on device '"
<< coll.device << "'";
TF_VLOG(4) << "Tensors graph hash " << coll.hash << " on device "
<< coll.device;
return coll;
}

Expand Down Expand Up @@ -1171,6 +1171,21 @@ void XLATensor::MarkStep(const Device* device) {
DeviceContextArena::Get()->ClearProfileData(device);
}

void XLATensor::WaitDeviceOps(
tensorflow::gtl::ArraySlice<const std::string> devices) {
std::set<Device> wait_devices;
if (!devices.empty()) {
for (auto& device_str : devices) {
wait_devices.insert(Device(device_str));
}
} else {
for (auto& device_str : xla::ComputationClient::Get()->GetLocalDevices()) {
wait_devices.insert(Device(device_str));
}
}
LockDevices(wait_devices);
}

XLATensor::OpByOpAsync XLATensor::SyncTensorsGraphOpByOp(
std::vector<XLATensor>* tensors,
tensorflow::gtl::ArraySlice<const std::string> devices,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class XLATensor {
// the computation boundaries.
static void MarkStep(const Device* device);

// Waits for all the outstanding operations on all the supplied devices.
// If devices is empty, the wait will happen for all local devices.
static void WaitDeviceOps(
tensorflow::gtl::ArraySlice<const std::string> devices);

// Retrieves the PyTorch CPU tensors behind the XLA tensors IR operations.
// All the tensors must be on the same device.
static std::vector<at::Tensor> GetTensors(std::vector<XLATensor>* tensors);
Expand Down
1 change: 1 addition & 0 deletions torch_xla/distributed/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __call__(self, loop_fn, loader, fixed_batch_size=False, batchdim=0):
torch.device(self._device_ids[0]), self._contexts[0])
]

xm.wait_device_ops()
para_loader = pl.ParallelLoader(
loader,
self._device_ids,
Expand Down