Skip to content

Commit

Permalink
Make DTensor support cuda-like device.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyf42 committed Jun 7, 2023
1 parent 2800a04 commit 190b389
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
10 changes: 10 additions & 0 deletions torch/distributed/_tensor/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,13 @@ def compute_global_tensor_info(
elif not isinstance(placement, (Replicate, _Partial)):
raise RuntimeError(f"placement type {type(placement)} not supported!")
return tensor_shape, tensor_stride


def _get_device_handle(device_type: str = "cuda"):
"""
Get the module corresponding to the device_type which is cuda or cuda-like device.
For example, when the device_type is cuda, the module `torch.cuda` is returned.
Return None when device_type is cpu or there is no corresponding module,
otherwise return the corresponding module.
"""
return getattr(torch, device_type, None) if device_type != "cpu" else None
16 changes: 9 additions & 7 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.distributed._functional_collectives as funcol

from torch.distributed._tensor._utils import _get_device_handle
from torch.distributed.distributed_c10d import (
_get_default_group,
all_gather,
Expand Down Expand Up @@ -65,7 +66,7 @@ class DeviceMesh(object):
DeviceMesh can be used as a context manager.
Args:
device_type (str): device type of the mesh. Currently supports: cpu, cuda.
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
mesh (ndarray): could be a multi-dimension array or an integer tensor that
describes the layout of devices, the ids are global ids of the
default process group.
Expand Down Expand Up @@ -124,17 +125,18 @@ def _get_or_create_default_group(self):
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
)

device_handle = _get_device_handle(self.device_type)
# TODO: if user want to pass pg_options, offer a way to do it
if not default_initialized and self.device_type == "cuda":
# automatically set the current cuda device base on num of gpu devices available in each host
if not default_initialized and device_handle:
# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_gpus_per_host = torch.cuda.device_count()
if world_size % num_gpus_per_host != 0:
num_devices_per_host = device_handle.device_count()
if world_size % num_devices_per_host != 0:
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_gpus_per_host} cuda devices!"
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
)
torch.cuda.set_device(get_rank() % num_gpus_per_host)
device_handle.set_device(get_rank() % num_devices_per_host)
# TODO (xilunwu): to perform DTensor random ops, we need to ensure all ranks in mesh is initialized
# with the same random seed. The seed to use will be the current seed on rank 0. We store this seed
# as an attribute of device mesh for future use. However, the detail is still TBD how we gonna use
Expand Down
38 changes: 22 additions & 16 deletions torch/distributed/_tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist

from torch import Tensor
from torch.distributed._tensor._utils import _get_device_handle
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec, Shard

Expand All @@ -32,11 +33,12 @@ def set_rng_state(new_state: Tensor, device_mesh: DeviceMesh) -> None:

if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
torch.cuda.set_rng_state(new_state)
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
device_handle.set_rng_state(new_state)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand All @@ -58,12 +60,12 @@ def get_rng_state(device_mesh: DeviceMesh) -> Tensor:
assert isinstance(
device_mesh, DeviceMesh
), f"expect a DeviceMesh but {type(device_mesh)} was passed in."

if device_mesh.device_type == "cuda":
return torch.cuda.get_rng_state()
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
return device_handle.get_rng_state()
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -101,11 +103,12 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:

# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
if device_mesh.device_type == "cuda":
torch.cuda.manual_seed(seed)
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
device_handle.manual_seed(seed)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -232,15 +235,16 @@ def _get_rng_offset(device_mesh: DeviceMesh) -> int:
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
`_get_rng_offset` still returns its GPU device's RNG offset.
"""
if device_mesh.device_type == "cuda":
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# last sizeof(int64_t) bytes are the offset
state = get_rng_state(device_mesh)
offset = state[-8:].view(torch.int64)
return int(offset[0].item())
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -264,7 +268,8 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
"""
if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# the RNG state tensor returned from torch.cuda.get_rng_state() is a ByteTensor
# first 200 * sizeof(4120) bytes in tensor are 0xFF
Expand All @@ -276,7 +281,7 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
set_rng_state(state, device_mesh)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -293,8 +298,9 @@ def _calc_shard_linear_idx(shard_coord: List[int], shard_size: List[int]) -> int


def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
# currently we only support correct RNG on cuda device
if device_mesh.device_type == "cuda":
# currently we only support correct RNG on cuda/cuda-like device
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle and hasattr(device_handle, "set_rng_state"):
return True
else:
warnings.warn(
Expand Down

0 comments on commit 190b389

Please sign in to comment.