Skip to content

Commit

Permalink
Add fake process group
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 1f5983b9437c2ab2b0f218ab82c27a81f3b385ae
Pull Request resolved: #102180
  • Loading branch information
ezyang committed May 24, 2023
1 parent fcf812c commit 2bee7ea
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
29 changes: 29 additions & 0 deletions test/distributed/test_fake_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Owner(s): ["oncall: distributed"]

import sys
import torch
import torch.distributed as dist
import unittest

if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)

from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

HAS_CUDA = torch.cuda.is_available()

class Test(TestCase):
@unittest.skipIf(not HAS_CUDA, 'No CUDA')
def test_construct_fsdp(self):
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=2, store=store
)
sharded_module = FSDP(nn.Linear(2, 3, device='cuda'))

if __name__ == "__main__":
run_tests()
5 changes: 4 additions & 1 deletion torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,10 @@ def init_process_group(
# these barriers may be unnecessary, as proved by a green CI after
# removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
# added which, when set to 0, will disable these barriers.
if backend == Backend.MPI:
if backend == "fake":
# Fake process group doesn't need barrier
pass
elif backend == Backend.MPI:
# MPI backend doesn't use store.
barrier()
else:
Expand Down
17 changes: 17 additions & 0 deletions torch/testing/_internal/distributed/fake_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch.distributed as dist

# A fake process group (not related to FakeTensor) is a process group which
# doesn't actually do any communication, it just hallucinates some
# communication. You can run a single rank with a fake process group
# without needing multiple processes.

class FakeProcessGroup(dist.ProcessGroup):
pass

class FakeStore(dist.Store):
pass

def _create_fake_pg(prefix_store, rank, world_size, timeout):
return FakeProcessGroup(rank, world_size)

dist.Backend.register_backend("fake", _create_fake_pg)

0 comments on commit 2bee7ea

Please sign in to comment.