Skip to content

Commit

Permalink
[DTensor] Allow DTensor support cuda-like device (#102468)
Browse files Browse the repository at this point in the history
 Allow DTensor support cuda-like device, fix #102442

Currently, DTensor supports cuda and cpu. There are other efforts to make DTensor support third-party devices, for example #101914 and #101911. However, this support only extends a portion of third-party devices and is no good support for third-party cuda-like devices. Therefore, we would like to extend DTensor to support cuda-like devices, after all, cuda is so popular!

1. Similar to what is done here, we need to initialize the communication backend for the device set by DeviceMesh. So `_default_backend_for_device` is added to `Backend`. It is worth noting that when we register a new backend for a device other than cpu and cuda, we also need to add a new default backend for this device.
2. Adding `_device_handle` to `DeviceMesh` for cuda-like devices, similar to what is set in FSDP. When `_device_handle` is not None, the device has similar behavior to `cuda`. In this way, functions like `torch.cuda.device_count()` need to be modified to `device_mesh._device_handle.device_count()`.
Pull Request resolved: #102468
Approved by: https://github.com/wanchaol
  • Loading branch information
shaoyf42 authored and pytorchmergebot committed Jun 7, 2023
1 parent 790f573 commit 17737f9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 24 deletions.
25 changes: 18 additions & 7 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def get_current_mesh(self) -> "DeviceMesh":
mesh_resources: _MeshEnv = _MeshEnv()


def _get_device_handle(device_type: str = "cuda"):
"""
Get the module corresponding to the device_type which is cuda or cuda-like device.
For example, when the device_type is cuda, the module `torch.cuda` is returned.
Return None when device_type is cpu or there is no corresponding module,
otherwise return the corresponding module.
"""
return getattr(torch, device_type, None) if device_type != "cpu" else None


class DeviceMesh(object):
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
Expand All @@ -65,7 +75,7 @@ class DeviceMesh(object):
DeviceMesh can be used as a context manager.
Args:
device_type (str): device type of the mesh. Currently supports: cpu, cuda.
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
mesh (ndarray): could be a multi-dimension array or an integer tensor that
describes the layout of devices, the ids are global ids of the
default process group.
Expand Down Expand Up @@ -124,17 +134,18 @@ def _get_or_create_default_group(self):
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
)

device_handle = _get_device_handle(self.device_type)
# TODO: if user want to pass pg_options, offer a way to do it
if not default_initialized and self.device_type == "cuda":
# automatically set the current cuda device base on num of gpu devices available in each host
if not default_initialized and device_handle:
# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_gpus_per_host = torch.cuda.device_count()
if world_size % num_gpus_per_host != 0:
num_devices_per_host = device_handle.device_count()
if world_size % num_devices_per_host != 0:
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_gpus_per_host} cuda devices!"
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
)
torch.cuda.set_device(get_rank() % num_gpus_per_host)
device_handle.set_device(get_rank() % num_devices_per_host)
# TODO (xilunwu): to perform DTensor random ops, we need to ensure all ranks in mesh is initialized
# with the same random seed. The seed to use will be the current seed on rank 0. We store this seed
# as an attribute of device mesh for future use. However, the detail is still TBD how we gonna use
Expand Down
39 changes: 22 additions & 17 deletions torch/distributed/_tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.distributed as dist

from torch import Tensor
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.device_mesh import _get_device_handle, DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, Shard


Expand All @@ -32,11 +32,12 @@ def set_rng_state(new_state: Tensor, device_mesh: DeviceMesh) -> None:

if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
torch.cuda.set_rng_state(new_state)
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
device_handle.set_rng_state(new_state)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand All @@ -58,12 +59,12 @@ def get_rng_state(device_mesh: DeviceMesh) -> Tensor:
assert isinstance(
device_mesh, DeviceMesh
), f"expect a DeviceMesh but {type(device_mesh)} was passed in."

if device_mesh.device_type == "cuda":
return torch.cuda.get_rng_state()
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
return device_handle.get_rng_state()
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -101,11 +102,12 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:

# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
if device_mesh.device_type == "cuda":
torch.cuda.manual_seed(seed)
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
device_handle.manual_seed(seed)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -232,15 +234,16 @@ def _get_rng_offset(device_mesh: DeviceMesh) -> int:
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
`_get_rng_offset` still returns its GPU device's RNG offset.
"""
if device_mesh.device_type == "cuda":
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# last sizeof(int64_t) bytes are the offset
state = get_rng_state(device_mesh)
offset = state[-8:].view(torch.int64)
return int(offset[0].item())
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -264,7 +267,8 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
"""
if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# the RNG state tensor returned from torch.cuda.get_rng_state() is a ByteTensor
# first 200 * sizeof(4120) bytes in tensor are 0xFF
Expand All @@ -276,7 +280,7 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
set_rng_state(state, device_mesh)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -293,8 +297,9 @@ def _calc_shard_linear_idx(shard_coord: List[int], shard_size: List[int]) -> int


def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
# currently we only support correct RNG on cuda device
if device_mesh.device_type == "cuda":
# currently we only support correct RNG on cuda/cuda-like device
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle and hasattr(device_handle, "set_rng_state"):
return True
else:
warnings.warn(
Expand Down

0 comments on commit 17737f9

Please sign in to comment.