Skip to content

Commit

Permalink
Move pywrap_dtensor_device.cc to //third_party/tensorflow/python.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 438567492
  • Loading branch information
jiawenhao authored and tensorflower-gardener committed Mar 31, 2022
1 parent 9334a8c commit 9575427
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 60 deletions.
34 changes: 1 addition & 33 deletions tensorflow/dtensor/python/BUILD
Expand Up @@ -3,12 +3,6 @@
load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py", "tf_py_test")

# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")

# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")

default_visibility = [
"//tensorflow/dtensor:dtensor-internal",
]
Expand Down Expand Up @@ -157,10 +151,10 @@ pytype_strict_library(
name = "dtensor_device",
srcs = ["dtensor_device.py"],
deps = [
":_dtensor_device",
":gen_dtensor_ops",
":layout",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:_pywrap_dtensor_device",
"//tensorflow/python:device",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/eager:context",
Expand All @@ -172,32 +166,6 @@ pytype_strict_library(
],
)

tf_python_pybind_extension(
name = "_dtensor_device",
srcs = ["pywrap_dtensor_device.cc"],
features = ["-layering_check"],
module_name = "_dtensor_device",
deps = [
":pywrap_densor_device_headers",
"//tensorflow/python/lib/core:pybind11_lib",
"//tensorflow/python/lib/core:pybind11_status_headers",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@pybind11",
],
)

tf_pybind_cc_library_wrapper(
name = "pywrap_densor_device_headers",
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/dtensor/cc:dtensor_device_cc",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//tensorflow/python/lib/core:safe_pyobject_ptr",
],
)

# -----------------------------------------------------------------------------
# Utilities.

Expand Down
54 changes: 28 additions & 26 deletions tensorflow/dtensor/python/dtensor_device.py
Expand Up @@ -24,9 +24,9 @@

# pylint: disable=g-direct-tensorflow-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.dtensor.python import _dtensor_device
from tensorflow.dtensor.python import gen_dtensor_ops
from tensorflow.dtensor.python import layout as layout_lib
from tensorflow.python import _pywrap_dtensor_device
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.framework import device as tf_device
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self, meshes: List[layout_lib.Mesh], is_async=True):
self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
_next_device_number)
_next_device_number += 1
device, device_info = _dtensor_device.Allocate(self.name)
device, device_info = _pywrap_dtensor_device.Allocate(self.name)
context.register_custom_device(device, self.name, device_info)

self._device_info = device_info
Expand Down Expand Up @@ -154,25 +154,25 @@ def _register_mesh(self, mesh: layout_lib.Mesh):
"""Idempotently register `mesh` with the dtensor device."""
with self._mesh_lock:
if mesh not in self._meshes:
_dtensor_device.AddMesh(self._device_info, mesh.to_string(),
self._is_async, False)
_pywrap_dtensor_device.AddMesh(self._device_info, mesh.to_string(),
self._is_async, False)
self._meshes.add(mesh)
if mesh.device_type().upper() == "TPU":
logging.info(
"Registering virtual 1:1 mapped host mesh %s for mesh %s",
mesh.host_mesh().to_string(), mesh.to_string())
_dtensor_device.AddMesh(self._device_info,
mesh.host_mesh().to_string(), self._is_async,
True)
_pywrap_dtensor_device.AddMesh(self._device_info,
mesh.host_mesh().to_string(),
self._is_async, True)
self._meshes.add(mesh.host_mesh())
embedding_host_mesh = self._create_embedding_host_mesh(mesh)
if embedding_host_mesh:
logging.info(
"Registering embedding host mesh %s on each client for mesh %s ",
embedding_host_mesh.to_string(), mesh.to_string())
_dtensor_device.AddMesh(self._device_info,
embedding_host_mesh.to_string(),
self._is_async, False)
_pywrap_dtensor_device.AddMesh(self._device_info,
embedding_host_mesh.to_string(),
self._is_async, False)
self._meshes.add(embedding_host_mesh)

@property
Expand Down Expand Up @@ -334,7 +334,7 @@ def pack(self, tensors, layout):
else:
is_sparse = False
try:
return _dtensor_device.Pack(
return _pywrap_dtensor_device.Pack(
context.context()._handle, # pylint: disable=protected-access
tensors,
layout.to_string(),
Expand Down Expand Up @@ -362,14 +362,14 @@ def unpack(self, tensor):
raise TypeError(
"Received Variable input to unpack, Variable is not supported.")
try:
tensors = _dtensor_device.Unpack(
tensors = _pywrap_dtensor_device.Unpack(
context.context()._handle, # pylint: disable=protected-access
tensor,
self._device_info)
except core._NotOkStatusException as e: # pylint: disable=protected-access
raise core._status_to_exception(e) from None # pylint: disable=protected-access

is_sparse = _dtensor_device.IsSparseDTensor(
is_sparse = _pywrap_dtensor_device.IsSparseDTensor(
context.context()._handle, # pylint: disable=protected-access.
tensor,
self._device_info)
Expand All @@ -395,7 +395,7 @@ def fetch_layout(self, tensor):
if issubclass(type(tensor), resource_variable_ops.BaseResourceVariable):
tensor = tensor.read_value()
try:
layout_string = _dtensor_device.FetchLayout(
layout_string = _pywrap_dtensor_device.FetchLayout(
context.context()._handle, # pylint: disable=protected-access
tensor,
self._device_info)
Expand All @@ -412,7 +412,7 @@ def set_same_shape_policy(self, enabled):
Args:
enabled: A boolean indicating whether to use the policy.
"""
_dtensor_device.SetSameShapePolicy(self._device_info, enabled)
_pywrap_dtensor_device.SetSameShapePolicy(self._device_info, enabled)

def set_tpu_core_ids(self, mesh_name, tpu_core_ids):
"""Sets the singleton global device ID-to-physical core ID map.
Expand All @@ -421,10 +421,11 @@ def set_tpu_core_ids(self, mesh_name, tpu_core_ids):
mesh_name: The name of a mesh. If empty, set the default mapping.
tpu_core_ids: TPU core IDs sorted by TF task/device ordinal.
"""
_dtensor_device.SetTPUCoreIDs(self._device_info, mesh_name, tpu_core_ids)
_pywrap_dtensor_device.SetTPUCoreIDs(self._device_info, mesh_name,
tpu_core_ids)

def clear_tpu_core_ids(self):
_dtensor_device.ClearTPUCoreIDs(self._device_info)
_pywrap_dtensor_device.ClearTPUCoreIDs(self._device_info)

def tpu_core_ids_to_locations(self, tpu_core_ids):
"""Translates TPU core IDs to TPU core locations.
Expand All @@ -435,7 +436,7 @@ def tpu_core_ids_to_locations(self, tpu_core_ids):
Returns:
A list of corresponding TPU core locations.
"""
return _dtensor_device.TPUCoreIDsToLocations(
return _pywrap_dtensor_device.TPUCoreIDsToLocations(
context.context()._handle, # pylint: disable=protected-access
self._device_info,
tpu_core_ids)
Expand All @@ -450,18 +451,19 @@ def tpu_core_locations_to_ids(self, tpu_core_locations):
Returns:
A list of corresponding TPU core IDs.
"""
return _dtensor_device.TPUCoreLocationsToIDs(
return _pywrap_dtensor_device.TPUCoreLocationsToIDs(
context.context()._handle, # pylint: disable=protected-access
self._device_info,
tpu_core_locations)

@contextlib.contextmanager
def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
self._register_mesh(mesh)
_dtensor_device.ExperimentalSetDefaultMesh(self._device_info,
mesh.to_string().encode("utf-8"))
_pywrap_dtensor_device.ExperimentalSetDefaultMesh(
self._device_info,
mesh.to_string().encode("utf-8"))
yield
_dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
_pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info)

@contextlib.contextmanager
def _default_layout(self, layout: layout_lib.Layout):
Expand Down Expand Up @@ -491,8 +493,8 @@ def _default_layout(self, layout: layout_lib.Layout):
try:
previous_default = self._current_output_layout
self._current_output_layout = layout.to_string().encode("utf-8")
_dtensor_device.ExperimentalSetDefaultLayout(self._device_info,
self._current_output_layout)
_pywrap_dtensor_device.ExperimentalSetDefaultLayout(
self._device_info, self._current_output_layout)
if context.executing_eagerly():
graph = None
previous_graph_size = None
Expand Down Expand Up @@ -524,7 +526,7 @@ def _default_layout(self, layout: layout_lib.Layout):

self._current_output_layout = previous_default # pytype: disable=name-error # py39-upgrade
if self._current_output_layout is None:
_dtensor_device.ExperimentalClearDefaultLayout(self._device_info)
_pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info)
else:
_dtensor_device.ExperimentalSetDefaultLayout(
_pywrap_dtensor_device.ExperimentalSetDefaultLayout(
self._device_info, self._current_output_layout.decode("utf-8"))
26 changes: 26 additions & 0 deletions tensorflow/python/BUILD
Expand Up @@ -4097,6 +4097,32 @@ tf_python_pybind_extension(
],
)

tf_python_pybind_extension(
name = "_pywrap_dtensor_device",
srcs = ["pywrap_dtensor_device.cc"],
features = ["-layering_check"],
module_name = "_pywrap_dtensor_device",
deps = [
":pywrap_densor_device_headers",
"//tensorflow/python/lib/core:pybind11_lib",
"//tensorflow/python/lib/core:pybind11_status_headers",
"//third_party/python_runtime:headers", # buildcleaner: keep
"@pybind11",
],
)

tf_pybind_cc_library_wrapper(
name = "pywrap_densor_device_headers",
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/dtensor/cc:dtensor_device_cc",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//tensorflow/python/lib/core:safe_pyobject_ptr",
],
)

py_library(
name = "tf2",
srcs = ["tf2.py"],
Expand Down
Expand Up @@ -80,7 +80,7 @@ void ConvertToTensor(TFE_Context* ctx, PyObject* input,
output_handle->reset(EagerTensorFromHandle(handle));
}

PYBIND11_MODULE(_dtensor_device, m) {
PYBIND11_MODULE(_pywrap_dtensor_device, m) {
m.def("Allocate", [](const std::string& name) {
TFE_CustomDevice* device = new TFE_CustomDevice;
std::unique_ptr<PyObject, decltype(&PyXDecref)> device_capsule(
Expand Down

0 comments on commit 9575427

Please sign in to comment.