Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fake process group #102180

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/distributed/test_fake_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Owner(s): ["oncall: distributed"]

import sys
import torch
import torch.distributed as dist
import unittest
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.nn as nn

if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)

from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

HAS_CUDA = torch.cuda.is_available()

class Test(TestCase):
@unittest.skipIf(not HAS_CUDA, 'No CUDA')
ezyang marked this conversation as resolved.
Show resolved Hide resolved
def test_construct_fsdp(self):
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=2, store=store
)
sharded_module = FSDP(nn.Linear(2, 3, device='cuda'))

if __name__ == "__main__":
run_tests()
5 changes: 4 additions & 1 deletion torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,10 @@ def init_process_group(
# these barriers may be unnecessary, as proved by a green CI after
# removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
# added which, when set to 0, will disable these barriers.
if backend == Backend.MPI:
if backend == "fake":
# Fake process group doesn't need barrier
pass
elif backend == Backend.MPI:
# MPI backend doesn't use store.
barrier()
else:
Expand Down
17 changes: 17 additions & 0 deletions torch/testing/_internal/distributed/fake_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch.distributed as dist

# A fake process group (not related to FakeTensor) is a process group which
# doesn't actually do any communication, it just hallucinates some
# communication. You can run a single rank with a fake process group
# without needing multiple processes.

class FakeProcessGroup(dist.ProcessGroup):
pass

class FakeStore(dist.Store):
Copy link
Contributor

@wanchaol wanchaol May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably don't need a FakeStore and instead we can just use HashStore I suppose. But that could be in a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't use HashStore. For example, FSDP will attempt to do a barrier. The barrier will block you until enough writes into the store have happened. If we're doing fake PG there will be no other writes and you'll deadlock. It's best to have the store error if you try to do anything with it and route around it differently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I think FakePG barrier should be a no-op, and in terms of init_processs_group store based barrier it seems we already skip the barrier so it won't write anything to the HashStore. So either fake_store or hash_store could work I feel (I can give it a try and see if that's feasible or not)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDP does a barrier which is why I ended up doing FakeStore. But yeah, try some stuff out, the goal is to be able to run FSDP end-to-end with the fake group with only one node.

pass

def _create_fake_pg(prefix_store, rank, world_size, timeout):
return FakeProcessGroup(rank, world_size)

dist.Backend.register_backend("fake", _create_fake_pg)