Skip to content

Commit

Permalink
fix circular import
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyf42 committed Jun 7, 2023
1 parent 190b389 commit ef61de3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
10 changes: 0 additions & 10 deletions torch/distributed/_tensor/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,3 @@ def compute_global_tensor_info(
elif not isinstance(placement, (Replicate, _Partial)):
raise RuntimeError(f"placement type {type(placement)} not supported!")
return tensor_shape, tensor_stride


def _get_device_handle(device_type: str = "cuda"):
"""
Get the module corresponding to the device_type which is cuda or cuda-like device.
For example, when the device_type is cuda, the module `torch.cuda` is returned.
Return None when device_type is cpu or there is no corresponding module,
otherwise return the corresponding module.
"""
return getattr(torch, device_type, None) if device_type != "cpu" else None
11 changes: 10 additions & 1 deletion torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch.distributed._functional_collectives as funcol

from torch.distributed._tensor._utils import _get_device_handle
from torch.distributed.distributed_c10d import (
_get_default_group,
all_gather,
Expand Down Expand Up @@ -50,6 +49,16 @@ def get_current_mesh(self) -> "DeviceMesh":
mesh_resources: _MeshEnv = _MeshEnv()


def _get_device_handle(device_type: str = "cuda"):
"""
Get the module corresponding to the device_type which is cuda or cuda-like device.
For example, when the device_type is cuda, the module `torch.cuda` is returned.
Return None when device_type is cpu or there is no corresponding module,
otherwise return the corresponding module.
"""
return getattr(torch, device_type, None) if device_type != "cpu" else None


class DeviceMesh(object):
"""
DeviceMesh represents a mesh of devices, where layout of devices could be
Expand Down
3 changes: 1 addition & 2 deletions torch/distributed/_tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import torch.distributed as dist

from torch import Tensor
from torch.distributed._tensor._utils import _get_device_handle
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.device_mesh import DeviceMesh, _get_device_handle
from torch.distributed._tensor.placement_types import DTensorSpec, Shard


Expand Down

0 comments on commit ef61de3

Please sign in to comment.