Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ function run_op_tests {
run_pjrt python3 "$CDIR/pjrt/test_mesh_service.py"
run_pjrt python3 "$CDIR/test_xla_sharding.py"
run_test python3 "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test python3 "$CDIR/test_input_output_aliases.py"
}

function run_mp_op_tests {
Expand Down
38 changes: 38 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import unittest


class MetricsTest(unittest.TestCase):

def test_non_view(self):
xla_device = xm.xla_device()
# This is a special case where we want to sync t1's and t2's
# value since they will have device_data ir instead of XLAData.
# HLO looks like
# ENTRY %IrToHlo.4 (p0.1: f32[4,2,2], p1.2: f32[4,2,2]) -> (f32[4,2,2], f32[4,2,2]) {
# %p0.1 = f32[4,2,2]{2,1,0} parameter(0)
# %p1.2 = f32[4,2,2]{2,1,0} parameter(1)
# ROOT %tuple.3 = (f32[4,2,2]{2,1,0}, f32[4,2,2]{2,1,0}) tuple(f32[4,2,2]{2,1,0} %p0.1, f32[4,2,2]{2,1,0} %p1.2)
# }
t1 = torch.randn(4, 2, 2).to(xla_device)
t2 = torch.randn(4, 2, 2).to(xla_device)
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

# check in place op aliasing.
t3 = t1 + t2
t1 *= 2.0
t2 += 2.0
xm.mark_step()

self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 4.0)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
13 changes: 11 additions & 2 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -859,12 +859,14 @@ XLAGraphExecutor::BuildInputOutputAliases(
LoweringContext* lowering_ctx) {
std::unordered_map<int64_t, size_t> output_tensor_id_map;
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
// tensors[indices] represent all tensors that needs to be updated after
// the execution. We can only alias the current buffer of these tensors since
// those buffers are no longer needed after execution.
for (size_t i = 0; i < indices.size(); ++i) {
size_t tensor_index = indices[i];
int64_t tensor_id = tensors[tensor_index]->GetUniqueId();
output_tensor_id_map[tensor_id] = i;
}
// TODO we need xla_shape here.
const auto& parameters_data = lowering_ctx->GetParametersData();
std::vector<ssize_t> alias_map(indices.size(), -1);
for (size_t i = 0; i < parameters_data.size(); ++i) {
Expand All @@ -873,11 +875,17 @@ XLAGraphExecutor::BuildInputOutputAliases(
parameters_data[i]->info());
if (data_info != nullptr && !data_info->read_only) {
auto it = output_tensor_id_map.find(data_info->tensor_id);
// Parameter buffer's TensorId in output_tensor_id_map means
// this buffer is not needed after execution since XLATensor will get a
// new buffer.
if (it != output_tensor_id_map.end()) {
size_t output_index = it->second;
xla::XlaOp root = lowering_ctx->GetResult(output_index);
const xla::Shape& root_shape = XlaHelpers::ShapeOfXlaOp(root);
auto parameter_data_shape = UnwrapXlaData(parameters_data[i])->shape();
// Need to check whether existing buffer and the new value has the same
// shape and the existing buffer has not been aliased before aliasing
// the existing and new buffer.
if (parameter_data_shape == root_shape && alias_map[output_index] < 0) {
// parameter is not a tuple so param_index will always be {}
lowering_ctx->builder()->SetUpAlias(
Expand All @@ -892,7 +900,8 @@ XLAGraphExecutor::BuildInputOutputAliases(
}
}
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", alias_map.size());
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount",
input_output_alias_pair.size());
return input_output_alias_pair;
}

Expand Down