-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: 1f5983b9437c2ab2b0f218ab82c27a81f3b385ae Pull Request resolved: #102180
- Loading branch information
Showing
3 changed files
with
50 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
import sys | ||
import torch | ||
import torch.distributed as dist | ||
import unittest | ||
|
||
if not dist.is_available(): | ||
print("Distributed not available, skipping tests", file=sys.stderr) | ||
sys.exit(0) | ||
|
||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
run_tests, | ||
) | ||
|
||
HAS_CUDA = torch.cuda.is_available() | ||
|
||
class Test(TestCase): | ||
@unittest.skipIf(not HAS_CUDA, 'No CUDA') | ||
def test_construct_fsdp(self): | ||
store = FakeStore() | ||
dist.init_process_group( | ||
backend="fake", rank=0, world_size=2, store=store | ||
) | ||
sharded_module = FSDP(nn.Linear(2, 3, device='cuda')) | ||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch.distributed as dist | ||
|
||
# A fake process group (not related to FakeTensor) is a process group which | ||
# doesn't actually do any communication, it just hallucinates some | ||
# communication. You can run a single rank with a fake process group | ||
# without needing multiple processes. | ||
|
||
class FakeProcessGroup(dist.ProcessGroup): | ||
pass | ||
|
||
class FakeStore(dist.Store): | ||
pass | ||
|
||
def _create_fake_pg(prefix_store, rank, world_size, timeout): | ||
return FakeProcessGroup(rank, world_size) | ||
|
||
dist.Backend.register_backend("fake", _create_fake_pg) |