Skip to content

Commit

Permalink
Make DTensor support cuda-like device.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyf42 committed May 29, 2023
1 parent 53d1d30 commit 0585e16
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 34 deletions.
36 changes: 18 additions & 18 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
_init_process_groups: bool = True,
) -> None:
self.device_type = device_type
self._device_handle = getattr(torch, device_type, None) if device_type != "cpu" else None
self.mesh = (
mesh.detach()
if isinstance(mesh, torch.Tensor)
Expand All @@ -120,7 +121,7 @@ def __init__(
self._dim_groups = self._init_process_groups()

def _get_or_create_default_group(self):
self._backend = Backend.GLOO if self.device_type == "cpu" else Backend.NCCL
self._backend = Backend.get_default_backend_for_device(self.device_type)
default_initialized = is_initialized()
if not default_initialized:
init_process_group(backend=self._backend)
Expand All @@ -138,34 +139,33 @@ def _get_or_create_default_group(self):
assert (
world_backend in cpu_backends
), f"Default PG backend: {world_backend} not supporting CPU!"
elif self.device_type == "cuda":
cuda_backends = ["nccl", "gloo", "threaded"]
if world_backend == "gloo":
logger.warning(
"We recommend using nccl backend for cuda device type, gloo backend might only have partial support!"
)
assert (
world_backend in cuda_backends
), f"Default PG backend: {world_backend} not supporting CUDA!"
else:
if self.device_type == "cuda":
cuda_backends = ["nccl", "gloo", "threaded"]
if world_backend == "gloo":
logger.warning(
"We recommend using nccl backend for cuda device type, gloo backend might only have partial support!"
)
assert (
world_backend in cuda_backends
), f"Default PG backend: {world_backend} not supporting CUDA!"
if self._device_handle is None:
raise RuntimeError(f"DeviceMesh don't support {self.device_type}")
if not default_initialized:
# automatically set the current cuda device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
num_gpus_per_host = torch.cuda.device_count()
num_gpus_per_host = self._device_handle.device_count()
if world_size % num_gpus_per_host != 0:
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_gpus_per_host} cuda devices!"
f"{world_size} ranks and {num_gpus_per_host} devices!"
)
torch.cuda.set_device(get_rank() % num_gpus_per_host)
self._device_handle.set_device(get_rank() % num_gpus_per_host)
# TODO (xilunwu): to perform DTensor random ops, we need to ensure all ranks in mesh is initialized
# with the same random seed. The seed to use will be the current seed on rank 0. We store this seed
# as an attribute of device mesh for future use. However, the detail is still TBD how we gonna use
# this attribute, so we will implement this logic once we figure out the answer.
self._seed = torch.cuda.initial_seed()
else:
raise RuntimeError(
f"DeviceMesh only support cpu or cuda device type for now, but got {self.device_type}"
)
self._seed = self._device_handle.initial_seed()

# calculate the coordinates of the current global rank on the mesh
rank_coords = (self.mesh == get_rank()).nonzero()
Expand Down
4 changes: 3 additions & 1 deletion torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def _replicate_tensor(
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
# if rank is not part of mesh, we simply return an empty tensor
return tensor.new_empty(0, requires_grad=tensor.requires_grad)
return tensor.new_empty(
0, device=mesh.device_type, requires_grad=tensor.requires_grad
)

tensor = tensor.contiguous()
mesh.broadcast(tensor, mesh_dim=mesh_dim)
Expand Down
31 changes: 16 additions & 15 deletions torch/distributed/_tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ def set_rng_state(new_state: Tensor, device_mesh: DeviceMesh) -> None:

if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
torch.cuda.set_rng_state(new_state)

if device_mesh._device_handle:
device_mesh._device_handle.set_rng_state(new_state)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand All @@ -59,11 +60,11 @@ def get_rng_state(device_mesh: DeviceMesh) -> Tensor:
device_mesh, DeviceMesh
), f"expect a DeviceMesh but {type(device_mesh)} was passed in."

if device_mesh.device_type == "cuda":
return torch.cuda.get_rng_state()
if device_mesh._device_handle:
return device_mesh._device_handle.get_rng_state()
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -101,11 +102,11 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:

# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
if device_mesh.device_type == "cuda":
torch.cuda.manual_seed(seed)
if device_mesh._device_handle:
device_mesh._device_handle.manual_seed(seed)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)


Expand Down Expand Up @@ -236,15 +237,15 @@ def _get_rng_offset(device_mesh: DeviceMesh) -> int:
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
`_get_rng_offset` still returns its GPU device's RNG offset.
"""
if device_mesh.device_type == "cuda":
if device_mesh._device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# last sizeof(int64_t) bytes are the offset
state = get_rng_state(device_mesh)
offset = state[-8:].view(torch.int64)
return int(offset[0].item())
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -268,7 +269,7 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
"""
if device_mesh.get_coordinate() is not None:
# the current rank is in mesh
if device_mesh.device_type == "cuda":
if device_mesh._device_handle:
# source: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
# the RNG state tensor returned from torch.cuda.get_rng_state() is a ByteTensor
# first 200 * sizeof(4120) bytes in tensor are 0xFF
Expand All @@ -280,7 +281,7 @@ def _set_rng_offset(new_offset: int, device_mesh: DeviceMesh) -> None:
set_rng_state(state, device_mesh)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, "
f"DTensor randomness only supports cuda/cuda-like device type, "
f"but got {device_mesh.device_type}"
)

Expand All @@ -297,8 +298,8 @@ def _calc_shard_linear_idx(shard_coord: List[int], shard_size: List[int]) -> int


def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
# currently we only support correct RNG on cuda device
if device_mesh.device_type == "cuda":
# currently we only support correct RNG on cuda/cuda-like device
if device_mesh._device_handle:
return True
else:
warnings.warn(
Expand Down
12 changes: 12 additions & 0 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ class Backend:
MPI : ["cpu"],
}

_default_backend_for_device: Dict[str, List[str]] = {
"cpu": GLOO,
"cuda": NCCL,
}

def __new__(cls, name: str):
if not isinstance(name, str):
raise ValueError(f"Backend name must be a string, but got: {name}")
Expand All @@ -190,6 +195,13 @@ def __new__(cls, name: str):
value = name.lower()
return value

@classmethod
def get_default_backend_for_device(cls, device: str):
if device not in Backend._default_backend_for_device:
raise RuntimeError(f"Default backend not set for device type {device}, please set a default using \
set_default_backend_for_device")
return Backend._default_backend_for_device[device]

@classmethod
def register_backend(cls, name, func, extended_api=False, devices: Optional[Union[str, List[str]]] = None):
"""
Expand Down

0 comments on commit 0585e16

Please sign in to comment.