Skip to content

Commit

Permalink
Fix lintrunner check
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyf committed May 29, 2023
1 parent 7f98b13 commit 96437a5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
11 changes: 4 additions & 7 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,26 +150,23 @@ def _get_or_create_default_group(self):
), f"Default PG backend: {world_backend} not supporting CUDA!"
device_model = getattr(torch, self.device_type, None)
if device_model is None:
raise RuntimeError(
f"DeviceMesh don\'t support {self.device_type}"
)
raise RuntimeError(f"DeviceMesh don't support {self.device_type}")
if not default_initialized:
# automatically set the current cuda device base on num of gpu devices available in each host
# NOTE: This device selection would only work for homogeneous hardware.
device_count_func = getattr(device_model, "device_count")
device_count_func = device_model.device_count
num_gpus_per_host = device_count_func()
if world_size % num_gpus_per_host != 0:
raise RuntimeError(
f"DeviceMesh only support homogeneous hardware, but found "
f"{world_size} ranks and {num_gpus_per_host} devices!"
)
getattr(device_model, "set_device")(get_rank() % num_gpus_per_host)
device_model.set_device(get_rank() % num_gpus_per_host)
# TODO (xilunwu): to perform DTensor random ops, we need to ensure all ranks in mesh is initialized
# with the same random seed. The seed to use will be the current seed on rank 0. We store this seed
# as an attribute of device mesh for future use. However, the detail is still TBD how we gonna use
# this attribute, so we will implement this logic once we figure out the answer.
self._seed = getattr(device_model, "initial_seed")()

self._seed = device_model.initial_seed()

# calculate the coordinates of the current global rank on the mesh
rank_coords = (self.mesh == get_rank()).nonzero()
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/_tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def set_rng_state(new_state: Tensor, device_mesh: DeviceMesh) -> None:
# the current rank is in mesh
device_model = getattr(torch, device_mesh.device_type, None)
if device_model:
getattr(device_model, "set_rng_state")(new_state)
device_model.set_rng_state(new_state)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
Expand Down Expand Up @@ -62,7 +62,7 @@ def get_rng_state(device_mesh: DeviceMesh) -> Tensor:

device_model = getattr(torch, device_mesh.device_type, None)
if device_model:
return getattr(device_model, "get_rng_state")()
return device_model.get_rng_state()
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
Expand Down Expand Up @@ -105,7 +105,7 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
if device_mesh.get_coordinate() is not None:
device_model = getattr(torch, device_mesh.device_type, None)
if device_model:
getattr(device_model, "manual_seed")(seed)
device_model.manual_seed(seed)
else:
raise NotImplementedError(
f"DTensor randomness only supports cuda device type, but got {device_mesh.device_type}"
Expand Down

0 comments on commit 96437a5

Please sign in to comment.