Skip to content

Commit

Permalink
[Dist] Enable FSDP on CPU (#112145)
Browse files Browse the repository at this point in the history
Differential Revision: [D50688958](https://our.internmc.facebook.com/intern/diff/D50688958/)

Pull Request resolved: #112145
Approved by: https://github.com/fegin
ghstack dependencies: #112144
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Nov 7, 2023
1 parent 5ffa98f commit c608b0e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
22 changes: 22 additions & 0 deletions test/distributed/fsdp/test_fsdp_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,28 @@ def test_fsdp_optimizer_overlap(self):
(n, p.clone()) for n, p in fsdp_overlap.named_parameters()
]

@skip_if_lt_x_gpu(2)
def test_fsdp_cpu_training(self):
"""Tests FSDP training on CPU."""
torch.manual_seed(0)
gloo_pg = dist.new_group(backend="gloo")
for ss in [
ShardingStrategy.NO_SHARD,
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
]:
model = MyModel()
fsdp = FSDP(
model,
auto_wrap_policy=always_wrap_policy,
process_group=gloo_pg,
device_id=torch.device("cpu"),
)
inp = torch.randn(2, 2)
fsdp(inp, inp).sum().backward()

@skip_if_lt_x_gpu(2)
def test_fsdp_cpu_init_stays_on_cpu(self):
# Move me to MT test once warning logging and backward collective issue
Expand Down
21 changes: 20 additions & 1 deletion torch/cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"device_count",
"Stream",
"StreamContext",
"Event",
]

_device_t = Union[_device, str, int, None]
Expand Down Expand Up @@ -56,7 +57,25 @@ class Stream:
N.B. This class only exists to facilitate device-agnostic code
"""

pass
def __init__(self, priority: int = -1):
pass

def wait_stream(self, stream) -> None:
pass


class Event:
def query(self) -> bool:
return True

def record(self, stream=None):
pass

def synchronize(self):
pass

def wait(self, stream=None):
pass


_default_cpu_stream = Stream()
Expand Down

0 comments on commit c608b0e

Please sign in to comment.