Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DTensor] Allow DTensor support cuda-like device #102468

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 18 additions & 7 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def get_current_mesh(self) -> "DeviceMesh":
mesh_resources: _MeshEnv = _MeshEnv()


def _get_device_handle(device_type: str = "cuda"):
"""
Get the module corresponding to the device_type which is cuda or cuda-like device.
For example, when the device_type is cuda, the module `torch.cuda` is returned.
Return None when device_type is cpu or there is no corresponding module,
otherwise return the corresponding module.
"""
return getattr(torch, device_type, None) if device_type != "cpu" else None


class DeviceMesh(object):
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
Expand All @@ -65,7 +75,7 @@ class DeviceMesh(object):

DeviceMesh can be used as a context manager.
Args:
device_type (str): device type of the mesh. Currently supports: cpu, cuda.
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
mesh (ndarray): could be a multi-dimension array or an integer tensor that
describes the layout of devices, the ids are global ids of the
default process group.
Expand Down Expand Up @@ -124,17 +134,18 @@ def _get_or_create_default_group(self):
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
)

device_handle = _get_device_handle(self.device_type)
# TODO: if user want to pass pg_options, offer a way to do it
if not default_initialized and self.device_type == "cuda":
# automatically set the current cuda device base on num of gpu devices available in each host
if not default_initialized and device_handle:
# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_gpus_per_host = torch.cuda.device_count()
if world_size % num_gpus_per_host != 0:
num_devices_per_host = device_handle.device_count()
if world_size % num_devices_per_host != 0:
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_gpus_per_host} cuda devices!"
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
)
torch.cuda.set_device(get_rank() % num_gpus_per_host)
device_handle.set_device(get_rank() % num_devices_per_host)
# TODO (xilunwu): to perform DTensor random ops, we need to ensure all ranks in mesh is initialized
# with the same random seed. The seed to use will be the current seed on rank 0. We store this seed
# as an attribute of device mesh for future use. However, the detail is still TBD how we gonna use
Expand Down
39 changes: 22 additions & 17 deletions torch/distributed/_tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.distributed as dist

from torch import Tensor
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.device_mesh import _get_device_handle, DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, Shard


Expand All @@ -32,11 +32,12 @@ def set_rng_state(new_state: Tensor, device_mesh: DeviceMesh) -> None:

if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
torch.cuda.set_rng_state(new_state)
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
device_handle.set_rng_state(new_state)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand All @@ -58,12 +59,12 @@ def get_rng_state(device_mesh: DeviceMesh) -> Tensor:
assert isinstance(
device_mesh, DeviceMesh
), f"expect a DeviceMesh but {type(device_mesh)} was passed in."

if device_mesh.device_type == "cuda":
return torch.cuda.get_rng_state()
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
return device_handle.get_rng_state()
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -101,11 +102,12 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:

# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
if device_mesh.device_type == "cuda":
torch.cuda.manual_seed(seed)
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
device_handle.manual_seed(seed)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -232,15 +234,16 @@ def _get_rng_offset(device_mesh: DeviceMesh) -> int:
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
`_get_rng_offset` still returns its GPU device's RNG offset.
"""
if device_mesh.device_type == "cuda":
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# last sizeof(int64_t) bytes are the offset
state = get_rng_state(device_mesh)
offset = state[-8:].view(torch.int64)
return int(offset[0].item())
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -264,7 +267,8 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
"""
if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# the RNG state tensor returned from torch.cuda.get_rng_state() is a ByteTensor
# first 200 * sizeof(4120) bytes in tensor are 0xFF
Expand All @@ -276,7 +280,7 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
set_rng_state(state, device_mesh)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -293,8 +297,9 @@ def _calc_shard_linear_idx(shard_coord: List[int], shard_size: List[int]) -> int


def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
# currently we only support correct RNG on cuda device
if device_mesh.device_type == "cuda":
# currently we only support correct RNG on cuda/cuda-like device
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle and hasattr(device_handle, "set_rng_state"):
return True
else:
warnings.warn(
Expand Down