From 726041a4c6a0408752fa6f6aac9c938c555d3123 Mon Sep 17 00:00:00 2001 From: HelioStrike Date: Mon, 5 Oct 2020 15:35:49 +0530 Subject: [PATCH 1/5] warning if current device index is lower than current local rank --- ignite/distributed/comp_models/horovod.py | 3 +++ ignite/distributed/comp_models/native.py | 2 ++ tests/ignite/distributed/comp_models/test_horovod.py | 9 +++++++++ tests/ignite/distributed/comp_models/test_native.py | 9 +++++++++ 4 files changed, 23 insertions(+) diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 1bdcb1402ad7..4b8b569c8e9c 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -1,4 +1,5 @@ import os +import warnings from typing import Callable, Mapping, Optional, Tuple import torch @@ -97,6 +98,8 @@ def get_node_rank(self) -> int: def device(self) -> torch.device: if torch.cuda.is_available(): index = torch.cuda.current_device() + if index < self.get_local_rank(): + warnings.warn("Current device index is less than current local rank.") return torch.device("cuda:{}".format(index)) return torch.device("cpu") diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index be44b8211b5a..c3b73e6f9a78 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -221,6 +221,8 @@ def get_node_rank(self) -> int: def device(self) -> torch.device: if self.backend() == dist.Backend.NCCL: index = torch.cuda.current_device() + if index < self.get_local_rank(): + warnings.warn("Current device index is less than current local rank.") return torch.device("cuda:{}".format(index)) return torch.device("cpu") diff --git a/tests/ignite/distributed/comp_models/test_horovod.py b/tests/ignite/distributed/comp_models/test_horovod.py index daff622d2d4b..7b091c5bc26b 100644 --- a/tests/ignite/distributed/comp_models/test_horovod.py +++ b/tests/ignite/distributed/comp_models/test_horovod.py @@ -195,3 +195,12 @@ def test__hvd_dist_model_spawn_cuda(): nproc_per_node=num_workers_per_machine, use_gloo=True, ) + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") +def test__warning_if_deviceindex_less_than_localrank(local_rank, world_size): + with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): + _test__hvd_dist_model_create_from_context_dist( + local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) + ) diff --git a/tests/ignite/distributed/comp_models/test_native.py b/tests/ignite/distributed/comp_models/test_native.py index b509df82b182..6afed4b69145 100644 --- a/tests/ignite/distributed/comp_models/test_native.py +++ b/tests/ignite/distributed/comp_models/test_native.py @@ -299,3 +299,12 @@ def test__native_dist_model_spawn_gloo(): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test__native_dist_model_spawn_nccl(): _test__native_dist_model_spawn("nccl", num_workers_per_machine=torch.cuda.device_count(), device="cuda") + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") +def test__warning_if_deviceindex_less_than_localrank(local_rank, world_size): + with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): + _test__native_dist_model_create_from_context_dist( + local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) + ) From 029197d54bdf58cf706233311f0d005b4e8433b9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 7 Oct 2020 21:34:38 +0000 Subject: [PATCH 2/5] Updated code and tests --- ignite/distributed/comp_models/native.py | 8 ++- .../distributed/comp_models/test_native.py | 51 ++++++++++++------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index 83362491df59..150168de6561 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -97,7 +97,13 @@ def _init_from_context(self) -> None: self._setup_attrs() def _compute_nproc_per_node(self) -> int: - tensor = torch.tensor([self.get_local_rank() + 1]).to(self.device()) + print("_compute_nproc_per_node") + local_rank = self.get_local_rank() + device = torch.device("cpu") + if self.backend() == dist.Backend.NCCL: + # we manually set cuda device to local rank in order to avoid a hang on all_reduce + device = torch.device("cuda:{}".format(local_rank)) + tensor = torch.tensor([self.get_local_rank() + 1]).to(device) dist.all_reduce(tensor, op=dist.ReduceOp.MAX) return int(tensor.item()) diff --git a/tests/ignite/distributed/comp_models/test_native.py b/tests/ignite/distributed/comp_models/test_native.py index 2a18705b0e36..fdfd6f16dbbc 100644 --- a/tests/ignite/distributed/comp_models/test_native.py +++ b/tests/ignite/distributed/comp_models/test_native.py @@ -211,6 +211,7 @@ def _test__native_dist_model_create_from_context_dist(local_rank, rank, world_si dist.init_process_group(true_backend, "tcp://0.0.0.0:2222", world_size=world_size, rank=rank) dist.barrier() + torch.cuda.set_device(local_rank) true_conf = { "device": true_device, @@ -244,22 +245,51 @@ def test__native_dist_model_create_no_dist_nccl(clean_env): @pytest.mark.distributed -def test__native_dist_model_create_dist_gloo(local_rank, world_size): +def test__native_dist_model_create_dist_gloo_1(local_rank, world_size): _test__native_dist_model_create_from_backend_dist(local_rank, local_rank, world_size, "gloo", "cpu") + + +@pytest.mark.distributed +def test__native_dist_model_create_dist_gloo_2(local_rank, world_size): _test__native_dist_model_create_from_context_dist(local_rank, local_rank, world_size, "gloo", "cpu") @pytest.mark.distributed @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test__native_dist_model_create_dist_nccl(local_rank, world_size): +def test__native_dist_model_create_dist_nccl_1(local_rank, world_size): _test__native_dist_model_create_from_backend_dist( local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) ) + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test__native_dist_model_create_dist_nccl_2(local_rank, world_size): _test__native_dist_model_create_from_context_dist( local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) ) +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") +def test__native_dist_model_warning_index_less_localrank(local_rank, world_size): + + assert _NativeDistModel.create_from_context() is None + + dist.init_process_group("nccl", "tcp://0.0.0.0:2222", world_size=world_size, rank=local_rank) + dist.barrier() + # We deliberately incorrectly set cuda device to 0 + torch.cuda.set_device(0) + + model = _NativeDistModel.create_from_context() + assert isinstance(model, _NativeDistModel), "{} vs _NativeDistModel".format(type(model)) + + if local_rank == 1: + with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): + model.device() + + dist.destroy_process_group() + + def _test_dist_spawn_fn(local_rank, backend, world_size, device): from ignite.distributed.utils import _model @@ -299,20 +329,3 @@ def test__native_dist_model_spawn_gloo(): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test__native_dist_model_spawn_nccl(): _test__native_dist_model_spawn("nccl", num_workers_per_machine=torch.cuda.device_count(), device="cuda") - - -@pytest.mark.distributed -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") -def test__warning_if_deviceindex_less_than_localrank(local_rank, world_size): - - assert _NativeDistModel.create_from_context() is None - - dist.init_process_group("nccl", "tcp://0.0.0.0:2222", world_size=world_size, rank=local_rank) - dist.barrier() - - with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): - model = _NativeDistModel.create_from_context() - - _test__native_dist_model_create_from_context_dist( - local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) - ) From d605e70bbeae41617de5e899b8167b91854dea93 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 7 Oct 2020 21:39:47 +0000 Subject: [PATCH 3/5] Fixed formatting --- tests/ignite/distributed/comp_models/test_native.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ignite/distributed/comp_models/test_native.py b/tests/ignite/distributed/comp_models/test_native.py index fdfd6f16dbbc..806efc6ed39b 100644 --- a/tests/ignite/distributed/comp_models/test_native.py +++ b/tests/ignite/distributed/comp_models/test_native.py @@ -261,6 +261,7 @@ def test__native_dist_model_create_dist_nccl_1(local_rank, world_size): local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) ) + @pytest.mark.distributed @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test__native_dist_model_create_dist_nccl_2(local_rank, world_size): From 354dfca58226de2283a7e7f947f232381b795924 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 7 Oct 2020 22:07:06 +0000 Subject: [PATCH 4/5] Updated code and tests for horovod - fixed failing test --- ignite/distributed/comp_models/horovod.py | 7 ++- ignite/distributed/comp_models/native.py | 5 +- .../distributed/comp_models/test_horovod.py | 53 ++++++++++++++----- .../distributed/comp_models/test_native.py | 3 +- 4 files changed, 52 insertions(+), 16 deletions(-) diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 0c6aafb1e5e6..ef4600bb4d7b 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -69,7 +69,7 @@ def __init__(self, do_init: bool = False, **kwargs: Any) -> None: self._local_rank = hvd.local_rank() - if torch.cuda.is_available(): + if do_init and torch.cuda.is_available(): torch.cuda.set_device(self._local_rank) self._setup_attrs() @@ -99,7 +99,10 @@ def device(self) -> torch.device: if torch.cuda.is_available(): index = torch.cuda.current_device() if index < self.get_local_rank(): - warnings.warn("Current device index is less than current local rank.") + warnings.warn( + "Current device index is less than current local rank. " + "Please, make sure to call torch.cuda.set_device(local_rank)." + ) return torch.device("cuda:{}".format(index)) return torch.device("cpu") diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index 150168de6561..c38637860188 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -227,7 +227,10 @@ def device(self) -> torch.device: if self.backend() == dist.Backend.NCCL: index = torch.cuda.current_device() if index < self.get_local_rank(): - warnings.warn("Current device index is less than current local rank.") + warnings.warn( + "Current device index is less than current local rank. " + "Please, make sure to call torch.cuda.set_device(local_rank)." + ) return torch.device("cuda:{}".format(index)) return torch.device("cpu") diff --git a/tests/ignite/distributed/comp_models/test_horovod.py b/tests/ignite/distributed/comp_models/test_horovod.py index 7b091c5bc26b..937977b7b7e3 100644 --- a/tests/ignite/distributed/comp_models/test_horovod.py +++ b/tests/ignite/distributed/comp_models/test_horovod.py @@ -109,10 +109,13 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device): assert _HorovodDistModel.create_from_context() is None hvd.init() + lrank = hvd.local_rank() + if torch.cuda.is_available(): + torch.cuda.set_device(lrank) true_conf = { "device": true_device, - "local_rank": hvd.local_rank(), + "local_rank": lrank, "rank": hvd.rank(), "world_size": hvd.size(), "node_index": 0, @@ -121,6 +124,7 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device): } model = _HorovodDistModel.create_from_context() + assert model.backend() == true_backend _assert_model(model, true_conf) hvd.shutdown() @@ -142,18 +146,52 @@ def test__hvd_dist_model_create_no_dist_cuda(gloo_hvd_executor): @pytest.mark.distributed @pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU") -def test__hvd_dist_model_create_dist(gloo_hvd_executor): +def test__hvd_dist_model_create_dist_1(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cpu"), np=4) + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU") +def test__hvd_dist_model_create_dist_2(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cpu"), np=4) @pytest.mark.distributed @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test__hvd_dist_model_create_dist_cuda(gloo_hvd_executor): +def test__hvd_dist_model_create_dist_cuda_1(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cuda"), np=torch.cuda.device_count()) + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test__hvd_dist_model_create_dist_cuda_2(gloo_hvd_executor): gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cuda"), np=torch.cuda.device_count()) +def _test__hvd_dist_model_warning_index_less_localrank(): + + assert _HorovodDistModel.create_from_context() is None + + hvd.init() + # We deliberately incorrectly set cuda device to 0 + torch.cuda.set_device(0) + + model = _HorovodDistModel.create_from_context() + assert isinstance(model, _HorovodDistModel), "{} vs _HorovodDistModel".format(type(model)) + + if hvd.local_rank() == 1: + with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): + model.device() + + hvd.shutdown() + + +@pytest.mark.distributed +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") +def test__hvd_dist_model_warning_index_less_localrank(gloo_hvd_executor): + gloo_hvd_executor(_test__hvd_dist_model_warning_index_less_localrank, (), np=torch.cuda.device_count()) + + def _test_dist_spawn_fn(local_rank, backend, world_size, device): from ignite.distributed.utils import _model @@ -195,12 +233,3 @@ def test__hvd_dist_model_spawn_cuda(): nproc_per_node=num_workers_per_machine, use_gloo=True, ) - - -@pytest.mark.distributed -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs") -def test__warning_if_deviceindex_less_than_localrank(local_rank, world_size): - with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."): - _test__hvd_dist_model_create_from_context_dist( - local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank) - ) diff --git a/tests/ignite/distributed/comp_models/test_native.py b/tests/ignite/distributed/comp_models/test_native.py index 806efc6ed39b..6bdf6d821d71 100644 --- a/tests/ignite/distributed/comp_models/test_native.py +++ b/tests/ignite/distributed/comp_models/test_native.py @@ -211,7 +211,8 @@ def _test__native_dist_model_create_from_context_dist(local_rank, rank, world_si dist.init_process_group(true_backend, "tcp://0.0.0.0:2222", world_size=world_size, rank=rank) dist.barrier() - torch.cuda.set_device(local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) true_conf = { "device": true_device, From c3a2db5a3eb9ae891474e159c2ef4a4bfb71dea4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 7 Oct 2020 22:59:31 +0000 Subject: [PATCH 5/5] Updated tests --- tests/ignite/distributed/utils/test_horovod.py | 4 ++++ tests/run_cpu_tests.sh | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/ignite/distributed/utils/test_horovod.py b/tests/ignite/distributed/utils/test_horovod.py index 4f7a06895327..3c8169aaa271 100644 --- a/tests/ignite/distributed/utils/test_horovod.py +++ b/tests/ignite/distributed/utils/test_horovod.py @@ -111,6 +111,10 @@ def _test_idist_methods_in_hvd_context(backend, device): ws = hvd.size() rank = hvd.rank() local_rank = hvd.local_rank() + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + _test_distrib_config(local_rank, backend=backend, ws=ws, true_device=device, rank=rank) hvd.shutdown() diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh index 4e48e19f9f28..78dd6b16d816 100644 --- a/tests/run_cpu_tests.sh +++ b/tests/run_cpu_tests.sh @@ -2,7 +2,7 @@ set -xeu -CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python$CI_PYTHON_VERSION --cov ignite --cov-report term-missing -vvv tests/ +CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python${CI_PYTHON_VERSION:-3.7} --cov ignite --cov-report term-missing -vvv tests/ # https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02 if [ "${SKIP_DISTRIB_TESTS:-0}" -eq "1" ]; then @@ -11,4 +11,4 @@ fi export WORLD_SIZE=2 -CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python$CI_PYTHON_VERSION tests -m distributed -vvv +CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python${CI_PYTHON_VERSION:-3.7} tests -m distributed -vvv