Skip to content

Commit

Permalink
Expose SetUpAlias to XLA Python client.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 306796534
Change-Id: I141c4e7ef52e96d26ad9aa0ec8d3094fb6f0ba42
  • Loading branch information
tomhennigan authored and tensorflower-gardener committed Apr 16, 2020
1 parent 261bacf commit 42052dc
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tensorflow/compiler/xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,15 @@ PYBIND11_MODULE(xla_extension, m) {
.def("IsConstant", &XlaBuilder::IsConstant)
.def("SetOpMetadata", &XlaBuilder::SetOpMetadata)
.def("SetSharding", &XlaBuilder::SetSharding)
.def("ClearSharding", &XlaBuilder::ClearSharding);
.def("ClearSharding", &XlaBuilder::ClearSharding)
.def("SetUpAlias",
[](XlaBuilder& builder, const std::vector<int64>& output_index,
int64 param_number, const std::vector<int64>& param_index) {
builder.SetUpAlias(
ShapeIndex(output_index.begin(), output_index.end()),
param_number,
ShapeIndex(param_index.begin(), param_index.end()));
});

m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor);
m.def("DLPackManagedTensorToBuffer",
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,16 @@ def Build(self, root=None, backend=None):
else:
return Computation(self._builder.Build(), backend=backend)

def SetUpAlias(self, output_index, param_number, param_index):
"""Adds a new input/output alias.
Args:
output_index: Iterable of int64 specifying the output index.
param_number: Parameter number.
param_index: Iterable of int64 specifying parameter index.
"""
return self._builder.SetUpAlias(output_index, param_number, param_index)

def GetShape(self, operand):
return self._builder.GetShape(operand)

Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/xla/python/xla_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,21 @@ def testSetSharding(self):
np.testing.assert_allclose(ans, 4.14)


class AliasTest(ComputationTest):

def testSetUpAlias(self):
c = self._NewComputation()
p1 = c.ParameterFromNumpy(NumpyArrayF32(1.0))
p2 = c.ParameterFromNumpy(NumpyArrayF32(1.0))
out = c.Add(p1, p2)
c.SetUpAlias([], 0, [])
c = c.Build(out)
with self.assertRaisesRegex(RuntimeError,
"Buffer aliasing is not supported "
"by XLA for non-TPU backends"):
c.Compile()


int_dtypes = [
np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
np.uint64
Expand Down

0 comments on commit 42052dc

Please sign in to comment.