diff --git a/test/run_tests.sh b/test/run_tests.sh index 8ab19d244f7b..8b3452f00308 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -140,6 +140,7 @@ function run_op_tests { run_pjrt python3 "$CDIR/pjrt/test_mesh_service.py" run_pjrt python3 "$CDIR/test_xla_sharding.py" run_test python3 "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY + run_test python3 "$CDIR/test_input_output_aliases.py" } function run_mp_op_tests { diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py new file mode 100644 index 000000000000..5efd0c30671a --- /dev/null +++ b/test/test_input_output_aliases.py @@ -0,0 +1,38 @@ +import sys + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import unittest + + +class MetricsTest(unittest.TestCase): + + def test_non_view(self): + xla_device = xm.xla_device() + # This is a special case where we want to sync t1's and t2's + # value since they will have device_data ir instead of XLAData. + # HLO looks like + # ENTRY %IrToHlo.4 (p0.1: f32[4,2,2], p1.2: f32[4,2,2]) -> (f32[4,2,2], f32[4,2,2]) { + # %p0.1 = f32[4,2,2]{2,1,0} parameter(0) + # %p1.2 = f32[4,2,2]{2,1,0} parameter(1) + # ROOT %tuple.3 = (f32[4,2,2]{2,1,0}, f32[4,2,2]{2,1,0}) tuple(f32[4,2,2]{2,1,0} %p0.1, f32[4,2,2]{2,1,0} %p1.2) + # } + t1 = torch.randn(4, 2, 2).to(xla_device) + t2 = torch.randn(4, 2, 2).to(xla_device) + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) + + # check in place op aliasing. + t3 = t1 + t2 + t1 *= 2.0 + t2 += 2.0 + xm.mark_step() + + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 4.0) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 8a06608c8b9c..23882e7ed4ac 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -859,12 +859,14 @@ XLAGraphExecutor::BuildInputOutputAliases( LoweringContext* lowering_ctx) { std::unordered_map output_tensor_id_map; std::vector> input_output_alias_pair; + // tensors[indices] represent all tensors that needs to be updated after + // the execution. We can only alias the current buffer of these tensors since + // those buffers are no longer needed after execution. for (size_t i = 0; i < indices.size(); ++i) { size_t tensor_index = indices[i]; int64_t tensor_id = tensors[tensor_index]->GetUniqueId(); output_tensor_id_map[tensor_id] = i; } - // TODO we need xla_shape here. const auto& parameters_data = lowering_ctx->GetParametersData(); std::vector alias_map(indices.size(), -1); for (size_t i = 0; i < parameters_data.size(); ++i) { @@ -873,11 +875,17 @@ XLAGraphExecutor::BuildInputOutputAliases( parameters_data[i]->info()); if (data_info != nullptr && !data_info->read_only) { auto it = output_tensor_id_map.find(data_info->tensor_id); + // Parameter buffer's TensorId in output_tensor_id_map means + // this buffer is not needed after execution since XLATensor will get a + // new buffer. if (it != output_tensor_id_map.end()) { size_t output_index = it->second; xla::XlaOp root = lowering_ctx->GetResult(output_index); const xla::Shape& root_shape = XlaHelpers::ShapeOfXlaOp(root); auto parameter_data_shape = UnwrapXlaData(parameters_data[i])->shape(); + // Need to check whether existing buffer and the new value has the same + // shape and the existing buffer has not been aliased before aliasing + // the existing and new buffer. if (parameter_data_shape == root_shape && alias_map[output_index] < 0) { // parameter is not a tuple so param_index will always be {} lowering_ctx->builder()->SetUpAlias( @@ -892,7 +900,8 @@ XLAGraphExecutor::BuildInputOutputAliases( } } } - TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", alias_map.size()); + TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", + input_output_alias_pair.size()); return input_output_alias_pair; }