From 097defb1608827d82b18b27adeec0a98b72a9281 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 11 Oct 2023 17:02:39 -0700 Subject: [PATCH] [device mesh] only check when world size > num_devices per host (#111091) as titled Pull Request resolved: https://github.com/pytorch/pytorch/pull/111091 Approved by: https://github.com/awgu, https://github.com/wz337 ghstack dependencies: #110898, #110900 --- torch/distributed/_tensor/device_mesh.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/_tensor/device_mesh.py index f71c4178ede88..75a32b335d1f1 100644 --- a/torch/distributed/_tensor/device_mesh.py +++ b/torch/distributed/_tensor/device_mesh.py @@ -188,7 +188,10 @@ def _get_or_create_default_group(self): # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host # NOTE: This device selection would only work for homogeneous hardware. num_devices_per_host = device_handle.device_count() - if world_size % num_devices_per_host != 0: + if ( + world_size > num_devices_per_host + and world_size % num_devices_per_host != 0 + ): raise RuntimeError( f"DeviceMesh only support homogeneous hardware, but found " f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"