Skip to content

Commit

Permalink
Allow setting XLA sharding annotations on TF2 ResourceVariables.
Browse files Browse the repository at this point in the history
xla_sharding API calls on ResourceVariables simply store the sharding proto on the ResourceVariable python object. A new XlaSharding op is created whenever a ReadVariable op is created if the ResourceVariable has its _xla_sharding field set.

PiperOrigin-RevId: 615245263
  • Loading branch information
swachhandl authored and tensorflower-gardener committed Mar 13, 2024
1 parent 1115619 commit 42eb06b
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 6 deletions.
5 changes: 4 additions & 1 deletion tensorflow/python/compiler/xla/experimental/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ py_strict_library(
deps = [
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/ops:resource_variable_ops",
"//third_party/py/numpy",
"@local_xla//xla:xla_data_proto_py",
],
Expand All @@ -29,11 +31,12 @@ py_strict_test(
# copybara:uncomment "//third_party/py/google/protobuf:use_fast_cpp_protos",
"//third_party/py/numpy",
"@local_xla//xla:xla_data_proto_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:tensor",
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/ops:array_ops",
"//tensorflow/python/ops:variables",
"@absl_py//absl/testing:absltest",
],
)
45 changes: 45 additions & 0 deletions tensorflow/python/compiler/xla/experimental/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from tensorflow.compiler.tf2xla.python import xla as tf2xla
from local_xla.xla import xla_data_pb2
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.ops import resource_variable_ops


class Sharding(object):
Expand Down Expand Up @@ -221,6 +223,21 @@ def apply_to_tensor(self,
if unspecified_dims:
assert use_sharding_op and not assign_tuple_sharding
proto = self._proto

# If passed a tf.BaseResourceVariable instead of a tf.Tensor, simply store
# the sharding proto on the tf.BaseResourceVariable object. An XlaShardingOp
# will be created down the line whenever a ReadVariableOp is created by the
# tf.BaseResourceVariable.
if (
isinstance(tensor, resource_variable_ops.BaseResourceVariable)
and context.xla_sharding_for_resource_variables_enabled()
):
if assign_tuple_sharding:
proto = self._create_tuple_proto(num_outputs=1)
# pylint: disable=protected-access
tensor._set_xla_sharding(proto)
return tensor

if use_sharding_op:
if assign_tuple_sharding:
proto = self._create_tuple_proto(num_outputs=1)
Expand Down Expand Up @@ -293,6 +310,20 @@ def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
if sharding is None:
return to_tensor

# If passed a tf.BaseResourceVariable instead of a tf.Tensor, simply store the
# sharding proto on the tf.BaseResourceVariable object. An XlaShardingOp
# will be created down the line whenever a ReadVariableOp is created by the
# tf.BaseResourceVariable.
if (
isinstance(to_tensor, resource_variable_ops.BaseResourceVariable)
and context.xla_sharding_for_resource_variables_enabled()
):
proto = xla_data_pb2.OpSharding()
proto.ParseFromString(sharding)
# pylint: disable=protected-access
to_tensor._set_xla_sharding(proto)
return to_tensor

if use_sharding_op:
to_tensor = tf2xla.sharding(to_tensor, sharding=sharding)
attr_value = attr_value_pb2.AttrValue(s=sharding)
Expand Down Expand Up @@ -416,6 +447,20 @@ def get_tensor_sharding(tensor):
Returns:
The attribute representing XLA sharding on tensor's op.
"""
# If passed a tf.BaseResourceVariable instead of a tf.Tensor, simply get the
# sharding proto set on the _xla_sharding field of the tf.BaseResourceVariable
# object.
if (
isinstance(tensor, resource_variable_ops.BaseResourceVariable)
and context.xla_sharding_for_resource_variables_enabled()
):
# pylint: disable=protected-access
sharding = tensor._get_xla_sharding()
if sharding is None:
return None
else:
return sharding.SerializeToString()

try:
return get_op_sharding(tensor.op)
except AttributeError:
Expand Down
39 changes: 34 additions & 5 deletions tensorflow/python/compiler/xla/experimental/xla_sharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from google.protobuf.message import DecodeError
from local_xla.xla import xla_data_pb2
from tensorflow.python.compiler.xla.experimental import xla_sharding
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables


class ShardingTest(test_util.TensorFlowTestCase):
Expand Down Expand Up @@ -67,13 +68,25 @@ def test_sharding_factory_functions_can_return_sharding_objects(self):
class XlaShardingTest(test_util.TensorFlowTestCase):
"""Tests for non-member functions in the module xla_sharding.py."""

def setUp(self):
super().setUp()
context.enable_xla_sharding_for_resource_variables()

def _graph_has_xla_sharding_op(self, graph):
for node in graph.node:
if node.op == 'XlaSharding' and any(
'ReadVariableOp' in input for input in node.input
):
return True

return False

def test_replicate_annotates_tensor_correctly(self):

@def_function.function
def replicate_helper(tensor):
replicated_tensor = xla_sharding.replicate(
array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
replicated_tensor = xla_sharding.replicate(tensor)
replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor)
self.assertIsNotNone(replicated_sharding)
self.assertIsNone(
Expand All @@ -84,13 +97,17 @@ def replicate_helper(tensor):
result = replicate_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)

var = variables.Variable(initial_value=in_tensor, name='var')
graph = replicate_helper.get_concrete_function(var).graph.as_graph_def()
self.assertTrue(self._graph_has_xla_sharding_op(graph))

def test_tile_annotates_tensor_correctly(self):

@def_function.function
def tile_helper(tensor):
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6]))
self.assertIsInstance(tiled_tensor, tensor_lib.Tensor)
self.assertIsInstance(tiled_tensor, type(tensor))
tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor)
tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding)
# This is the shape of the tile assignment [2, 1, 6]
Expand All @@ -102,13 +119,17 @@ def tile_helper(tensor):
result = tile_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)

var = variables.Variable(initial_value=in_tensor, name='var')
graph = tile_helper.get_concrete_function(var).graph.as_graph_def()
self.assertTrue(self._graph_has_xla_sharding_op(graph))

def test_split_annotates_tensor_correctly(self):

@def_function.function
def split_helper(tensor):
self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
split_tensor = xla_sharding.split(tensor, 2, 3)
self.assertIsInstance(split_tensor, tensor_lib.Tensor)
self.assertIsInstance(split_tensor, type(tensor))
split_sharding = xla_sharding.get_tensor_sharding(split_tensor)
split_shape = xla_sharding.get_sharding_tile_shape(split_sharding)
expected_shape = [1, 1, 3]
Expand All @@ -119,6 +140,10 @@ def split_helper(tensor):
result = split_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)

var = variables.Variable(initial_value=in_tensor, name='var')
graph = split_helper.get_concrete_function(var).graph.as_graph_def()
self.assertTrue(self._graph_has_xla_sharding_op(graph))

def test_split_raises_error_with_incommensurate_dimensions(self):

@def_function.function
Expand Down Expand Up @@ -158,6 +183,10 @@ def copy_helper(tensor):
result = copy_helper(array_ops.ones([4, 5, 6], dtype=dtypes.float32))
self.assertAllEqual(in_tensor, result)

var = variables.Variable(initial_value=in_tensor, name='var')
graph = copy_helper.get_concrete_function(var).graph.as_graph_def()
self.assertTrue(self._graph_has_xla_sharding_op(graph))

def test_get_sharding_tile_shape_returns_none_on_none_input(self):
self.assertIsNone(xla_sharding.get_sharding_tile_shape(None))

Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3122,6 +3122,7 @@ py_strict_library(
":resource_variable_ops_gen",
":state_ops",
":state_ops_gen",
"//tensorflow/compiler/tf2xla/ops:gen_xla_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/core/function/trace_type",
"//tensorflow/python:pywrap_tensorflow",
Expand Down
44 changes: 44 additions & 0 deletions tensorflow/python/ops/resource_variable_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import functools
import weakref

from absl import logging
import numpy as np

from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.core.function import trace_type
Expand Down Expand Up @@ -482,6 +484,31 @@ def __init__( # pylint: disable=super-init-not-called
self._constraint = constraint
self._cached_shape_as_list = None
self._validate_shape = validate_shape
self._xla_sharding = None
self._variable_read = False

def _get_xla_sharding(self):
return self._xla_sharding

def _set_xla_sharding(self, xla_sharding):
"""Annotates this `ResourceVariable` with `xla_sharding`.
`xla_sharding` will be used to create an `XlaShardingOp` whenever a
`ReadVariableOp` is created.
Args:
xla_sharding: The xla.OpSharding proto to annotate this ResourceVariable
with.
"""
if self._variable_read and not context.executing_eagerly():
logging.warning(
"This variable (%s) has already been read (ie. a ReadVariableOp has"
" already been generated) and a new XlaShardingOp using this sharding"
" will not be created unless it is read again. If that's not possible"
", please set the XLA sharding before reading the variable.",
self.name,
)
self._xla_sharding = xla_sharding

def __repr__(self):
if context.executing_eagerly() and not self._in_graph_mode:
Expand Down Expand Up @@ -796,6 +823,7 @@ def _read_variable_op(self, no_copy=False):
The value of the variable.
"""
variable_accessed(self)
self._variable_read = True

def read_and_set_handle(no_copy):
if no_copy and forward_compat.forward_compatible(2022, 5, 3):
Expand All @@ -819,6 +847,22 @@ def read_and_set_handle(no_copy):
"ReadVariableOp", [result], [self.handle],
backward_function=lambda x: [x],
forward_function=lambda x: [x])

# Create an XlaShardingOp if this ResourceVariable is annotated with an XLA
# sharding i.e. the _xla_sharding field is set. Please see the design at
# http://shortn/_RGoruJpzrv for more details.
if (
context.xla_sharding_for_resource_variables_enabled()
and not context.executing_eagerly()
and self._xla_sharding is not None
):
sharding_string = self._xla_sharding.SerializeToString()
result = gen_xla_ops.xla_sharding(result, sharding=sharding_string)
# pylint: disable=protected-access
result.op._set_attr(
"_XlaSharding",
attr_value_pb2.AttrValue(s=sharding_string),
)
return result

def read_value(self):
Expand Down

0 comments on commit 42eb06b

Please sign in to comment.